Spaces:
Paused
Paused
| import os | |
| from dotenv import load_dotenv | |
| import gradio as gr | |
| import chromadb | |
| from chromadb.utils import embedding_functions | |
| from langchain_openai import ChatOpenAI | |
| from langchain_core.output_parsers import StrOutputParser | |
| from langchain_core.prompts import ChatPromptTemplate | |
| from pymongo import MongoClient | |
| from pymongo.errors import ConnectionFailure | |
| from urllib.parse import quote_plus | |
| # Load environment variables | |
| load_dotenv() | |
| openai_api_key = os.getenv("OPENAI_API_KEY") | |
| REPO_ID = "poemsforaphrodite/rag" # Replace with your actual space ID | |
| CHROMA_PATH = "./chroma_db" | |
| # MongoDB connection details | |
| MONGO_URI = os.getenv("MONGO_URI") | |
| # Initialize MongoDB client | |
| try: | |
| mongo_client = MongoClient(MONGO_URI) | |
| db = mongo_client.get_database("chatbot_db") | |
| chat_logs_collection = db.get_collection("chat_logs") | |
| except ConnectionFailure as e: | |
| raise | |
| def clean_text_with_gpt(text: str) -> str: | |
| model = ChatOpenAI(model="gpt-4o-mini", openai_api_key=openai_api_key) | |
| template = """Clean the following text, only show the main text | |
| {text} | |
| Cleaned text:""" | |
| prompt = ChatPromptTemplate.from_template(template) | |
| chain = prompt | model | StrOutputParser() | |
| return chain.invoke({"text": text}) | |
| def log_chat(query: str, response: str): | |
| chat_entry = { | |
| "messages": [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": query}, | |
| {"role": "assistant", "content": response} | |
| ] | |
| } | |
| try: | |
| chat_logs_collection.insert_one(chat_entry) | |
| except Exception: | |
| pass | |
| def query_all_collections(query: str) -> tuple[str, str]: | |
| try: | |
| # Initialize Chroma client | |
| chroma_client = chromadb.PersistentClient(path=CHROMA_PATH) | |
| # Create an OpenAI embedding function | |
| openai_ef = embedding_functions.OpenAIEmbeddingFunction( | |
| api_key=openai_api_key, | |
| model_name="text-embedding-ada-002" | |
| ) | |
| # Get all collection names | |
| collection_names = [col.name for col in chroma_client.list_collections()] | |
| all_results = [] | |
| for collection_name in collection_names: | |
| try: | |
| collection = chroma_client.get_collection(name=collection_name, embedding_function=openai_ef) | |
| results = collection.query( | |
| query_texts=[query], | |
| n_results=1, # Adjust as needed | |
| include=['documents', 'metadatas', 'distances'] | |
| ) | |
| for doc, meta in zip(results['documents'][0], results['metadatas'][0]): | |
| book_name = meta.get('file_name') or meta.get('source') or meta.get('book_name') or 'Unknown' | |
| all_results.append((doc, book_name, collection_name)) | |
| except Exception: | |
| continue | |
| if not all_results: | |
| return "No relevant documents found.", "" | |
| # Sort results by relevance (assuming the query returns most relevant first) | |
| all_results.sort(key=lambda x: x[0]) | |
| # Combine results from all collections and clean the text using GPT-4 | |
| combined_docs = [] | |
| for i, (doc, book_name, col_name) in enumerate(all_results): | |
| cleaned_text = clean_text_with_gpt(doc) | |
| formatted_doc = f""" | |
| Document {i+1}: | |
| Collection: {col_name} | |
| Book: {os.path.basename(book_name)} | |
| Content: {cleaned_text} | |
| ---""" | |
| combined_docs.append(formatted_doc) | |
| context = "\n".join([doc for doc, _, _ in all_results]) | |
| # Use the combined results for the response | |
| model = ChatOpenAI(model="gpt-4o-mini", openai_api_key=openai_api_key) | |
| template = """Answer the question based only on the following context: | |
| {context} | |
| Question: {question} | |
| After providing your answer, please add the following question: | |
| "Based on your clinical judgment and the patient's complete history and current presentation, do you agree with this recommendation, or are there additional considerations or adjustments needed?" | |
| """ | |
| prompt = ChatPromptTemplate.from_template(template) | |
| chain = prompt | model | StrOutputParser() | |
| response = chain.invoke({"context": context, "question": query}) | |
| response_text = f"Response: {response}" | |
| # Log the chat interaction | |
| log_chat(query, response) | |
| return response_text, "\n".join(combined_docs) | |
| except Exception as e: | |
| return f"An error occurred: {str(e)}", "" | |
| # Gradio interface | |
| iface = gr.Interface( | |
| fn=query_all_collections, | |
| inputs=[ | |
| gr.Textbox(lines=1, placeholder="Enter your query here") | |
| ], | |
| outputs=[ | |
| gr.Textbox(lines=10, label="Answer"), | |
| gr.Textbox(lines=10, label="Relevant Document Chunks") | |
| ], | |
| title="Multi-Collection Document Retrieval and QA Chatbot", | |
| description="Ask questions based on the content across all collections in the document database.", | |
| ) | |
| if __name__ == "__main__": | |
| iface.launch(share=True) |