WebRAG / src /vector_store.py
Arun21102003
Initial clean commit
97f9138
from qdrant_client import QdrantClient
from qdrant_client.models import Distance, VectorParams, PointStruct
from typing import List, Dict, Optional
import uuid
from src.config import config
from src.embeddings import embedding_service
class VectorStore:
def __init__(self, host: str = None, port: int = None, collection_name: str = None):
self.host = host or config.QDRANT_HOST
self.port = port or config.QDRANT_PORT
self.collection_name = collection_name or config.COLLECTION_NAME
self.client = None
def connect(self):
if self.client is None:
if config.QDRANT_URL:
self.client = QdrantClient(
url=config.QDRANT_URL,
api_key=config.QDRANT_API_KEY
)
else:
self.client = QdrantClient(host=self.host, port=self.port)
self._ensure_collection()
def _ensure_collection(self):
collections = self.client.get_collections().collections
collection_names = [col.name for col in collections]
if self.collection_name not in collection_names:
embedding_dim = embedding_service.get_embedding_dimension()
self.client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(
size=embedding_dim,
distance=Distance.COSINE
)
)
print(f"Created collection: {self.collection_name}")
def store_chunks(self, url_id: str, url: str, chunks: List[Dict], embeddings: List[List[float]]):
self.connect()
points = []
for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)):
point_id = str(uuid.uuid4())
point = PointStruct(
id=point_id,
vector=embedding,
payload={
"url_id": url_id,
"url": url,
"chunk_id": chunk["id"],
"text": chunk["text"],
"start_word": chunk["start_word"],
"end_word": chunk["end_word"]
}
)
points.append(point)
self.client.upsert(
collection_name=self.collection_name,
points=points
)
return len(points)
def search(self, query_embedding: List[float], top_k: int = None) -> List[Dict]:
self.connect()
k = top_k or config.TOP_K_RESULTS
results = self.client.search(
collection_name=self.collection_name,
query_vector=query_embedding,
limit=k
)
return [
{
"id": result.id,
"score": result.score,
"url": result.payload.get("url"),
"url_id": result.payload.get("url_id"),
"text": result.payload.get("text"),
"chunk_id": result.payload.get("chunk_id")
}
for result in results
]
def delete_by_url_id(self, url_id: str):
self.connect()
self.client.delete(
collection_name=self.collection_name,
points_selector={
"filter": {
"must": [
{
"key": "url_id",
"match": {"value": url_id}
}
]
}
}
)
vector_store = VectorStore()