File size: 3,710 Bytes
7e820ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import google.generativeai as genai
from chromadb import Documents, EmbeddingFunction, Embeddings, PersistentClient, Collection
from typing import Dict, List
import os
from dotenv import load_dotenv
load_dotenv(override=True)
from text_chunk import *

class GeminiEmbeddingFuction(EmbeddingFunction):
    """

    Custom embedding function using the Gemini AI API for document retrieval.



    This class extends the EmbeddingFunction class and implements the __call__ method

    to generate embeddings for a given set of documents using the Gemini AI API.



    Parameters:

    - input (Documents): A collection of documents to be embedded.



    Returns:

    - Embeddings: Embeddings generated for the input documents.

    """

    def __call__(self, input: Documents) -> Embeddings:
        genai.configure(api_key=os.getenv("GEMINI_API"))
        return genai.embed_content(model = "models/embedding-001",
                                   content= input,
                                   task_type="retrieval_document",
                                   title="Query")['embedding']


def create_chroma_db(documents: List[str], path: str, name: str):
    """

    Creates a Chroma database using the provided documents, path, and collection name.



    Parameters:

    - documents: An iterable of documents to be added to the Chroma database.

    - path (str): The path where the Chroma database will be stored.

    - name (str): The name of the collection within the Chroma database.



    Returns:

    - Tuple[chromadb.Collection, str]: A tuple containing the created Chroma Collection and its name.

    """

    chroma_client = PersistentClient(path=path)
    db = chroma_client.create_collection(name=name,
                                         embedding_function=GeminiEmbeddingFuction())
    for i, d in enumerate(documents):
        db.add(documents=[d], ids = str(i))
    return db, name

def load_chroma_db(path: str, name: str):
    """

    Loads an existing Chroma collection from the specified path with the given name.



    Parameters:

    - path (str): The path where the Chroma database is stored.

    - name (str): The name of the collection within the Chroma database.



    Returns:

    - chromadb.Collection: The loaded Chroma Collection.

    """

    chroma_client = PersistentClient(path=path)
    db = chroma_client.get_collection(name=name, embedding_function=GeminiEmbeddingFuction())
    return db

def get_relevant_passage(query: str, db: Collection, n_results: int): 
    """

    semantic search to retrieve the most similar chunks of text from the database.



    Parameters:

    query (str): The query to search for.

    n_results (int): The number of results to return.

    db (chromadb.Collection): The Chroma collection to search.



    Returns:

    List[str]: A list of the most similar chunks of text.

    """
    passage = db.query(query_texts=[query],
                      n_results=n_results)['documents'][0]
    return passage

if __name__ == "__main__":
    # Create database based on linkdin and summary
    # text = load_documents(data_path=f"Week_1\Data_w1")
    # print("Length of text: ", len(text))
    # chunked_text= sliding_window_chunk(text= text)
    # db, name = create_chroma_db(
    #     documents= chunked_text,
    #     path= "Week_1\Data_w1",
    #     name= 'RAG_DB'
    # )

    # Retrieval example
    # db = load_chroma_db(path= "Week_1\Data_w1", name= 'RAG_DB')
    # relevant_text = get_relevant_passage(query="Your python experience",db=db,n_results=3)

    # print(relevant_text)  
    print("Done")