File size: 3,710 Bytes
7e820ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import google.generativeai as genai
from chromadb import Documents, EmbeddingFunction, Embeddings, PersistentClient, Collection
from typing import Dict, List
import os
from dotenv import load_dotenv
load_dotenv(override=True)
from text_chunk import *
class GeminiEmbeddingFuction(EmbeddingFunction):
"""
Custom embedding function using the Gemini AI API for document retrieval.
This class extends the EmbeddingFunction class and implements the __call__ method
to generate embeddings for a given set of documents using the Gemini AI API.
Parameters:
- input (Documents): A collection of documents to be embedded.
Returns:
- Embeddings: Embeddings generated for the input documents.
"""
def __call__(self, input: Documents) -> Embeddings:
genai.configure(api_key=os.getenv("GEMINI_API"))
return genai.embed_content(model = "models/embedding-001",
content= input,
task_type="retrieval_document",
title="Query")['embedding']
def create_chroma_db(documents: List[str], path: str, name: str):
"""
Creates a Chroma database using the provided documents, path, and collection name.
Parameters:
- documents: An iterable of documents to be added to the Chroma database.
- path (str): The path where the Chroma database will be stored.
- name (str): The name of the collection within the Chroma database.
Returns:
- Tuple[chromadb.Collection, str]: A tuple containing the created Chroma Collection and its name.
"""
chroma_client = PersistentClient(path=path)
db = chroma_client.create_collection(name=name,
embedding_function=GeminiEmbeddingFuction())
for i, d in enumerate(documents):
db.add(documents=[d], ids = str(i))
return db, name
def load_chroma_db(path: str, name: str):
"""
Loads an existing Chroma collection from the specified path with the given name.
Parameters:
- path (str): The path where the Chroma database is stored.
- name (str): The name of the collection within the Chroma database.
Returns:
- chromadb.Collection: The loaded Chroma Collection.
"""
chroma_client = PersistentClient(path=path)
db = chroma_client.get_collection(name=name, embedding_function=GeminiEmbeddingFuction())
return db
def get_relevant_passage(query: str, db: Collection, n_results: int):
"""
semantic search to retrieve the most similar chunks of text from the database.
Parameters:
query (str): The query to search for.
n_results (int): The number of results to return.
db (chromadb.Collection): The Chroma collection to search.
Returns:
List[str]: A list of the most similar chunks of text.
"""
passage = db.query(query_texts=[query],
n_results=n_results)['documents'][0]
return passage
if __name__ == "__main__":
# Create database based on linkdin and summary
# text = load_documents(data_path=f"Week_1\Data_w1")
# print("Length of text: ", len(text))
# chunked_text= sliding_window_chunk(text= text)
# db, name = create_chroma_db(
# documents= chunked_text,
# path= "Week_1\Data_w1",
# name= 'RAG_DB'
# )
# Retrieval example
# db = load_chroma_db(path= "Week_1\Data_w1", name= 'RAG_DB')
# relevant_text = get_relevant_passage(query="Your python experience",db=db,n_results=3)
# print(relevant_text)
print("Done")
|