File size: 2,649 Bytes
46d624e
1fff755
c2ebaa7
1fff755
6bd76d2
1fff755
6bd76d2
1fff755
 
 
6bd76d2
 
 
 
1fff755
 
 
 
 
 
46d624e
 
1fff755
 
46d624e
 
 
 
1fff755
 
 
 
 
 
 
46d624e
1fff755
 
 
 
 
 
 
46d624e
1fff755
 
 
 
 
 
6bd76d2
1fff755
 
 
 
 
 
 
 
 
 
 
 
 
 
46d624e
1fff755
6bd76d2
1fff755
46d624e
6bd76d2
f698c44
6bd76d2
 
e1fb2f2
c2ebaa7
 
 
 
 
 
46d624e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import asyncio
import os
import uuid

from pinecone import Pinecone
from pinecone import ServerlessSpec
from langchain_pinecone import PineconeVectorStore

from ._config import logger

from langchain_huggingface import HuggingFaceEmbeddings

embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-mpnet-base-v2")

class PineconeClient:
    def __init__(self):
        self.pc = Pinecone(
            api_key=os.getenv("PINECONE_API_KEY"),
        )
        self.index_name = os.getenv("PINECONE_INDEX_NAME")
        self.index = None
        self.vector_store = None

    async def __aenter__(self):
        if not await self._is_index_exist():
            await self._create_index()
        self.index = await self._fetch_index()
        self.vector_store = PineconeVectorStore(index=self.index, embedding=embeddings)
        return self

    async def __aexit__(self, exc_type, exc_val, exc_tb):
        pass

    async def _list_indexs(self):
        return self.pc.list_indexes()

    async def _is_index_exist(self, index_name=None):
        if not index_name:
            index_name = self.index_name
        for index in await self._list_indexs():
            if index['name'] == index_name:
                return True
        return False

    async def _create_index(self, index_name=None):
        if not index_name:
            index_name = self.index_name
        if not await self._is_index_exist(index_name):
            self.pc.create_index(
                name=index_name,
                dimension=768,
                metric="cosine" ,
                spec=ServerlessSpec(
                    cloud='aws',
                    region='us-east-1',
                )
            )
        return

    async def _fetch_index(self, index_name=None):
        if not index_name:
            index_name = self.index_name
        if not await self._is_index_exist(index_name):
            await self._create_index(index_name)
        return self.pc.Index(index_name)

    async def _upsert(self, texts, index_name=None):
        await self.vector_store.aadd_texts(texts)
        return

    async def get_context_for_user_query(self, query: str, top_k=1, index_name=None):
        logger.info("Query is: " + query)
        results = await self.vector_store.asimilarity_search(query, k=top_k)
        content = [result.page_content for result in results]
        return "\n".join(content)

    async def _delete_index(self, index_name=None):
        if not index_name:
            index_name = self.index_name
        if await self._is_index_exist(index_name):
            self.pc.delete_index(index_name)
        return