import google.generativeai as genai from chromadb import Documents, EmbeddingFunction, Embeddings, PersistentClient, Collection from typing import Dict, List import os from dotenv import load_dotenv load_dotenv(override=True) from text_chunk import * class GeminiEmbeddingFuction(EmbeddingFunction): """ Custom embedding function using the Gemini AI API for document retrieval. This class extends the EmbeddingFunction class and implements the __call__ method to generate embeddings for a given set of documents using the Gemini AI API. Parameters: - input (Documents): A collection of documents to be embedded. Returns: - Embeddings: Embeddings generated for the input documents. """ def __call__(self, input: Documents) -> Embeddings: genai.configure(api_key=os.getenv("GEMINI_API")) return genai.embed_content(model = "models/embedding-001", content= input, task_type="retrieval_document", title="Query")['embedding'] def create_chroma_db(documents: List[str], path: str, name: str): """ Creates a Chroma database using the provided documents, path, and collection name. Parameters: - documents: An iterable of documents to be added to the Chroma database. - path (str): The path where the Chroma database will be stored. - name (str): The name of the collection within the Chroma database. Returns: - Tuple[chromadb.Collection, str]: A tuple containing the created Chroma Collection and its name. """ chroma_client = PersistentClient(path=path) db = chroma_client.create_collection(name=name, embedding_function=GeminiEmbeddingFuction()) for i, d in enumerate(documents): db.add(documents=[d], ids = str(i)) return db, name def load_chroma_db(path: str, name: str): """ Loads an existing Chroma collection from the specified path with the given name. Parameters: - path (str): The path where the Chroma database is stored. - name (str): The name of the collection within the Chroma database. Returns: - chromadb.Collection: The loaded Chroma Collection. """ chroma_client = PersistentClient(path=path) db = chroma_client.get_collection(name=name, embedding_function=GeminiEmbeddingFuction()) return db def get_relevant_passage(query: str, db: Collection, n_results: int): """ semantic search to retrieve the most similar chunks of text from the database. Parameters: query (str): The query to search for. n_results (int): The number of results to return. db (chromadb.Collection): The Chroma collection to search. Returns: List[str]: A list of the most similar chunks of text. """ passage = db.query(query_texts=[query], n_results=n_results)['documents'][0] return passage if __name__ == "__main__": # Create database based on linkdin and summary # text = load_documents(data_path=f"Week_1\Data_w1") # print("Length of text: ", len(text)) # chunked_text= sliding_window_chunk(text= text) # db, name = create_chroma_db( # documents= chunked_text, # path= "Week_1\Data_w1", # name= 'RAG_DB' # ) # Retrieval example # db = load_chroma_db(path= "Week_1\Data_w1", name= 'RAG_DB') # relevant_text = get_relevant_passage(query="Your python experience",db=db,n_results=3) # print(relevant_text) print("Done")