File size: 5,027 Bytes
df8f756 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 | from pinecone import Pinecone, ServerlessSpec
from typing import List, Dict, Any, Optional
import logging
from config import settings
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class VectorDB:
def __init__(self):
self.pinecone_client = Pinecone(
api_key=settings.pinecone_api_key
)
self.index_name = settings.pinecone_index_name
self.index = None
self._connect_to_index()
def _connect_to_index(self) -> None:
existing_indexes = self.pinecone_client.list_indexes()
index_names = [idx.name for idx in existing_indexes]
if self.index_name not in index_names:
logger.info(f"Index '{self.index_name}' not found. Creating new index...")
self._create_index()
else:
logger.info(f"Connecting to existing index: {self.index_name}")
self.index = self.pinecone_client.Index(self.index_name)
self._verify_connection()
def _create_index(self, dimension: int = 384) -> None:
self.pinecone_client.create_index(
name=self.index_name,
dimension=dimension,
metric="cosine",
spec=ServerlessSpec(
cloud="aws",
region="us-east-1"
)
)
logger.info(f"Index '{self.index_name}' created successfully")
self.index = self.pinecone_client.Index(self.index_name)
def _verify_connection(self) -> bool:
try:
stats = self.index.describe_index_stats()
logger.info(f"Index stats: {stats}")
return True
except Exception as e:
logger.error(f"Failed to connect to index: {e}")
return False
def upsert_vectors(
self,
vectors: List[Dict[str, Any]],
namespace: str = ""
) -> Dict[str, Any]:
try:
result = self.index.upsert(
vectors=vectors,
namespace=namespace
)
logger.info(f"Upserted {len(vectors)} vectors")
return result
except Exception as e:
logger.error(f"Failed to upsert vectors: {e}")
raise
def query_vectors(
self,
query_vector: List[float],
top_k: int = 5,
include_metadata: bool = True,
include_values: bool = False,
namespace: str = "",
filter_dict: Optional[Dict[str, Any]] = None
) -> Dict[str, Any]:
try:
result = self.index.query(
vector=query_vector,
top_k=top_k,
include_metadata=include_metadata,
include_values=include_values,
namespace=namespace,
filter=filter_dict
)
return result
except Exception as e:
logger.error(f"Failed to query vectors: {e}")
raise
def delete_vectors(
self,
ids: List[str],
namespace: str = ""
) -> Dict[str, Any]:
try:
result = self.index.delete(
ids=ids,
namespace=namespace
)
logger.info(f"Deleted {len(ids)} vectors")
return result
except Exception as e:
logger.error(f"Failed to delete vectors: {e}")
raise
def delete_all_vectors(self, namespace: str = "") -> None:
try:
self.index.delete(delete_all=True, namespace=namespace)
logger.info("All vectors deleted from index")
except Exception as e:
logger.error(f"Failed to delete all vectors: {e}")
raise
def get_index_stats(self) -> Dict[str, Any]:
try:
stats = self.index.describe_index_stats()
return stats.to_dict()
except Exception as e:
logger.error(f"Failed to get index stats: {e}")
raise
vector_db = VectorDB()
def get_relevant_context(
query_embedding: List[float],
top_k: int = None,
threshold: float = None
) -> List[Dict[str, Any]]:
if top_k is None:
top_k = settings.top_k_results
if threshold is None:
threshold = settings.similarity_threshold
results = vector_db.query_vectors(
query_vector=query_embedding,
top_k=top_k
)
relevant_contexts = []
for match in results.get("matches", []):
if match["score"] >= threshold:
relevant_contexts.append({
"text": match["metadata"].get("text", ""),
"source": match["metadata"].get("source", ""),
"topic": match["metadata"].get("topic", ""),
"score": match["score"]
})
return relevant_contexts
|