firermsdata-agent / src /utils /_pinecone_client.py
Aryan Jain
migrate to pinecone and show graph color
411c555
import uuid
import os
from pinecone import PineconeAsyncio as AsyncPinecone
from openai import AsyncOpenAI
class PineconeClient:
def __init__(self):
self.index_name = os.getenv("PINECONE_INDEX")
self.embedding_model = os.getenv(
"OPENAI_EMBEDDING_MODEL", "text-embedding-3-small"
)
self.openai = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
self.client = None
self.collection = None
self._client_ctx = None
self._collection_ctx = None
self._pinecone_upsert_batch_size = int(os.getenv("PINECONE_UPSERT_BATCH_SIZE", "50"))
async def __aenter__(self):
self._client_ctx = AsyncPinecone(api_key=os.getenv("PINECONE_API_KEY"))
self.client = await self._client_ctx.__aenter__()
self._collection_ctx = self.client.IndexAsyncio(self.index_name)
self.collection = await self._collection_ctx.__aenter__()
return self
async def __aexit__(self, exc_type, exc_value, traceback):
if self._collection_ctx is not None:
await self._collection_ctx.__aexit__(exc_type, exc_value, traceback)
if self._client_ctx is not None:
await self._client_ctx.__aexit__(exc_type, exc_value, traceback)
self.client = None
self.collection = None
self._client_ctx = None
self._collection_ctx = None
async def _get_text_embedding(self, text: str) -> list[float]:
response = await self.openai.embeddings.create(
input=text,
model=self.embedding_model,
)
return response.data[0].embedding
async def _get_batch_embeddings(self, texts: list[str]) -> list[list[float]]:
response = await self.openai.embeddings.create(
input=texts,
model=self.embedding_model,
)
return [item.embedding for item in response.data]
async def text_splitter(self, text: str, splitter: str = "\n\n") -> list[str]:
return text.split(splitter)
async def upsert(self, texts: list[str], metadatas: list[dict] = None):
if not texts:
return
if metadatas is None:
metadatas = [{} for _ in texts]
if len(texts) != len(metadatas):
raise ValueError("texts and metadatas must have the same length")
ids = [meta.pop("id", str(uuid.uuid4())) for meta in metadatas]
embeddings = await self._get_batch_embeddings(texts)
vectors = [
{
"id": id_,
"values": embedding,
"metadata": {**meta, "_document": text},
}
for id_, embedding, text, meta in zip(ids, embeddings, texts, metadatas)
]
for i in range(0, len(vectors), self._pinecone_upsert_batch_size):
batch = vectors[i : i + self._pinecone_upsert_batch_size]
await self.collection.upsert(vectors=batch)
async def query(self, query: str, n_results: int = 5) -> dict:
query_embedding = await self._get_text_embedding(query)
results = await self.collection.query(
vector=query_embedding,
top_k=n_results,
include_metadata=True,
include_values=False,
)
matches = results["matches"]
ids = [m["id"] for m in matches]
documents = [m["metadata"].pop("_document", "") for m in matches]
metadatas = [m["metadata"] for m in matches]
distances = [m["score"] for m in matches]
return {
"ids": ids,
"documents": documents,
"metadatas": metadatas,
"distances": distances,
}