from qdrant_client import QdrantClient from qdrant_client.models import Distance, VectorParams, PointStruct from typing import List, Dict, Optional import uuid from src.config import config from src.embeddings import embedding_service class VectorStore: def __init__(self, host: str = None, port: int = None, collection_name: str = None): self.host = host or config.QDRANT_HOST self.port = port or config.QDRANT_PORT self.collection_name = collection_name or config.COLLECTION_NAME self.client = None def connect(self): if self.client is None: if config.QDRANT_URL: self.client = QdrantClient( url=config.QDRANT_URL, api_key=config.QDRANT_API_KEY ) else: self.client = QdrantClient(host=self.host, port=self.port) self._ensure_collection() def _ensure_collection(self): collections = self.client.get_collections().collections collection_names = [col.name for col in collections] if self.collection_name not in collection_names: embedding_dim = embedding_service.get_embedding_dimension() self.client.create_collection( collection_name=self.collection_name, vectors_config=VectorParams( size=embedding_dim, distance=Distance.COSINE ) ) print(f"Created collection: {self.collection_name}") def store_chunks(self, url_id: str, url: str, chunks: List[Dict], embeddings: List[List[float]]): self.connect() points = [] for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): point_id = str(uuid.uuid4()) point = PointStruct( id=point_id, vector=embedding, payload={ "url_id": url_id, "url": url, "chunk_id": chunk["id"], "text": chunk["text"], "start_word": chunk["start_word"], "end_word": chunk["end_word"] } ) points.append(point) self.client.upsert( collection_name=self.collection_name, points=points ) return len(points) def search(self, query_embedding: List[float], top_k: int = None) -> List[Dict]: self.connect() k = top_k or config.TOP_K_RESULTS results = self.client.search( collection_name=self.collection_name, query_vector=query_embedding, limit=k ) return [ { "id": result.id, "score": result.score, "url": result.payload.get("url"), "url_id": result.payload.get("url_id"), "text": result.payload.get("text"), "chunk_id": result.payload.get("chunk_id") } for result in results ] def delete_by_url_id(self, url_id: str): self.connect() self.client.delete( collection_name=self.collection_name, points_selector={ "filter": { "must": [ { "key": "url_id", "match": {"value": url_id} } ] } } ) vector_store = VectorStore()