File size: 3,665 Bytes
411c555
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import uuid
import os
from pinecone import PineconeAsyncio as AsyncPinecone
from openai import AsyncOpenAI


class PineconeClient:
    def __init__(self):
        self.index_name = os.getenv("PINECONE_INDEX")
        self.embedding_model = os.getenv(
            "OPENAI_EMBEDDING_MODEL", "text-embedding-3-small"
        )

        self.openai = AsyncOpenAI(api_key=os.getenv("OPENAI_API_KEY"))
        self.client = None
        self.collection = None
        self._client_ctx = None
        self._collection_ctx = None
        self._pinecone_upsert_batch_size = int(os.getenv("PINECONE_UPSERT_BATCH_SIZE", "50"))

    async def __aenter__(self):
        self._client_ctx = AsyncPinecone(api_key=os.getenv("PINECONE_API_KEY"))
        self.client = await self._client_ctx.__aenter__()
        self._collection_ctx = self.client.IndexAsyncio(self.index_name)
        self.collection = await self._collection_ctx.__aenter__()
        return self

    async def __aexit__(self, exc_type, exc_value, traceback):
        if self._collection_ctx is not None:
            await self._collection_ctx.__aexit__(exc_type, exc_value, traceback)
        if self._client_ctx is not None:
            await self._client_ctx.__aexit__(exc_type, exc_value, traceback)
        self.client = None
        self.collection = None
        self._client_ctx = None
        self._collection_ctx = None

    async def _get_text_embedding(self, text: str) -> list[float]:
        response = await self.openai.embeddings.create(
            input=text,
            model=self.embedding_model,
        )
        return response.data[0].embedding

    async def _get_batch_embeddings(self, texts: list[str]) -> list[list[float]]:
        response = await self.openai.embeddings.create(
            input=texts,
            model=self.embedding_model,
        )
        return [item.embedding for item in response.data]

    async def text_splitter(self, text: str, splitter: str = "\n\n") -> list[str]:
        return text.split(splitter)

    async def upsert(self, texts: list[str], metadatas: list[dict] = None):
        if not texts:
            return

        if metadatas is None:
            metadatas = [{} for _ in texts]

        if len(texts) != len(metadatas):
            raise ValueError("texts and metadatas must have the same length")

        ids = [meta.pop("id", str(uuid.uuid4())) for meta in metadatas]
        embeddings = await self._get_batch_embeddings(texts)

        vectors = [
            {
                "id": id_,
                "values": embedding,
                "metadata": {**meta, "_document": text},
            }
            for id_, embedding, text, meta in zip(ids, embeddings, texts, metadatas)
        ]

        for i in range(0, len(vectors), self._pinecone_upsert_batch_size):
            batch = vectors[i : i + self._pinecone_upsert_batch_size]
            await self.collection.upsert(vectors=batch)

    async def query(self, query: str, n_results: int = 5) -> dict:
        query_embedding = await self._get_text_embedding(query)

        results = await self.collection.query(
            vector=query_embedding,
            top_k=n_results,
            include_metadata=True,
            include_values=False,
        )

        matches = results["matches"]
        ids = [m["id"] for m in matches]
        documents = [m["metadata"].pop("_document", "") for m in matches]
        metadatas = [m["metadata"] for m in matches]
        distances = [m["score"] for m in matches]

        return {
            "ids": ids,
            "documents": documents,
            "metadatas": metadatas,
            "distances": distances,
        }