sifars-chatbot-demo / src /utils /_pinecone_client.py
Aryan Jain
add main.py for server initialization; update file handling in DatabaseUpdater and improve context retrieval in ToolCall
46d624e
import asyncio
import os
import uuid
from pinecone import Pinecone
from pinecone import ServerlessSpec
from langchain_pinecone import PineconeVectorStore
from ._config import logger
from langchain_huggingface import HuggingFaceEmbeddings
embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")
class PineconeClient:
def __init__(self):
self.pc = Pinecone(
api_key=os.getenv("PINECONE_API_KEY"),
)
self.index_name = os.getenv("PINECONE_INDEX_NAME")
self.index = None
self.vector_store = None
async def __aenter__(self):
if not await self._is_index_exist():
await self._create_index()
self.index = await self._fetch_index()
self.vector_store = PineconeVectorStore(index=self.index, embedding=embeddings)
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
pass
async def _list_indexs(self):
return self.pc.list_indexes()
async def _is_index_exist(self, index_name=None):
if not index_name:
index_name = self.index_name
for index in await self._list_indexs():
if index['name'] == index_name:
return True
return False
async def _create_index(self, index_name=None):
if not index_name:
index_name = self.index_name
if not await self._is_index_exist(index_name):
self.pc.create_index(
name=index_name,
dimension=768,
metric="cosine" ,
spec=ServerlessSpec(
cloud='aws',
region='us-east-1',
)
)
return
async def _fetch_index(self, index_name=None):
if not index_name:
index_name = self.index_name
if not await self._is_index_exist(index_name):
await self._create_index(index_name)
return self.pc.Index(index_name)
async def _upsert(self, texts, index_name=None):
await self.vector_store.aadd_texts(texts)
return
async def get_context_for_user_query(self, query: str, top_k=1, index_name=None):
logger.info("Query is: " + query)
results = await self.vector_store.asimilarity_search(query, k=top_k)
content = [result.page_content for result in results]
return "\n".join(content)
async def _delete_index(self, index_name=None):
if not index_name:
index_name = self.index_name
if await self._is_index_exist(index_name):
self.pc.delete_index(index_name)
return