| from pinecone import Pinecone, ServerlessSpec | |
| from typing import List, Dict, Any, Optional | |
| import logging | |
| from config import settings | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class VectorDB: | |
| def __init__(self): | |
| self.pinecone_client = Pinecone( | |
| api_key=settings.pinecone_api_key | |
| ) | |
| self.index_name = settings.pinecone_index_name | |
| self.index = None | |
| self._connect_to_index() | |
| def _connect_to_index(self) -> None: | |
| existing_indexes = self.pinecone_client.list_indexes() | |
| index_names = [idx.name for idx in existing_indexes] | |
| if self.index_name not in index_names: | |
| logger.info(f"Index '{self.index_name}' not found. Creating new index...") | |
| self._create_index() | |
| else: | |
| logger.info(f"Connecting to existing index: {self.index_name}") | |
| self.index = self.pinecone_client.Index(self.index_name) | |
| self._verify_connection() | |
| def _create_index(self, dimension: int = 384) -> None: | |
| self.pinecone_client.create_index( | |
| name=self.index_name, | |
| dimension=dimension, | |
| metric="cosine", | |
| spec=ServerlessSpec( | |
| cloud="aws", | |
| region="us-east-1" | |
| ) | |
| ) | |
| logger.info(f"Index '{self.index_name}' created successfully") | |
| self.index = self.pinecone_client.Index(self.index_name) | |
| def _verify_connection(self) -> bool: | |
| try: | |
| stats = self.index.describe_index_stats() | |
| logger.info(f"Index stats: {stats}") | |
| return True | |
| except Exception as e: | |
| logger.error(f"Failed to connect to index: {e}") | |
| return False | |
| def upsert_vectors( | |
| self, | |
| vectors: List[Dict[str, Any]], | |
| namespace: str = "" | |
| ) -> Dict[str, Any]: | |
| try: | |
| result = self.index.upsert( | |
| vectors=vectors, | |
| namespace=namespace | |
| ) | |
| logger.info(f"Upserted {len(vectors)} vectors") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Failed to upsert vectors: {e}") | |
| raise | |
| def query_vectors( | |
| self, | |
| query_vector: List[float], | |
| top_k: int = 5, | |
| include_metadata: bool = True, | |
| include_values: bool = False, | |
| namespace: str = "", | |
| filter_dict: Optional[Dict[str, Any]] = None | |
| ) -> Dict[str, Any]: | |
| try: | |
| result = self.index.query( | |
| vector=query_vector, | |
| top_k=top_k, | |
| include_metadata=include_metadata, | |
| include_values=include_values, | |
| namespace=namespace, | |
| filter=filter_dict | |
| ) | |
| return result | |
| except Exception as e: | |
| logger.error(f"Failed to query vectors: {e}") | |
| raise | |
| def delete_vectors( | |
| self, | |
| ids: List[str], | |
| namespace: str = "" | |
| ) -> Dict[str, Any]: | |
| try: | |
| result = self.index.delete( | |
| ids=ids, | |
| namespace=namespace | |
| ) | |
| logger.info(f"Deleted {len(ids)} vectors") | |
| return result | |
| except Exception as e: | |
| logger.error(f"Failed to delete vectors: {e}") | |
| raise | |
| def delete_all_vectors(self, namespace: str = "") -> None: | |
| try: | |
| self.index.delete(delete_all=True, namespace=namespace) | |
| logger.info("All vectors deleted from index") | |
| except Exception as e: | |
| logger.error(f"Failed to delete all vectors: {e}") | |
| raise | |
| def get_index_stats(self) -> Dict[str, Any]: | |
| try: | |
| stats = self.index.describe_index_stats() | |
| return stats.to_dict() | |
| except Exception as e: | |
| logger.error(f"Failed to get index stats: {e}") | |
| raise | |
| vector_db = VectorDB() | |
| def get_relevant_context( | |
| query_embedding: List[float], | |
| top_k: int = None, | |
| threshold: float = None | |
| ) -> List[Dict[str, Any]]: | |
| if top_k is None: | |
| top_k = settings.top_k_results | |
| if threshold is None: | |
| threshold = settings.similarity_threshold | |
| results = vector_db.query_vectors( | |
| query_vector=query_embedding, | |
| top_k=top_k | |
| ) | |
| relevant_contexts = [] | |
| for match in results.get("matches", []): | |
| if match["score"] >= threshold: | |
| relevant_contexts.append({ | |
| "text": match["metadata"].get("text", ""), | |
| "source": match["metadata"].get("source", ""), | |
| "topic": match["metadata"].get("topic", ""), | |
| "score": match["score"] | |
| }) | |
| return relevant_contexts | |