talk2data / gradio_app.py
amirkiarafiei's picture
feat: demo with yugabyte
0391cfb
raw
history blame
6.48 kB
from asyncio.log import logger
import yaml
from pathlib import Path
import gradio as gr
import asyncio
from langchain_mcp_client import lc_mcp_exec
from dotenv import load_dotenv
import os
import base64
from memory_store import MemoryStore
import logging
# ======================================= Load DB configs
def load_db_configs():
"""Load database configurations from configs.yaml"""
configs_path = Path("configs.yaml")
if not configs_path.exists():
raise FileNotFoundError("configs.yaml not found")
with open(configs_path) as f:
configs = yaml.safe_load(f)
return configs["db_configs"]
def image_to_base64_markdown(image_path, alt_text="Customer Status"):
with open(image_path, "rb") as f:
encoded = base64.b64encode(f.read()).decode("utf-8")
return f"![{alt_text}](data:image/png;base64,{encoded})"
# ====================================== Async-compatible wrapper
async def run_agent(request, history=None):
try:
# Process request using existing memory
response, messages = await lc_mcp_exec(request)
# Handle image processing
image_path = ""
load_dotenv()
PANDAS_EXPORTS_PATH = os.getenv("PANDAS_EXPORTS_PATH", "exports/charts")
# Ensure the exports directory exists
os.makedirs(PANDAS_EXPORTS_PATH, exist_ok=True)
# Check for generated charts
generated_files = [f for f in os.listdir(PANDAS_EXPORTS_PATH)
if f.startswith("temp_chart_") and f.endswith(".png")]
if generated_files:
image_path = os.path.join(PANDAS_EXPORTS_PATH, generated_files[0])
try:
image_markdown = image_to_base64_markdown(image_path)
output = f"{image_markdown}\n\n{response}"
os.remove(image_path) # Clean up the image file
except Exception as e:
logger.error(f"Error processing image: {e}")
output = response
else:
output = response
return output
except Exception as e:
logger.error(f"Error in run_agent: {str(e)}", exc_info=True)
return f"Error: {str(e)}"
# ====================================== Gradio UI with history
LOGO_PATH = "resources/pialogo.png"
# CSS customizations
custom_css = """
.container {
max-width: 2200px !important;
margin: auto;
padding: 20px;
}
.chat-container {
height: 1000px !important;
min-height: 1000px !important;
overflow-y: auto;
}
.message-container {
padding: 15px;
border-radius: 10px;
margin: 10px 0;
}
.markdown-content {
font-size: 18px;
line-height: 1.7;
}
"""
with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
with gr.Row(elem_classes="container"):
# with gr.Column(scale=0.5):
# gr.Image(value=LOGO_PATH, height=100, show_label=False, show_download_button=False, show_fullscreen_button=False)
with gr.Column(scale=5):
gr.Markdown(
"""
<h1 style='text-align: center; margin-bottom: 1rem'>Talk to Your Data</h1>
<p style='text-align: center'>Ask questions about your database, analyze and visualize data.</p>
"""
)
with gr.Row(elem_classes="container"):
with gr.Column(scale=3):
chat = gr.ChatInterface(
fn=run_agent,
chatbot=gr.Chatbot(
height=1000,
show_label=False,
elem_classes="chat-container",
render_markdown=True
),
textbox=gr.Textbox(
placeholder="Type your questions here...",
container=False,
scale=4
),
theme="soft",
# examples=[
# "Describe the database",
# "List all tables in the database",
# "List all tables with columns and data types",
# "How many customers do you have?",
# "What are the statuses my of my customers",
# "Visualize with different colors and show legend",
# "What are the statues of my customers and how many are in each status, show it by percentage",
# "Total number of completed orders in six years by customer count show top most 10 customers",
# "In january how many products has been sold ? group them by year",
# "How many users and roles have been created in 2024"
# ],
examples=[
"Describe the database",
"List all tables in the database",
"List all tables with columns and data types",
"How many comments are there per ticket channel (email, chat, portal)? Also Visualize it as a pie chart",
"Visualize with different colors and show legend",
"How many customers are in each industry?",
"List the 5 most active agents by ticket count in 2024.",
"How many tickets were reopened at least once?"
],
save_history=True,
type="messages"
)
with gr.Column(scale=1):
with gr.Accordion("Example Questions", open=True):
# gr.Markdown("""
# - πŸ“Š List all tables in database
# - πŸ‘₯ Total number of customers
# - πŸ“ˆ Visualize it with different colors
# - πŸ“‹ Order statistics for last 6 years
# - πŸ“† User and role counts in 2024
# """)
gr.Markdown("""
- πŸ“Š List all tables in database
- πŸ‘₯ Total number of customers
- πŸ“ˆ Visualize it with different colors
- πŸ“‹ Order statistics for last 6 years
- πŸ“† Average ticket reopen count per year
""")
# TODO: maybe we can add a mcp tool to validate the results (those converted to DataFrame) to make sure the valid type is passed to the visualization tool by ReAct agent
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
# server_port=7860,
share=True,
debug=True
)