Making-Repos-Speakable / vector_db.py
Abdullah9862873's picture
Upload vector_db.py with huggingface_hub
df8f756 verified
from pinecone import Pinecone, ServerlessSpec
from typing import List, Dict, Any, Optional
import logging
from config import settings
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class VectorDB:
def __init__(self):
self.pinecone_client = Pinecone(
api_key=settings.pinecone_api_key
)
self.index_name = settings.pinecone_index_name
self.index = None
self._connect_to_index()
def _connect_to_index(self) -> None:
existing_indexes = self.pinecone_client.list_indexes()
index_names = [idx.name for idx in existing_indexes]
if self.index_name not in index_names:
logger.info(f"Index '{self.index_name}' not found. Creating new index...")
self._create_index()
else:
logger.info(f"Connecting to existing index: {self.index_name}")
self.index = self.pinecone_client.Index(self.index_name)
self._verify_connection()
def _create_index(self, dimension: int = 384) -> None:
self.pinecone_client.create_index(
name=self.index_name,
dimension=dimension,
metric="cosine",
spec=ServerlessSpec(
cloud="aws",
region="us-east-1"
)
)
logger.info(f"Index '{self.index_name}' created successfully")
self.index = self.pinecone_client.Index(self.index_name)
def _verify_connection(self) -> bool:
try:
stats = self.index.describe_index_stats()
logger.info(f"Index stats: {stats}")
return True
except Exception as e:
logger.error(f"Failed to connect to index: {e}")
return False
def upsert_vectors(
self,
vectors: List[Dict[str, Any]],
namespace: str = ""
) -> Dict[str, Any]:
try:
result = self.index.upsert(
vectors=vectors,
namespace=namespace
)
logger.info(f"Upserted {len(vectors)} vectors")
return result
except Exception as e:
logger.error(f"Failed to upsert vectors: {e}")
raise
def query_vectors(
self,
query_vector: List[float],
top_k: int = 5,
include_metadata: bool = True,
include_values: bool = False,
namespace: str = "",
filter_dict: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
try:
result = self.index.query(
vector=query_vector,
top_k=top_k,
include_metadata=include_metadata,
include_values=include_values,
namespace=namespace,
filter=filter_dict
)
return result
except Exception as e:
logger.error(f"Failed to query vectors: {e}")
raise
def delete_vectors(
self,
ids: List[str],
namespace: str = ""
) -> Dict[str, Any]:
try:
result = self.index.delete(
ids=ids,
namespace=namespace
)
logger.info(f"Deleted {len(ids)} vectors")
return result
except Exception as e:
logger.error(f"Failed to delete vectors: {e}")
raise
def delete_all_vectors(self, namespace: str = "") -> None:
try:
self.index.delete(delete_all=True, namespace=namespace)
logger.info("All vectors deleted from index")
except Exception as e:
logger.error(f"Failed to delete all vectors: {e}")
raise
def get_index_stats(self) -> Dict[str, Any]:
try:
stats = self.index.describe_index_stats()
return stats.to_dict()
except Exception as e:
logger.error(f"Failed to get index stats: {e}")
raise
vector_db = VectorDB()
def get_relevant_context(
query_embedding: List[float],
top_k: int = None,
threshold: float = None
) -> List[Dict[str, Any]]:
if top_k is None:
top_k = settings.top_k_results
if threshold is None:
threshold = settings.similarity_threshold
results = vector_db.query_vectors(
query_vector=query_embedding,
top_k=top_k
)
relevant_contexts = []
for match in results.get("matches", []):
if match["score"] >= threshold:
relevant_contexts.append({
"text": match["metadata"].get("text", ""),
"source": match["metadata"].get("source", ""),
"topic": match["metadata"].get("topic", ""),
"score": match["score"]
})
return relevant_contexts