Spaces:
Running on Zero
Running on Zero
| from qdrant_client import QdrantClient | |
| from qdrant_client.http.models import Distance, PointStruct, VectorParams | |
| from app.core.config import settings | |
| from app.core.models import Chunk, SearchResult | |
| class QdrantVectorStore: | |
| def __init__(self, collection_name: str | None = None): | |
| self.collection_name = collection_name or settings.QDRANT_COLLECTION_NAME | |
| self.client = QdrantClient( | |
| url=settings.get_qdrant_url(), | |
| api_key=settings.QDRANT_API_KEY or None, | |
| timeout=60, | |
| ) | |
| def ensure_collection(self, vector_size: int) -> None: | |
| collections = self.client.get_collections().collections | |
| exists = any(collection.name == self.collection_name for collection in collections) | |
| if not exists: | |
| self.client.create_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=VectorParams(size=vector_size, distance=Distance.COSINE), | |
| ) | |
| def upsert_chunks(self, chunks: list[Chunk], embeddings: list[list[float]]) -> None: | |
| if len(chunks) != len(embeddings): | |
| raise ValueError("Chunks and embeddings must have the same length.") | |
| if not chunks: | |
| return | |
| self.ensure_collection(vector_size=len(embeddings[0])) | |
| points = [ | |
| PointStruct( | |
| id=chunk.id, | |
| vector=embedding, | |
| payload={ | |
| "text": chunk.text, | |
| "chunk_index": chunk.index, | |
| "source_type": chunk.source_type.value, | |
| "source": chunk.source, | |
| "title": chunk.title, | |
| "metadata": chunk.metadata, | |
| }, | |
| ) | |
| for chunk, embedding in zip(chunks, embeddings, strict=True) | |
| ] | |
| self.client.upsert(collection_name=self.collection_name, points=points) | |
| def search(self, query_embedding: list[float], limit: int = 5) -> list[SearchResult]: | |
| if hasattr(self.client, "query_points"): | |
| response = self.client.query_points( | |
| collection_name=self.collection_name, | |
| query=query_embedding, | |
| limit=limit, | |
| with_payload=True, | |
| ) | |
| hits = response.points | |
| else: | |
| hits = self.client.search( | |
| collection_name=self.collection_name, | |
| query_vector=query_embedding, | |
| limit=limit, | |
| with_payload=True, | |
| ) | |
| results: list[SearchResult] = [] | |
| for hit in hits: | |
| payload = hit.payload or {} | |
| results.append( | |
| SearchResult( | |
| score=float(hit.score), | |
| text=str(payload.get("text", "")), | |
| title=str(payload.get("title", "")), | |
| source=str(payload.get("source", "")), | |
| source_type=str(payload.get("source_type", "")), | |
| metadata=dict(payload.get("metadata", {})), | |
| ) | |
| ) | |
| return results | |