Spaces:
Runtime error
Runtime error
| from typing import Optional | |
| from dataclasses import dataclass | |
| from qdrant_client import QdrantClient | |
| from qdrant_client.http import models | |
| from qdrant_client.http.exceptions import UnexpectedResponse | |
| from sentence_transformers import SentenceTransformer | |
| import numpy as np | |
| from config import config | |
| class DocumentChunk: | |
| chunk_id: str | |
| paper_id: str | |
| paper_name: str | |
| content: str | |
| section_title: str = "" | |
| subsection_title: str = "" | |
| class SearchResult: | |
| chunk: DocumentChunk | |
| score: float | |
| rank: int | |
| class QdrantVectorStore: | |
| VECTOR_SIZE = 384 | |
| MAX_VECTORS_FREE_TIER = 1000000 | |
| def __init__(self): | |
| self.client: Optional[QdrantClient] = None | |
| self.model: Optional[SentenceTransformer] = None | |
| self._initialize() | |
| def _initialize(self): | |
| if config.QDRANT_URL and config.QDRANT_API_KEY: | |
| self.client = QdrantClient( | |
| url=config.QDRANT_URL, | |
| api_key=config.QDRANT_API_KEY | |
| ) | |
| self._ensure_collection() | |
| self.model = SentenceTransformer(config.EMBEDDING_MODEL) | |
| def _ensure_collection(self): | |
| collection_exists = False | |
| try: | |
| self.client.get_collection(config.QDRANT_COLLECTION) | |
| collection_exists = True | |
| except (UnexpectedResponse, Exception): | |
| self.client.create_collection( | |
| collection_name=config.QDRANT_COLLECTION, | |
| vectors_config=models.VectorParams( | |
| size=self.VECTOR_SIZE, | |
| distance=models.Distance.COSINE | |
| ) | |
| ) | |
| try: | |
| self.client.create_payload_index( | |
| collection_name=config.QDRANT_COLLECTION, | |
| field_name="paper_name", | |
| field_schema=models.PayloadSchemaType.KEYWORD | |
| ) | |
| except Exception: | |
| pass | |
| def _check_and_cleanup_if_needed(self): | |
| if not self.client: | |
| return | |
| try: | |
| info = self.client.get_collection(config.QDRANT_COLLECTION) | |
| if info.points_count >= self.MAX_VECTORS_FREE_TIER * 0.9: | |
| self.client.delete_collection(config.QDRANT_COLLECTION) | |
| self._ensure_collection() | |
| print("Qdrant collection reset due to approaching limit") | |
| except Exception as e: | |
| print(f"Error checking collection: {e}") | |
| def add_chunks(self, chunks: list[DocumentChunk]) -> int: | |
| if not chunks or not self.client: | |
| return 0 | |
| self._check_and_cleanup_if_needed() | |
| texts = [c.content for c in chunks] | |
| embeddings = self.model.encode(texts, normalize_embeddings=True) | |
| points = [] | |
| for i, chunk in enumerate(chunks): | |
| points.append(models.PointStruct( | |
| id=hash(chunk.chunk_id) % (2**63), | |
| vector=embeddings[i].tolist(), | |
| payload={ | |
| "chunk_id": chunk.chunk_id, | |
| "paper_id": chunk.paper_id, | |
| "paper_name": chunk.paper_name, | |
| "content": chunk.content, | |
| "section_title": chunk.section_title, | |
| "subsection_title": chunk.subsection_title | |
| } | |
| )) | |
| self.client.upsert( | |
| collection_name=config.QDRANT_COLLECTION, | |
| points=points | |
| ) | |
| return len(chunks) | |
| def search(self, query: str, top_k: Optional[int] = None, paper_filter: Optional[str] = None) -> list[SearchResult]: | |
| if not self.client: | |
| return [] | |
| top_k = top_k or config.TOP_K_CHUNKS | |
| query_embedding = self.model.encode(query, normalize_embeddings=True) | |
| filter_condition = None | |
| if paper_filter: | |
| filter_condition = models.Filter( | |
| must=[models.FieldCondition( | |
| key="paper_name", | |
| match=models.MatchValue(value=paper_filter) | |
| )] | |
| ) | |
| results = self.client.query_points( | |
| collection_name=config.QDRANT_COLLECTION, | |
| query=query_embedding.tolist(), | |
| query_filter=filter_condition, | |
| limit=top_k | |
| ) | |
| search_results = [] | |
| for i, hit in enumerate(results.points): | |
| chunk = DocumentChunk( | |
| chunk_id=hit.payload["chunk_id"], | |
| paper_id=hit.payload["paper_id"], | |
| paper_name=hit.payload["paper_name"], | |
| content=hit.payload["content"], | |
| section_title=hit.payload.get("section_title", ""), | |
| subsection_title=hit.payload.get("subsection_title", "") | |
| ) | |
| search_results.append(SearchResult(chunk=chunk, score=hit.score, rank=i+1)) | |
| return search_results | |
| def get_papers(self) -> list[dict]: | |
| if not self.client: | |
| return [] | |
| try: | |
| result = self.client.scroll( | |
| collection_name=config.QDRANT_COLLECTION, | |
| limit=10000, | |
| with_payload=["paper_name"] | |
| ) | |
| papers = {} | |
| for point in result[0]: | |
| name = point.payload.get("paper_name", "") | |
| if name: | |
| papers[name] = papers.get(name, 0) + 1 | |
| return [{"paper_name": k, "chunk_count": v} for k, v in papers.items()] | |
| except Exception: | |
| return [] | |
| def delete_paper(self, paper_name: str) -> bool: | |
| if not self.client: | |
| return False | |
| try: | |
| self.client.delete( | |
| collection_name=config.QDRANT_COLLECTION, | |
| points_selector=models.FilterSelector( | |
| filter=models.Filter( | |
| must=[models.FieldCondition( | |
| key="paper_name", | |
| match=models.MatchValue(value=paper_name) | |
| )] | |
| ) | |
| ) | |
| ) | |
| return True | |
| except Exception: | |
| return False | |
| def get_stats(self) -> dict: | |
| if not self.client: | |
| return {"papers_indexed": 0, "chunks_indexed": 0} | |
| try: | |
| info = self.client.get_collection(config.QDRANT_COLLECTION) | |
| papers = self.get_papers() | |
| return { | |
| "papers_indexed": len(papers), | |
| "chunks_indexed": info.points_count | |
| } | |
| except Exception: | |
| return {"papers_indexed": 0, "chunks_indexed": 0} | |
| vector_store = QdrantVectorStore() | |