dataviz / ui /app.py
Медведев Андрей Васильевич
update ui
4fc6cf4
import gradio as gr
import pandas as pd
import os
import tempfile
import re
import base64
import io
import zipfile
import logging
import asyncio
from PIL import Image
from docx import Document
from docx.shared import Inches
from agent.agent import DataVizAgent
from mcp_tools.client import DataVizClient
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler('dataviz_agent.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
# Initialize Agent
agent = DataVizAgent()
# Initialize MCP Client
mcp_client = DataVizClient()
def b64_to_pil(b64_str):
return Image.open(io.BytesIO(base64.b64decode(b64_str)))
def analyze_dataset(file_path):
"""
Analyzes the dataset and returns a summary and the dataframe.
"""
if file_path is None:
return None, "No file uploaded."
try:
if file_path.endswith('.csv'):
df = pd.read_csv(file_path)
elif file_path.endswith('.xlsx'):
df = pd.read_excel(file_path)
else:
return None, "Unsupported file format. Please upload CSV or Excel."
# Validate dataset
if df.empty:
return None, "Error: The uploaded file is empty."
if len(df.columns) == 0:
return None, "Error: No columns found in the dataset."
if len(df) > 1000000:
return None, "Error: Dataset is too large (>1M rows). Please use a smaller file."
except Exception as e:
return None, f"Error loading file: {str(e)}"
summary = {
"columns": [],
"row_count": len(df)
}
for col in df.columns:
col_info = {
"name": col,
"type": str(df[col].dtype),
"unique_values": df[col].nunique(),
"missing_values": df[col].isnull().sum()
}
if pd.api.types.is_numeric_dtype(df[col]):
try:
min_val = df[col].min()
max_val = df[col].max()
col_info["min"] = float(min_val) if pd.notna(min_val) else None
col_info["max"] = float(max_val) if pd.notna(max_val) else None
except (ValueError, TypeError):
col_info["min"] = None
col_info["max"] = None
col_info["is_numeric"] = True
else:
col_info["is_numeric"] = False
summary["columns"].append(col_info)
return df, summary
def process_upload(file):
logger.info(f"Processing file upload: {file.name}")
df, summary = analyze_dataset(file.name)
if df is None:
logger.error(f"Failed to load file: {file.name}")
return None, {}, "Error loading file.", None
# Save dataframe to a temporary parquet file for the MCP tool
fd, path = tempfile.mkstemp(suffix='.parquet')
os.close(fd)
df.to_parquet(path)
logger.info(f"Dataset saved to temp file: {path}")
# Create a readable summary string
summary_str = f"Dataset Loaded: {len(df)} rows, {len(df.columns)} columns.\n\nColumns:\n"
for col in summary["columns"]:
summary_str += f"- {col['name']} ({col['type']}): {col['unique_values']} unique"
if col['is_numeric'] and col.get('min') is not None and col.get('max') is not None:
summary_str += f", range: [{col['min']:.2f}, {col['max']:.2f}]"
summary_str += "\n"
return df, summary, summary_str, path
async def respond(message, chat_history, state):
logger.info(f"User message: {message}")
if state["dataframe"] is None:
logger.warning("User attempted to chat without uploading dataset")
chat_history.append({"role": "user", "content": message})
chat_history.append({"role": "assistant", "content": "Please upload a dataset first."})
return "", chat_history, gr.update(), state, gr.update(choices=[])
# Check for chart modification request
chart_id_match = re.search(r'#(\d+)', message)
existing_code = None
target_chart_id = None
if chart_id_match:
chart_id = int(chart_id_match.group(1))
if chart_id in state["charts"]:
existing_code = state["charts"][chart_id]["code"]
target_chart_id = chart_id
logger.info(f"Modifying chart #{chart_id}")
else:
chat_history.append({"role": "user", "content": message})
chat_history.append({"role": "assistant", "content": f"Chart #{chart_id} not found."})
return "", chat_history, _get_gallery_items(state), state, _get_chart_choices(state)
# Generate response using Agent (with chat history)
response = agent.generate_plot_code(
message,
state["columns_summary"],
history=chat_history,
existing_code=existing_code
)
chat_history.append({"role": "user", "content": message})
# Check response type
if response["type"] == "error":
logger.error(f"Agent error: {response['content']}")
chat_history.append({"role": "assistant", "content": f"Error: {response['content']}"})
return "", chat_history, _get_gallery_items(state), state, _get_chart_choices(state)
elif response["type"] == "message":
# Conversational response - no code to execute
logger.info("Agent provided conversational response")
chat_history.append({"role": "assistant", "content": response["content"]})
return "", chat_history, _get_gallery_items(state), state, _get_chart_choices(state)
elif response["type"] == "code":
# Code generation - execute it
code = response["content"]
logger.info("Executing generated code")
# Execute code using MCP Tool
result = await mcp_client.generate_plot(code, state["data_path"])
gallery_update = _get_gallery_items(state)
if result["success"]:
# Determine Chart ID
if target_chart_id:
cid = target_chart_id
action = "Updated"
else:
cid = state["next_chart_id"]
state["next_chart_id"] += 1
action = "Created"
# Generate description
description = agent.describe_chart(message, code)
# Update State
state["charts"][cid] = {
"code": code,
"image": result["image"],
"description": description
}
response_text = f"{action} chart #{cid}: {description}"
chat_history.append({"role": "assistant", "content": response_text})
logger.info(f"{action} chart #{cid}")
gallery_update = _get_gallery_items(state, selected_cid=cid)
else:
error_details = result.get('stderr', result.get('error', 'Unknown error occurred'))
error_msg = f"Failed to generate chart.\nError: {error_details}\n\nCode:\n```python\n{code}\n```"
chat_history.append({"role": "assistant", "content": error_msg})
logger.error(f"Chart generation failed: {error_details}")
return "", chat_history, gallery_update, state, _get_chart_choices(state)
# Fallback
return "", chat_history, _get_gallery_items(state), state, _get_chart_choices(state)
def _get_gallery_items(state, selected_cid=None):
items = []
selected_index = None
current_idx = 0
# Sort by ID
for cid in sorted(state["charts"].keys()):
chart = state["charts"][cid]
if chart["image"]:
img = b64_to_pil(chart["image"])
items.append((img, f"#{cid} {chart['description']}"))
if selected_cid is not None and cid == selected_cid:
selected_index = current_idx
current_idx += 1
if selected_cid is not None:
return gr.update(value=items, selected_index=selected_index)
return items
def _get_chart_choices(state):
return gr.update(choices=[f"#{cid}" for cid in sorted(state["charts"].keys())])
def delete_chart(chart_str, chat_history, state):
if not chart_str:
return chat_history, _get_gallery_items(state), state, _get_chart_choices(state)
try:
cid = int(chart_str.replace("#", ""))
if cid in state["charts"]:
del state["charts"][cid]
chat_history.append({"role": "assistant", "content": f"🗑️ Chart #{cid} has been deleted."})
except:
pass
return chat_history, _get_gallery_items(state), state, _get_chart_choices(state)
def download_zip(state):
if not state["charts"]:
return None
zip_filename = tempfile.mktemp(suffix=".zip")
with zipfile.ZipFile(zip_filename, 'w') as zipf:
for cid, chart in state["charts"].items():
if chart["image"]:
img_data = base64.b64decode(chart["image"])
zipf.writestr(f"chart_{cid}.png", img_data)
return zip_filename
def download_report(state):
if not state["charts"]:
return None
doc = Document()
doc.add_heading('DataViz Agent Report', 0)
for cid in sorted(state["charts"].keys()):
chart = state["charts"][cid]
if chart["image"]:
doc.add_heading(f"Chart #{cid}: {chart['description']}", level=1)
# Save temp image for docx
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_img:
tmp_img.write(base64.b64decode(chart["image"]))
tmp_img_path = tmp_img.name
try:
doc.add_picture(tmp_img_path, width=Inches(6))
finally:
os.remove(tmp_img_path)
doc.add_paragraph(f"Code:\n{chart['code']}")
doc.add_page_break()
doc_filename = tempfile.mktemp(suffix=".docx")
doc.save(doc_filename)
return doc_filename
def global_clear():
logger.info("Global clear initiated")
new_state = {
"dataframe": None,
"columns_summary": {},
"charts": {},
"next_chart_id": 1,
"data_path": None
}
return (
None, # File
"Upload a dataset to get started.", # Info
[], # Chat
[], # Gallery
new_state, # State
gr.update(choices=[]), # Dropdown
None # Download File
)
with gr.Blocks(title="DataViz Agent", theme=gr.themes.Soft(), fill_height=True) as demo:
state = gr.State({
"dataframe": None,
"columns_summary": {},
"charts": {},
"next_chart_id": 1,
"data_path": None
})
with gr.Row():
with gr.Column(scale=3):
gr.Markdown("## 🤖 DataViz Agent Chat")
with gr.Column(scale=2):
gr.Markdown("## 📊 Charts Gallery")
with gr.Row():
with gr.Column(scale=3):
with gr.Row():
with gr.Group():
file_upload = gr.File(label="Upload Dataset (CSV/XLSX)", file_types=[".csv", ".xlsx"])
with gr.Accordion("Dataset Info", open=False):
dataset_info = gr.Markdown("Upload a dataset to get started.")
with gr.Row(scale=1, height=700):
chatbot = gr.Chatbot(type="messages", height=700)
with gr.Row(height=50, equal_height=True):
msg = gr.Textbox(
placeholder="Ask to visualize data (e.g., 'Show distribution of age')",
show_label=False,
elem_id="chat-input",
lines=1,
max_lines=1,
scale=1
)
send_btn = gr.Button("Send", variant="primary", scale=0)
with gr.Column(scale=2):
with gr.Row(height=626):
gallery = gr.Gallery(label="Generated Charts", columns=1, object_fit="contain", height=626)
with gr.Row():
with gr.Group():
gr.Markdown("### Manage Charts")
with gr.Row():
chart_selector = gr.Dropdown(label="Select Chart to Delete", choices=[])
delete_btn = gr.Button("🗑️ Delete Chart", variant="stop")
with gr.Row():
dl_zip_btn = gr.Button("💾 Download All (ZIP)")
dl_report_btn = gr.Button("📄 Download Report (Word)")
with gr.Row(height=80):
dl_file = gr.File(label="Download", visible=True)
# Global Clear (Bottom)
with gr.Row():
global_clear_btn = gr.Button("Global Clear (Reset All)", variant="stop")
# Event Handlers
def on_file_upload(file, current_state):
if file is None:
return current_state, "Upload a dataset to get started."
df, summary, summary_str, path = process_upload(file)
if df is not None:
current_state["dataframe"] = df
current_state["columns_summary"] = summary
current_state["data_path"] = path
return current_state, summary_str
return current_state, summary_str
def on_file_upload_wrapper(file, current_state):
# Clean up old temporary file if exists
if current_state.get("data_path") and os.path.exists(current_state["data_path"]):
try:
os.remove(current_state["data_path"])
logger.info(f"Cleaned up old temp file: {current_state['data_path']}")
except Exception as e:
logger.warning(f"Failed to remove temp file: {e}")
return on_file_upload(file, current_state)
file_upload.change(
on_file_upload_wrapper,
inputs=[file_upload, state],
outputs=[state, dataset_info]
)
# Chat interactions
msg.submit(
respond,
inputs=[msg, chatbot, state],
outputs=[msg, chatbot, gallery, state, chart_selector]
).then(
None, None, None,
js="() => { setTimeout(() => { const el = document.getElementById('chat-input'); if (el) { const input = el.querySelector('textarea') || el.querySelector('input'); if (input) input.focus(); } }, 200); }"
)
send_btn.click(
respond,
inputs=[msg, chatbot, state],
outputs=[msg, chatbot, gallery, state, chart_selector]
).then(
None, None, None,
js="() => { setTimeout(() => { const el = document.getElementById('chat-input'); if (el) { const input = el.querySelector('textarea') || el.querySelector('input'); if (input) input.focus(); } }, 200); }"
)
# Chart Management
delete_btn.click(
delete_chart,
inputs=[chart_selector, chatbot, state],
outputs=[chatbot, gallery, state, chart_selector]
)
dl_zip_btn.click(
download_zip,
inputs=[state],
outputs=[dl_file]
)
dl_report_btn.click(
download_report,
inputs=[state],
outputs=[dl_file]
)
global_clear_btn.click(
global_clear,
inputs=[],
outputs=[file_upload, dataset_info, chatbot, gallery, state, chart_selector, dl_file]
)
if __name__ == "__main__":
demo.launch()