Spaces:
Sleeping
Sleeping
| import chromadb | |
| import time | |
| import chromadb.utils.embedding_functions as embedding_functions | |
| import os | |
| from openai import OpenAI | |
| class ChromaCollection: | |
| def __init__(self, collection_name, db_path, api_key=None): | |
| # Initialize Chroma persistent client and collection name | |
| self.chroma_client = chromadb.PersistentClient(path=db_path) | |
| self.collection_name = collection_name | |
| self.collection = None | |
| # Use provided API key or fall back to environment variable | |
| self.openai_key = api_key or os.getenv("OPENAI_API_KEY") | |
| if not self.openai_key: | |
| raise ValueError("OpenAI API key is required") | |
| self.openai_ef = embedding_functions.OpenAIEmbeddingFunction( | |
| api_key=self.openai_key, | |
| model_name="text-embedding-ada-002" | |
| ) | |
| # Initialize OpenAI client | |
| self.openai_client = OpenAI(api_key=self.openai_key) | |
| self._initialize_collection() | |
| def _initialize_collection(self): | |
| """ | |
| Initializes the collection if it doesn't exist. | |
| """ | |
| try: | |
| self.collection = self.chroma_client.get_collection( | |
| name=self.collection_name, | |
| embedding_function=self.openai_ef | |
| ) | |
| print(f"Collection '{self.collection_name}' already exists.") | |
| except Exception as e: | |
| # If collection doesn't exist, create a new one | |
| self.collection = self.chroma_client.create_collection( | |
| name=self.collection_name, | |
| embedding_function=self.openai_ef | |
| ) | |
| print(f"Created new collection '{self.collection_name}'.") | |
| def query_collection(self, query_texts, n_results=1): | |
| """ | |
| Queries the collection with the given text and returns the results. | |
| :param query_texts: List of query strings | |
| :param n_results: Number of results to return | |
| :return: Query results | |
| """ | |
| max_retries = 3 | |
| for attempt in range(max_retries): | |
| try: | |
| results = self.collection.query( | |
| query_texts=query_texts, # Chroma will embed this for you | |
| n_results=n_results # How many results to return | |
| ) | |
| return results | |
| except Exception as e: | |
| error_msg = str(e).lower() | |
| print(f"Query attempt {attempt + 1} failed: {e}") | |
| if "connection" in error_msg or "timeout" in error_msg: | |
| if attempt < max_retries - 1: | |
| wait_time = (attempt + 1) * 2 | |
| print(f"Connection issue detected. Waiting {wait_time} seconds before retry...") | |
| import time | |
| time.sleep(wait_time) | |
| continue | |
| else: | |
| # Non-connection error, don't retry | |
| break | |
| print(f"All {max_retries} query attempts failed") | |
| return {"documents": [[]], "metadatas": [[]], "distances": [[]]} | |
| def generate_answer(self, query, results): | |
| """ | |
| Takes the query and ChromaDB results and generates an accurate answer using the LLM. | |
| :param query: User's query | |
| :param results: ChromaDB results | |
| :return: Generated answer from LLM | |
| """ | |
| # Check if we have any results | |
| if not results['documents'][0]: | |
| return "No relevant documents found to answer your question." | |
| # Prepare the context for LLM by appending the query and results | |
| documents_text = "\n".join(results['documents'][0][:5]) # Use top 5 results | |
| context = f"""Based on the following context from the documents, please answer the user's question accurately and concisely. | |
| Context from documents: | |
| {documents_text} | |
| User's question: {query} | |
| Please provide a clear and accurate answer based only on the information provided in the context above.""" | |
| try: | |
| # Use the new OpenAI API format | |
| response = self.openai_client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| { | |
| "role": "system", | |
| "content": "You are a helpful assistant that answers questions based on provided document context. Only use information from the provided context to answer questions." | |
| }, | |
| { | |
| "role": "user", | |
| "content": context | |
| } | |
| ], | |
| max_tokens=500, | |
| temperature=0.1 | |
| ) | |
| # Extract and return the answer from the response | |
| return response.choices[0].message.content.strip() | |
| except Exception as e: | |
| return f"Error generating answer: {str(e)}" | |
| def get_collection_count(self): | |
| """ | |
| Get the number of documents in the collection. | |
| """ | |
| try: | |
| return self.collection.count() | |
| except Exception as e: | |
| print(f"Error getting collection count: {e}") | |
| return 0 |