Market_research_bot / front_end.py
MrPatrickHenry's picture
Update front_end.py
25f426a verified
import os
import gradio as gr
from pymongo import MongoClient
from rag import mongo_rag_tool
import logging
import io
# Log stream for capturing logs
log_stream = io.StringIO()
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(message)s", stream=log_stream)
# Constants
DATABASE_NAME = "9d6b59e0-667c-5017-aae1-dd5989df421d"
def get_collection_names():
"""
Fetch the list of collection names from the MongoDB database.
"""
try:
mongo_connection_string = os.getenv("MONGO_CONNECTION_STRING")
if not mongo_connection_string:
raise ValueError("Missing MONGO_CONNECTION_STRING environment variable.")
client = MongoClient(mongo_connection_string)
db = client[DATABASE_NAME]
collection_names = db.list_collection_names()
return collection_names
except Exception as e:
logging.error(f"Error fetching collections: {e}")
return ["Error fetching collections"]
# Fetch the collection names dynamically
collection_list = get_collection_names()
def query_rag_tool(query, collection_name):
"""
Wrapper to handle Gradio inputs, call mongo_rag_tool, and capture logs.
"""
try:
# Clear log stream
log_stream.truncate(0)
log_stream.seek(0)
# Call the tool
answer, sources = mongo_rag_tool(query, collection_name)
# Get logs from the log stream
logs = log_stream.getvalue()
return answer, sources, logs
except Exception as e:
logs = log_stream.getvalue()
logs += f"\nError: {e}"
return "Error occurred", "No sources available", logs
# Gradio interface
interface = gr.Interface(
fn=query_rag_tool,
inputs=[
gr.Textbox(
label="Query",
placeholder="What do people think about Caterpillar Vision Link?",
lines=1
),
gr.Dropdown(
label="Collection Name",
choices=collection_list,
value=collection_list[0] if collection_list else None,
interactive=True
)
],
outputs=[
gr.Textbox(label="Answer", lines=5, placeholder="The answer will appear here..."),
gr.Textbox(label="Sources", lines=10, placeholder="The sources will appear here..."),
gr.Textbox(label="Logs", lines=15, placeholder="Logs for debugging will appear here...")
],
title="Andiron's AI Market Research Tool",
description=(
"Interact with your clients data using Andirons' Retrieval Augmented Generation (RAG) model. "
"Select a collection, ask your query."
),
)
if __name__ == "__main__":
interface.launch()