CV-Info-Agent / chromadb_query.py
dure-waseem's picture
initial code
b1c00a1
import chromadb
import time
import chromadb.utils.embedding_functions as embedding_functions
import os
from openai import OpenAI
class ChromaCollection:
def __init__(self, collection_name, db_path, api_key=None):
# Initialize Chroma persistent client and collection name
self.chroma_client = chromadb.PersistentClient(path=db_path)
self.collection_name = collection_name
self.collection = None
# Use provided API key or fall back to environment variable
self.openai_key = api_key or os.getenv("OPENAI_API_KEY")
if not self.openai_key:
raise ValueError("OpenAI API key is required")
self.openai_ef = embedding_functions.OpenAIEmbeddingFunction(
api_key=self.openai_key,
model_name="text-embedding-ada-002"
)
# Initialize OpenAI client
self.openai_client = OpenAI(api_key=self.openai_key)
self._initialize_collection()
def _initialize_collection(self):
"""
Initializes the collection if it doesn't exist.
"""
try:
self.collection = self.chroma_client.get_collection(
name=self.collection_name,
embedding_function=self.openai_ef
)
print(f"Collection '{self.collection_name}' already exists.")
except Exception as e:
# If collection doesn't exist, create a new one
self.collection = self.chroma_client.create_collection(
name=self.collection_name,
embedding_function=self.openai_ef
)
print(f"Created new collection '{self.collection_name}'.")
def query_collection(self, query_texts, n_results=1):
"""
Queries the collection with the given text and returns the results.
:param query_texts: List of query strings
:param n_results: Number of results to return
:return: Query results
"""
max_retries = 3
for attempt in range(max_retries):
try:
results = self.collection.query(
query_texts=query_texts, # Chroma will embed this for you
n_results=n_results # How many results to return
)
return results
except Exception as e:
error_msg = str(e).lower()
print(f"Query attempt {attempt + 1} failed: {e}")
if "connection" in error_msg or "timeout" in error_msg:
if attempt < max_retries - 1:
wait_time = (attempt + 1) * 2
print(f"Connection issue detected. Waiting {wait_time} seconds before retry...")
import time
time.sleep(wait_time)
continue
else:
# Non-connection error, don't retry
break
print(f"All {max_retries} query attempts failed")
return {"documents": [[]], "metadatas": [[]], "distances": [[]]}
def generate_answer(self, query, results):
"""
Takes the query and ChromaDB results and generates an accurate answer using the LLM.
:param query: User's query
:param results: ChromaDB results
:return: Generated answer from LLM
"""
# Check if we have any results
if not results['documents'][0]:
return "No relevant documents found to answer your question."
# Prepare the context for LLM by appending the query and results
documents_text = "\n".join(results['documents'][0][:5]) # Use top 5 results
context = f"""Based on the following context from the documents, please answer the user's question accurately and concisely.
Context from documents:
{documents_text}
User's question: {query}
Please provide a clear and accurate answer based only on the information provided in the context above."""
try:
# Use the new OpenAI API format
response = self.openai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "system",
"content": "You are a helpful assistant that answers questions based on provided document context. Only use information from the provided context to answer questions."
},
{
"role": "user",
"content": context
}
],
max_tokens=500,
temperature=0.1
)
# Extract and return the answer from the response
return response.choices[0].message.content.strip()
except Exception as e:
return f"Error generating answer: {str(e)}"
def get_collection_count(self):
"""
Get the number of documents in the collection.
"""
try:
return self.collection.count()
except Exception as e:
print(f"Error getting collection count: {e}")
return 0