Spaces:
Sleeping
Sleeping
| from qdrant_client import QdrantClient | |
| from qdrant_client.models import Distance, VectorParams, PointStruct | |
| from typing import List, Dict, Optional | |
| import uuid | |
| from src.config import config | |
| from src.embeddings import embedding_service | |
| class VectorStore: | |
| def __init__(self, host: str = None, port: int = None, collection_name: str = None): | |
| self.host = host or config.QDRANT_HOST | |
| self.port = port or config.QDRANT_PORT | |
| self.collection_name = collection_name or config.COLLECTION_NAME | |
| self.client = None | |
| def connect(self): | |
| if self.client is None: | |
| if config.QDRANT_URL: | |
| self.client = QdrantClient( | |
| url=config.QDRANT_URL, | |
| api_key=config.QDRANT_API_KEY | |
| ) | |
| else: | |
| self.client = QdrantClient(host=self.host, port=self.port) | |
| self._ensure_collection() | |
| def _ensure_collection(self): | |
| collections = self.client.get_collections().collections | |
| collection_names = [col.name for col in collections] | |
| if self.collection_name not in collection_names: | |
| embedding_dim = embedding_service.get_embedding_dimension() | |
| self.client.create_collection( | |
| collection_name=self.collection_name, | |
| vectors_config=VectorParams( | |
| size=embedding_dim, | |
| distance=Distance.COSINE | |
| ) | |
| ) | |
| print(f"Created collection: {self.collection_name}") | |
| def store_chunks(self, url_id: str, url: str, chunks: List[Dict], embeddings: List[List[float]]): | |
| self.connect() | |
| points = [] | |
| for i, (chunk, embedding) in enumerate(zip(chunks, embeddings)): | |
| point_id = str(uuid.uuid4()) | |
| point = PointStruct( | |
| id=point_id, | |
| vector=embedding, | |
| payload={ | |
| "url_id": url_id, | |
| "url": url, | |
| "chunk_id": chunk["id"], | |
| "text": chunk["text"], | |
| "start_word": chunk["start_word"], | |
| "end_word": chunk["end_word"] | |
| } | |
| ) | |
| points.append(point) | |
| self.client.upsert( | |
| collection_name=self.collection_name, | |
| points=points | |
| ) | |
| return len(points) | |
| def search(self, query_embedding: List[float], top_k: int = None) -> List[Dict]: | |
| self.connect() | |
| k = top_k or config.TOP_K_RESULTS | |
| results = self.client.search( | |
| collection_name=self.collection_name, | |
| query_vector=query_embedding, | |
| limit=k | |
| ) | |
| return [ | |
| { | |
| "id": result.id, | |
| "score": result.score, | |
| "url": result.payload.get("url"), | |
| "url_id": result.payload.get("url_id"), | |
| "text": result.payload.get("text"), | |
| "chunk_id": result.payload.get("chunk_id") | |
| } | |
| for result in results | |
| ] | |
| def delete_by_url_id(self, url_id: str): | |
| self.connect() | |
| self.client.delete( | |
| collection_name=self.collection_name, | |
| points_selector={ | |
| "filter": { | |
| "must": [ | |
| { | |
| "key": "url_id", | |
| "match": {"value": url_id} | |
| } | |
| ] | |
| } | |
| } | |
| ) | |
| vector_store = VectorStore() | |