|
|
import google.generativeai as genai
|
|
|
from chromadb import Documents, EmbeddingFunction, Embeddings, PersistentClient, Collection
|
|
|
from typing import Dict, List
|
|
|
import os
|
|
|
from dotenv import load_dotenv
|
|
|
load_dotenv(override=True)
|
|
|
from text_chunk import *
|
|
|
|
|
|
class GeminiEmbeddingFuction(EmbeddingFunction):
|
|
|
"""
|
|
|
Custom embedding function using the Gemini AI API for document retrieval.
|
|
|
|
|
|
This class extends the EmbeddingFunction class and implements the __call__ method
|
|
|
to generate embeddings for a given set of documents using the Gemini AI API.
|
|
|
|
|
|
Parameters:
|
|
|
- input (Documents): A collection of documents to be embedded.
|
|
|
|
|
|
Returns:
|
|
|
- Embeddings: Embeddings generated for the input documents.
|
|
|
"""
|
|
|
|
|
|
def __call__(self, input: Documents) -> Embeddings:
|
|
|
genai.configure(api_key=os.getenv("GEMINI_API"))
|
|
|
return genai.embed_content(model = "models/embedding-001",
|
|
|
content= input,
|
|
|
task_type="retrieval_document",
|
|
|
title="Query")['embedding']
|
|
|
|
|
|
|
|
|
def create_chroma_db(documents: List[str], path: str, name: str):
|
|
|
"""
|
|
|
Creates a Chroma database using the provided documents, path, and collection name.
|
|
|
|
|
|
Parameters:
|
|
|
- documents: An iterable of documents to be added to the Chroma database.
|
|
|
- path (str): The path where the Chroma database will be stored.
|
|
|
- name (str): The name of the collection within the Chroma database.
|
|
|
|
|
|
Returns:
|
|
|
- Tuple[chromadb.Collection, str]: A tuple containing the created Chroma Collection and its name.
|
|
|
"""
|
|
|
|
|
|
chroma_client = PersistentClient(path=path)
|
|
|
db = chroma_client.create_collection(name=name,
|
|
|
embedding_function=GeminiEmbeddingFuction())
|
|
|
for i, d in enumerate(documents):
|
|
|
db.add(documents=[d], ids = str(i))
|
|
|
return db, name
|
|
|
|
|
|
def load_chroma_db(path: str, name: str):
|
|
|
"""
|
|
|
Loads an existing Chroma collection from the specified path with the given name.
|
|
|
|
|
|
Parameters:
|
|
|
- path (str): The path where the Chroma database is stored.
|
|
|
- name (str): The name of the collection within the Chroma database.
|
|
|
|
|
|
Returns:
|
|
|
- chromadb.Collection: The loaded Chroma Collection.
|
|
|
"""
|
|
|
|
|
|
chroma_client = PersistentClient(path=path)
|
|
|
db = chroma_client.get_collection(name=name, embedding_function=GeminiEmbeddingFuction())
|
|
|
return db
|
|
|
|
|
|
def get_relevant_passage(query: str, db: Collection, n_results: int):
|
|
|
"""
|
|
|
semantic search to retrieve the most similar chunks of text from the database.
|
|
|
|
|
|
Parameters:
|
|
|
query (str): The query to search for.
|
|
|
n_results (int): The number of results to return.
|
|
|
db (chromadb.Collection): The Chroma collection to search.
|
|
|
|
|
|
Returns:
|
|
|
List[str]: A list of the most similar chunks of text.
|
|
|
"""
|
|
|
passage = db.query(query_texts=[query],
|
|
|
n_results=n_results)['documents'][0]
|
|
|
return passage
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Done")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|