File size: 1,730 Bytes
206ef5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c926830
206ef5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
"""
Contain Wrapper Class for ChormaDB client, that can process and store documents and retrive document chunks.
"""

# for chromaDB
__import__("pysqlite3")
import sys

sys.modules["sqlite3"] = sys.modules.pop("pysqlite3")

from typing import List, Optional, Tuple
import chromadb


class AdvancedClient:

    def __init__(self, vector_database_path: str = "vectorDB") -> None:
        self.client = chromadb.PersistentClient(path=vector_database_path)

    def create_collection(
        self,
        collection_id: str,
        file_datas: List[Tuple[str, int]],
    ):
        chunks = []
        ids = []

        for chunk, _id in file_datas:
            chunks.append(chunk)
            ids.append(str(_id)) #make sure IDs are string dtpye

        from .ModelCallingFunctions import generate_embedding

        embeddings = generate_embedding(texts=chunks)

        collection = self.client.create_collection(collection_id)
        collection.add(
            ids=ids,
            embeddings=embeddings,  # type: ignore
            documents=chunks,
        )

    def retrieve_chunks(
        self,
        collection_id: str,
        query: str = "NONE",
        query_embedding: Optional[List[float]] = None,
        number_of_chunks: int = 3,
    ):

        collection = self.client.get_collection(name=collection_id)

        if query_embedding == None:
            from .ModelCallingFunctions import generate_embedding

            query_emb = generate_embedding([query])[0]
        else:
            query_emb = query_embedding

        results = collection.query(
            query_embeddings=query_emb,
            n_results=number_of_chunks,
        )

        return results["documents"][0]  # pyright: ignore