from typing import Optional from dataclasses import dataclass from qdrant_client import QdrantClient from qdrant_client.http import models from qdrant_client.http.exceptions import UnexpectedResponse from sentence_transformers import SentenceTransformer import numpy as np from config import config @dataclass class DocumentChunk: chunk_id: str paper_id: str paper_name: str content: str section_title: str = "" subsection_title: str = "" @dataclass class SearchResult: chunk: DocumentChunk score: float rank: int class QdrantVectorStore: VECTOR_SIZE = 384 MAX_VECTORS_FREE_TIER = 1000000 def __init__(self): self.client: Optional[QdrantClient] = None self.model: Optional[SentenceTransformer] = None self._initialize() def _initialize(self): if config.QDRANT_URL and config.QDRANT_API_KEY: self.client = QdrantClient( url=config.QDRANT_URL, api_key=config.QDRANT_API_KEY ) self._ensure_collection() self.model = SentenceTransformer(config.EMBEDDING_MODEL) def _ensure_collection(self): collection_exists = False try: self.client.get_collection(config.QDRANT_COLLECTION) collection_exists = True except (UnexpectedResponse, Exception): self.client.create_collection( collection_name=config.QDRANT_COLLECTION, vectors_config=models.VectorParams( size=self.VECTOR_SIZE, distance=models.Distance.COSINE ) ) try: self.client.create_payload_index( collection_name=config.QDRANT_COLLECTION, field_name="paper_name", field_schema=models.PayloadSchemaType.KEYWORD ) except Exception: pass def _check_and_cleanup_if_needed(self): if not self.client: return try: info = self.client.get_collection(config.QDRANT_COLLECTION) if info.points_count >= self.MAX_VECTORS_FREE_TIER * 0.9: self.client.delete_collection(config.QDRANT_COLLECTION) self._ensure_collection() print("Qdrant collection reset due to approaching limit") except Exception as e: print(f"Error checking collection: {e}") def add_chunks(self, chunks: list[DocumentChunk]) -> int: if not chunks or not self.client: return 0 self._check_and_cleanup_if_needed() texts = [c.content for c in chunks] embeddings = self.model.encode(texts, normalize_embeddings=True) points = [] for i, chunk in enumerate(chunks): points.append(models.PointStruct( id=hash(chunk.chunk_id) % (2**63), vector=embeddings[i].tolist(), payload={ "chunk_id": chunk.chunk_id, "paper_id": chunk.paper_id, "paper_name": chunk.paper_name, "content": chunk.content, "section_title": chunk.section_title, "subsection_title": chunk.subsection_title } )) self.client.upsert( collection_name=config.QDRANT_COLLECTION, points=points ) return len(chunks) def search(self, query: str, top_k: Optional[int] = None, paper_filter: Optional[str] = None) -> list[SearchResult]: if not self.client: return [] top_k = top_k or config.TOP_K_CHUNKS query_embedding = self.model.encode(query, normalize_embeddings=True) filter_condition = None if paper_filter: filter_condition = models.Filter( must=[models.FieldCondition( key="paper_name", match=models.MatchValue(value=paper_filter) )] ) results = self.client.query_points( collection_name=config.QDRANT_COLLECTION, query=query_embedding.tolist(), query_filter=filter_condition, limit=top_k ) search_results = [] for i, hit in enumerate(results.points): chunk = DocumentChunk( chunk_id=hit.payload["chunk_id"], paper_id=hit.payload["paper_id"], paper_name=hit.payload["paper_name"], content=hit.payload["content"], section_title=hit.payload.get("section_title", ""), subsection_title=hit.payload.get("subsection_title", "") ) search_results.append(SearchResult(chunk=chunk, score=hit.score, rank=i+1)) return search_results def get_papers(self) -> list[dict]: if not self.client: return [] try: result = self.client.scroll( collection_name=config.QDRANT_COLLECTION, limit=10000, with_payload=["paper_name"] ) papers = {} for point in result[0]: name = point.payload.get("paper_name", "") if name: papers[name] = papers.get(name, 0) + 1 return [{"paper_name": k, "chunk_count": v} for k, v in papers.items()] except Exception: return [] def delete_paper(self, paper_name: str) -> bool: if not self.client: return False try: self.client.delete( collection_name=config.QDRANT_COLLECTION, points_selector=models.FilterSelector( filter=models.Filter( must=[models.FieldCondition( key="paper_name", match=models.MatchValue(value=paper_name) )] ) ) ) return True except Exception: return False def get_stats(self) -> dict: if not self.client: return {"papers_indexed": 0, "chunks_indexed": 0} try: info = self.client.get_collection(config.QDRANT_COLLECTION) papers = self.get_papers() return { "papers_indexed": len(papers), "chunks_indexed": info.points_count } except Exception: return {"papers_indexed": 0, "chunks_indexed": 0} vector_store = QdrantVectorStore()