| from qdrant_client import QdrantClient |
| from qdrant_client.http.models import ( |
| ScoredPoint, |
| Filter, |
| FieldCondition, |
| MatchText |
| ) |
| from qdrant_client.models import ( |
| VectorParams, |
| Distance, |
| PointStruct, |
| TextIndexParams, |
| TokenizerType |
| ) |
| from app.core.models import Embedder |
| from app.core.chunks import Chunk |
| import numpy as np |
| from uuid import UUID |
| from app.settings import settings |
| import time |
| from fastapi import HTTPException |
| import re |
|
|
|
|
| class VectorDatabase: |
| def __init__(self, embedder: Embedder, host: str = "qdrant", port: int = 6333): |
| self.host: str = host |
| self.client: QdrantClient = self._initialize_qdrant_client() |
| self.embedder: Embedder = embedder |
| self.already_stored: np.array[np.array] = np.array([]).reshape( |
| 0, embedder.get_vector_dimensionality() |
| ) |
|
|
| def store( |
| self, collection_name: str, chunks: list[Chunk], batch_size: int = 1000 |
| ) -> None: |
| points: list[PointStruct] = [] |
|
|
| print("Start getting text embeddings") |
| start = time.time() |
| vectors = self.embedder.encode([chunk.get_raw_text() for chunk in chunks]) |
| print(f"Embeddings - {time.time() - start}") |
|
|
| for vector, chunk in zip(vectors, chunks): |
| if self.accept_vector(collection_name, vector): |
| points.append( |
| PointStruct( |
| id=str(chunk.id), |
| vector=vector, |
| payload={ |
| "metadata": chunk.get_metadata(), |
| "text": chunk.get_raw_text(), |
| }, |
| ) |
| ) |
|
|
| if len(points): |
| for group in range(0, len(points), batch_size): |
| self.client.upsert( |
| collection_name=collection_name, |
| points=points[group : group + batch_size], |
| wait=False, |
| ) |
|
|
| """ |
| Measures a cosine of angle between tow vectors |
| """ |
|
|
| def cosine_similarity(self, vec1: list[float], vec2: list[float] | list[list[float]]) -> float: |
| if len(vec2) == 0: |
| return 0 |
|
|
| vec1_np = np.array(vec1) |
| vec2_np = np.array(vec2) |
|
|
| if vec2_np.ndim == 2: |
| vec2_np = vec2_np.T |
|
|
| similarities = np.array(vec1_np @ vec2_np / (np.linalg.norm(vec1_np) * np.linalg.norm(vec2_np, axis=0))) |
| return np.max(similarities) |
|
|
| """ |
| Defines weather the vector should be stored in the db by searching for the most |
| similar one |
| """ |
|
|
| def accept_vector(self, collection_name: str, vector: np.array) -> bool: |
| most_similar = self.client.query_points( |
| collection_name=collection_name, query=vector, limit=1, with_vectors=True |
| ).points |
|
|
| if not len(most_similar): |
| return True |
| else: |
| most_similar = most_similar[0] |
|
|
| if 1 - self.cosine_similarity(vector, most_similar.vector) < settings.max_delta: |
| return False |
| return True |
|
|
| def construct_keywords_list(self, query: str) -> list[FieldCondition]: |
| keywords = re.findall(r'\b[A-Z]{2,}\b', query) |
| filters = [] |
|
|
| print(keywords) |
|
|
| for word in keywords: |
| if len(word) > 30 or len(word) < 2: |
| continue |
| filters.append(FieldCondition(key="text", match=MatchText(text=word))) |
|
|
| return filters |
|
|
| def combine_points_without_duplications(self, first: list[ScoredPoint], second: list[ScoredPoint] = None) -> list[ScoredPoint]: |
| combined = [] |
| similarity_vectors = [] |
|
|
| to_combine = [first] |
| if second is not None: |
| to_combine.append(second) |
|
|
| for group in to_combine: |
| for point in group: |
| if 1 - self.cosine_similarity(point.vector, similarity_vectors) > min(settings.max_delta, 0.2): |
| combined.append(point) |
| similarity_vectors.append(point.vector) |
| return combined |
|
|
| def search(self, collection_name: str, query: str, top_k: int = 5) -> list[Chunk]: |
| query_embedded: np.ndarray = self.embedder.encode(query) |
|
|
| if isinstance(query_embedded, list): |
| query_embedded = query_embedded[0] |
|
|
| keywords = self.construct_keywords_list(query) |
|
|
| mixed_result: list[ScoredPoint] = self.client.query_points( |
| collection_name=collection_name, query=query_embedded, limit=top_k + int(top_k * 0.3), |
| query_filter=Filter(should=keywords), with_vectors=True |
| ).points |
|
|
| print(f"Len of original array -> {len(mixed_result)}") |
| combined = self.combine_points_without_duplications(mixed_result) |
| print(f"Len of combined array -> {len(combined)}") |
|
|
| return [ |
| Chunk( |
| id=UUID(point.payload.get("metadata", {}).get("id", "")), |
| filename=point.payload.get("metadata", {}).get("filename", ""), |
| page_number=point.payload.get("metadata", {}).get("page_number", 0), |
| start_index=point.payload.get("metadata", {}).get("start_index", 0), |
| start_line=point.payload.get("metadata", {}).get("start_line", 0), |
| end_line=point.payload.get("metadata", {}).get("end_line", 0), |
| text=point.payload.get("text", ""), |
| ) |
| for point in combined |
| ] |
|
|
| def _initialize_qdrant_client(self, max_retries=5, delay=2) -> QdrantClient: |
| for attempt in range(max_retries): |
| try: |
| client = QdrantClient(**settings.qdrant.model_dump()) |
| client.get_collections() |
| return client |
| except Exception as e: |
| if attempt == max_retries - 1: |
| raise HTTPException( |
| 500, |
| f"Failed to connect to Qdrant server after {max_retries} attempts. " |
| f"Last error: {str(e)}", |
| ) |
|
|
| print( |
| f"Connection attempt {attempt + 1} out of {max_retries} failed. " |
| f"Retrying in {delay} seconds..." |
| ) |
|
|
| time.sleep(delay) |
| delay *= 2 |
|
|
| def _check_collection_exists(self, collection_name: str) -> bool: |
| try: |
| return self.client.collection_exists(collection_name) |
| except Exception as e: |
| raise HTTPException( |
| 500, |
| f"Failed to check collection {collection_name} exists. Last error: {str(e)}", |
| ) |
|
|
| def _create_collection(self, collection_name: str) -> None: |
| try: |
| self.client.create_collection( |
| collection_name=collection_name, |
| vectors_config=VectorParams( |
| size=self.embedder.get_vector_dimensionality(), |
| distance=Distance.COSINE, |
| ), |
| ) |
| self.client.create_payload_index( |
| collection_name=collection_name, |
| field_name="text", |
| field_schema=TextIndexParams( |
| type="text", |
| tokenizer=TokenizerType.WORD, |
| min_token_len=2, |
| max_token_len=30, |
| lowercase=True |
| ) |
| ) |
| except Exception as e: |
| raise HTTPException( |
| 500, f"Failed to create collection {self.collection_name}: {str(e)}" |
| ) |
|
|
| def create_collection(self, collection_name: str) -> None: |
| try: |
| if self._check_collection_exists(collection_name): |
| return |
| self._create_collection(collection_name) |
| except Exception as e: |
| print(e) |
| raise HTTPException(500, e) |
|
|
| def __del__(self): |
| if hasattr(self, "client"): |
| self.client.close() |
|
|
| def get_collections(self) -> list[str]: |
| try: |
| return self.client.get_collections() |
| except Exception as e: |
| print(e) |
| raise HTTPException(500, "Failed to get collection names") |
|
|