|
|
"""Gradio Chat Interface for Data Analyzer Agent with Image Support.""" |
|
|
import os |
|
|
import base64 |
|
|
import io |
|
|
import tempfile |
|
|
from PIL import Image |
|
|
|
|
|
import gradio as gr |
|
|
from openai import OpenAI |
|
|
from e2b_code_interpreter import Sandbox |
|
|
|
|
|
from src import coding_agent, execute_code_schema, tools |
|
|
|
|
|
|
|
|
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") |
|
|
E2B_API_KEY = os.getenv("E2B_API_KEY") |
|
|
|
|
|
|
|
|
client = OpenAI() if OPENAI_API_KEY else None |
|
|
|
|
|
|
|
|
sbx = None |
|
|
|
|
|
|
|
|
SYSTEM_PROMPT = """You are a data analysis agent. Generate Python code to analyze data, perform statistical analysis, and create visualizations using matplotlib, pandas, numpy, and seaborn. Always use these libraries for professional data analysis.""" |
|
|
|
|
|
|
|
|
def respond(message, history): |
|
|
"""Handle chat with image support.""" |
|
|
global sbx |
|
|
|
|
|
|
|
|
if not client or not E2B_API_KEY: |
|
|
yield "Error: Environment variables not set. Please set OPENAI_API_KEY and E2B_API_KEY.", [] |
|
|
return |
|
|
|
|
|
try: |
|
|
|
|
|
if sbx is None: |
|
|
sbx = Sandbox.create(timeout=3600) |
|
|
|
|
|
|
|
|
messages, metadata = coding_agent( |
|
|
client=client, |
|
|
query=message, |
|
|
system=SYSTEM_PROMPT, |
|
|
tools=tools, |
|
|
tools_schemas=[execute_code_schema], |
|
|
sbx=sbx, |
|
|
messages=None, |
|
|
max_steps=5 |
|
|
) |
|
|
|
|
|
|
|
|
response_text = "" |
|
|
for msg in reversed(messages): |
|
|
if isinstance(msg, dict) and msg.get("type") == "message": |
|
|
content = msg.get("content", []) |
|
|
if isinstance(content, list): |
|
|
text_parts = [item.get("text", "") for item in content |
|
|
if isinstance(item, dict) and item.get("type") == "output_text"] |
|
|
response_text = "".join(text_parts) |
|
|
else: |
|
|
response_text = str(content) |
|
|
break |
|
|
|
|
|
if not response_text: |
|
|
response_text = "Analysis complete." |
|
|
|
|
|
|
|
|
image_files = [] |
|
|
if metadata.get("images"): |
|
|
temp_dir = tempfile.gettempdir() |
|
|
for i, png_data in enumerate(metadata["images"]): |
|
|
|
|
|
img_bytes = base64.b64decode(png_data) |
|
|
img = Image.open(io.BytesIO(img_bytes)) |
|
|
|
|
|
|
|
|
temp_path = os.path.join(temp_dir, f"plot_{i}.png") |
|
|
img.save(temp_path) |
|
|
image_files.append(temp_path) |
|
|
|
|
|
yield response_text, image_files |
|
|
|
|
|
except Exception as e: |
|
|
import traceback |
|
|
error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" |
|
|
yield error_msg, [] |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Data Analyzer Agent") as demo: |
|
|
gr.Markdown("# Data Analyzer Agent") |
|
|
gr.Markdown("Ask me to analyze data and create visualizations!") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
chatbot = gr.Chatbot(label="Chat", height=400, type="tuples") |
|
|
msg = gr.Textbox( |
|
|
label="Message", |
|
|
placeholder="Ask me to analyze data...", |
|
|
lines=2 |
|
|
) |
|
|
with gr.Row(): |
|
|
submit = gr.Button("Submit", variant="primary") |
|
|
clear = gr.Button("Clear") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gallery = gr.Gallery(label="Visualizations", columns=1, height=400) |
|
|
|
|
|
|
|
|
gr.Examples( |
|
|
examples=[ |
|
|
"Calculate the mean of [1,2,3,4,5]", |
|
|
"Generate 50 random numbers from a normal distribution and plot a histogram", |
|
|
"Create a scatter plot of 20 random x,y points" |
|
|
], |
|
|
inputs=msg |
|
|
) |
|
|
|
|
|
def user_submit(user_message, history): |
|
|
"""Handle user message submission.""" |
|
|
return "", history + [[user_message, None]] |
|
|
|
|
|
def bot_respond(history): |
|
|
"""Get bot response with images.""" |
|
|
user_message = history[-1][0] |
|
|
bot_response, images = None, [] |
|
|
|
|
|
for response_text, image_files in respond(user_message, history[:-1]): |
|
|
bot_response = response_text |
|
|
images = image_files |
|
|
|
|
|
history[-1][1] = bot_response |
|
|
yield history, images |
|
|
|
|
|
|
|
|
history[-1][1] = bot_response |
|
|
yield history, images |
|
|
|
|
|
|
|
|
demo.queue(default_concurrency_limit=10) |
|
|
|
|
|
msg.submit(user_submit, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
|
bot_respond, chatbot, [chatbot, gallery] |
|
|
) |
|
|
submit.click(user_submit, [msg, chatbot], [msg, chatbot], queue=False).then( |
|
|
bot_respond, chatbot, [chatbot, gallery] |
|
|
) |
|
|
clear.click(lambda: ([], []), None, [chatbot, gallery], queue=False) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(show_error=True) |
|
|
|