rag / app.py
poemsforaphrodite's picture
Update app.py
e760570 verified
import os
from dotenv import load_dotenv
import gradio as gr
import chromadb
from chromadb.utils import embedding_functions
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from pymongo import MongoClient
from pymongo.errors import ConnectionFailure
from urllib.parse import quote_plus
# Load environment variables
load_dotenv()
openai_api_key = os.getenv("OPENAI_API_KEY")
REPO_ID = "poemsforaphrodite/rag" # Replace with your actual space ID
CHROMA_PATH = "./chroma_db"
# MongoDB connection details
MONGO_URI = os.getenv("MONGO_URI")
# Initialize MongoDB client
try:
mongo_client = MongoClient(MONGO_URI)
db = mongo_client.get_database("chatbot_db")
chat_logs_collection = db.get_collection("chat_logs")
except ConnectionFailure as e:
raise
def clean_text_with_gpt(text: str) -> str:
model = ChatOpenAI(model="gpt-4o-mini", openai_api_key=openai_api_key)
template = """Clean the following text, only show the main text
{text}
Cleaned text:"""
prompt = ChatPromptTemplate.from_template(template)
chain = prompt | model | StrOutputParser()
return chain.invoke({"text": text})
def log_chat(query: str, response: str):
chat_entry = {
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": query},
{"role": "assistant", "content": response}
]
}
try:
chat_logs_collection.insert_one(chat_entry)
except Exception:
pass
def query_all_collections(query: str) -> tuple[str, str]:
try:
# Initialize Chroma client
chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
# Create an OpenAI embedding function
openai_ef = embedding_functions.OpenAIEmbeddingFunction(
api_key=openai_api_key,
model_name="text-embedding-ada-002"
)
# Get all collection names
collection_names = [col.name for col in chroma_client.list_collections()]
all_results = []
for collection_name in collection_names:
try:
collection = chroma_client.get_collection(name=collection_name, embedding_function=openai_ef)
results = collection.query(
query_texts=[query],
n_results=1, # Adjust as needed
include=['documents', 'metadatas', 'distances']
)
for doc, meta in zip(results['documents'][0], results['metadatas'][0]):
book_name = meta.get('file_name') or meta.get('source') or meta.get('book_name') or 'Unknown'
all_results.append((doc, book_name, collection_name))
except Exception:
continue
if not all_results:
return "No relevant documents found.", ""
# Sort results by relevance (assuming the query returns most relevant first)
all_results.sort(key=lambda x: x[0])
# Combine results from all collections and clean the text using GPT-4
combined_docs = []
for i, (doc, book_name, col_name) in enumerate(all_results):
cleaned_text = clean_text_with_gpt(doc)
formatted_doc = f"""
Document {i+1}:
Collection: {col_name}
Book: {os.path.basename(book_name)}
Content: {cleaned_text}
---"""
combined_docs.append(formatted_doc)
context = "\n".join([doc for doc, _, _ in all_results])
# Use the combined results for the response
model = ChatOpenAI(model="gpt-4o-mini", openai_api_key=openai_api_key)
template = """Answer the question based only on the following context:
{context}
Question: {question}
After providing your answer, please add the following question:
"Based on your clinical judgment and the patient's complete history and current presentation, do you agree with this recommendation, or are there additional considerations or adjustments needed?"
"""
prompt = ChatPromptTemplate.from_template(template)
chain = prompt | model | StrOutputParser()
response = chain.invoke({"context": context, "question": query})
response_text = f"Response: {response}"
# Log the chat interaction
log_chat(query, response)
return response_text, "\n".join(combined_docs)
except Exception as e:
return f"An error occurred: {str(e)}", ""
# Gradio interface
iface = gr.Interface(
fn=query_all_collections,
inputs=[
gr.Textbox(lines=1, placeholder="Enter your query here")
],
outputs=[
gr.Textbox(lines=10, label="Answer"),
gr.Textbox(lines=10, label="Relevant Document Chunks")
],
title="Multi-Collection Document Retrieval and QA Chatbot",
description="Ask questions based on the content across all collections in the document database.",
)
if __name__ == "__main__":
iface.launch(share=True)