import os import json import uuid import textwrap import streamlit as st from typing import List, Dict, Any, Optional from pathlib import Path import time # Load environment variables from dotenv import load_dotenv load_dotenv() # Get API keys from environment OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") if not OPENAI_API_KEY or not TAVILY_API_KEY: st.error("Please set OPENAI_API_KEY and TAVILY_API_KEY in your .env file") st.stop() # Set environment variables os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY os.environ['TAVILY_API_KEY'] = TAVILY_API_KEY # Imports after setting environment variables from openai import OpenAI from tavily import TavilyClient from langchain_openai import ChatOpenAI from langgraph.prebuilt import create_react_agent from langchain_core.tools import tool import plotly.express as px import plotly.graph_objects as go import plotly.io as pio # Initialize clients openai_client = OpenAI(api_key=OPENAI_API_KEY) tavily_client = TavilyClient(TAVILY_API_KEY) llm_model = ChatOpenAI(model="gpt-4o-mini", temperature=0) # Create plots directory plots_dir = Path("./plots") plots_dir.mkdir(exist_ok=True) # Initialize session state for conversation memory if 'conversation_history' not in st.session_state: st.session_state.conversation_history = [] if 'current_data' not in st.session_state: st.session_state.current_data = None if 'current_plot_context' not in st.session_state: st.session_state.current_plot_context = {} # ── TOOLS ──────────────────────────────────────────────────────────────── @tool def search_web(query: str, max_results: int = 5) -> List[Dict[str, Any]]: """Return Tavily results (title, url, raw_content, score).""" return tavily_client.search( query=query, max_results=max_results, search_depth="advanced", chunks_per_source=3, include_raw_content=True, )["results"] @tool def extract_data( raw_results: List[Dict[str, Any]], schema: Optional[str] = None ) -> List[Dict[str, Any]]: """Turn *raw_results* into structured JSON matching *schema*. If schema is None, a minimal list-of-dicts schema is inferred.""" if schema is None: schema = '[{"OS":"string","MarketShare":"number"}]' sys = "Return ONLY valid JSON. No markdown." usr = ( f"Raw:\n{json.dumps(raw_results, ensure_ascii=False)[:4000]}" f"\n\nSchema:\n{schema}" ) res = openai_client.chat.completions.create( model="gpt-4o-mini", messages=[{"role": "system", "content": sys}, {"role": "user", "content": usr}], temperature=0, max_tokens=2000, response_format={"type": "json_object"}, ) return json.loads(res.choices[0].message.content.strip()) @tool def generate_plot_code( data: List[Dict[str, Any]], instructions: str ) -> str: """Return RAW python defining create_plot(data)->fig.""" sys = ("Return ONLY python code (no markdown) that defines " "`create_plot(data)` and returns a Plotly figure.") usr = f"Data:\n{json.dumps(data, indent=2)}\n\nInstructions:\n{instructions}" res = openai_client.chat.completions.create( model="gpt-4o-mini", messages=[{"role": "system", "content": sys}, {"role": "user", "content": usr}], temperature=0, max_tokens=1500, response_format={"type": "text"}, ) return res.choices[0].message.content.strip() @tool def render_plot( code: str, data: List[Dict[str, Any]], filename: str | None = None ) -> str: """Exec *code* and save fig to HTML; returns filepath.""" if filename is None: filename = f"plot_{uuid.uuid4().hex[:8]}.html" # Ensure filename is saved in plots directory filepath = plots_dir / filename ctx = {"px": px, "go": go, "pio": pio} exec(code, ctx) # defines create_plot pio.write_html(ctx["create_plot"](data), str(filepath)) # Store current data and context for conversation memory st.session_state.current_data = data st.session_state.current_plot_context = { 'code': code, 'data': data, 'filepath': str(filepath), 'filename': filename } return str(filepath) # ── AGENT PROMPT ───────────────────────────────────────────────────────── cheat_sheet = textwrap.dedent(""" ┏━━ TOOL ARG GUIDE ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓ ┃ search_web {{query:str, max_results?:int}} ┃ ┃ extract_data {{raw_results:…, schema?:str}} ← schema optional ┃ ┃ generate_plot_code {{data:…, instructions:str}} ┃ ┃ render_plot {{code:…, data:… [,filename]}} → then STOP ┃ ┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛ """) def create_agent_prompt(conversation_history, current_data): """Create dynamic agent prompt with conversation context.""" context_info = "" if current_data: context_info = f""" CURRENT DATA CONTEXT: You have access to previously extracted data: {json.dumps(current_data[:2], indent=2)}... If the user asks to modify the current plot, you can skip search_web and extract_data steps and directly use this data. """ conversation_context = "" if conversation_history: recent_messages = conversation_history[-4:] # Last 4 messages for context conversation_context = f""" CONVERSATION HISTORY: {chr(10).join([f"User: {msg['user']}" + (f"\nBot: {msg['bot']}" if msg.get('bot') else "") for msg in recent_messages])} """ return f""" You are Plot-Agent, an AI visualization assistant with conversation memory. {context_info} {conversation_context} PIPELINE: search_web → extract_data → generate_plot_code → render_plot. RULES • If user asks to modify current plot and you have current data, skip search_web and extract_data. • If extract_data gets no schema, that's OK; the tool will infer one. • After render_plot, reply with the file path & a one-liner, then **end**. • Use conversation context to understand user's intent better. {cheat_sheet} """ agent_prompt = create_agent_prompt([], None) # Initial prompt TOOLS = [search_web, extract_data, generate_plot_code, render_plot] plot_agent = create_react_agent(llm_model, TOOLS, prompt=agent_prompt) # ── STREAMLIT UI ───────────────────────────────────────────────────────── def main(): st.set_page_config( page_title="Plot-Agent 🤖📊", page_icon="📊", layout="wide", initial_sidebar_state="expanded" ) # Custom CSS for better styling st.markdown(""" """, unsafe_allow_html=True) # Header st.markdown("""

🤖 Plot-Agent: AI-Powered Data Visualization

Search the web, extract data, and create stunning visualizations automatically!

""", unsafe_allow_html=True) # Sidebar with st.sidebar: st.header("⚙️ Configuration") # API Status st.subheader("🔑 API Status") if OPENAI_API_KEY and TAVILY_API_KEY: st.success("✅ API Keys Loaded") else: st.error("❌ API Keys Missing") # Recent Plots st.subheader("📁 Recent Plots") plot_files = list(plots_dir.glob("*.html")) if plot_files: plot_files.sort(key=lambda x: x.stat().st_mtime, reverse=True) for i, plot_file in enumerate(plot_files[:5]): if st.button(f"📊 {plot_file.stem}", key=f"recent_{i}"): st.session_state.selected_plot = str(plot_file) # Clear latest_plot to avoid conflicts if 'latest_plot' in st.session_state: del st.session_state.latest_plot st.rerun() else: st.info("No plots generated yet") # Clear plots if st.button("🗑️ Clear All Plots"): for plot_file in plot_files: plot_file.unlink() # Clear session state plot references if 'latest_plot' in st.session_state: del st.session_state.latest_plot if 'selected_plot' in st.session_state: del st.session_state.selected_plot st.success("All plots cleared!") st.rerun() # Clear conversation if st.button("🗑️ Clear Conversation"): st.session_state.conversation_history = [] st.session_state.current_data = None st.session_state.current_plot_context = {} st.success("Conversation cleared!") st.rerun() # Show current context if st.session_state.current_data: st.subheader("💾 Current Data Context") st.success(f"📊 {len(st.session_state.current_data)} data points available") with st.expander("View Data Sample"): st.json(st.session_state.current_data[:3]) # Show conversation history if st.session_state.conversation_history: st.subheader("💬 Conversation History") with st.expander(f"View History ({len(st.session_state.conversation_history)} messages)"): for i, msg in enumerate(st.session_state.conversation_history[-5:]): st.write(f"**{i+1}. User:** {msg['user']}") if msg.get('bot'): st.write(f"**Bot:** {msg['bot']}") # Main interface col1, col2 = st.columns([1, 1]) with col1: st.header("💬 Chat with Plot-Agent") # Input form with st.form("plot_request"): user_input = st.text_area( "What visualization would you like to create?", placeholder="e.g., Create a bar chart of top 10 countries by GDP in 2024", height=100 ) submitted = st.form_submit_button("🚀 Generate Plot", use_container_width=True) # Example prompts st.subheader("💡 Example Prompts") # Dynamic examples based on context base_examples = [ "Create a line chart of Bitcoin price over the last 6 months", "Show a pie chart of global smartphone market share in 2024", "Make a bar chart of top 10 most populous cities in the world", "Create a scatter plot of countries by GDP vs population", ] context_examples = [] if st.session_state.current_data: context_examples = [ "Change the current chart to a pie chart", "Make the bars horizontal instead of vertical", "Add different colors to each data point", "Change the title and add axis labels", ] all_examples = context_examples + base_examples for i, example in enumerate(all_examples[:6]): # Show max 6 examples prefix = "🔄" if i < len(context_examples) else "📝" if st.button(f"{prefix} {example}", key=f"example_{i}"): st.session_state.user_input = example submitted = True user_input = example with col2: st.header("🔄 Agent Activity") # Create placeholders for real-time updates status_placeholder = st.empty() activity_placeholder = st.empty() # Process request if submitted and user_input: # Add user message to conversation history st.session_state.conversation_history.append({ 'user': user_input, 'timestamp': time.time() }) with status_placeholder.container(): st.markdown('
🚀 Starting Plot-Agent...
', unsafe_allow_html=True) # Create containers for activity logging activity_container = activity_placeholder.container() try: # Create dynamic agent with conversation context dynamic_prompt = create_agent_prompt( st.session_state.conversation_history, st.session_state.current_data ) plot_agent = create_react_agent(llm_model, TOOLS, prompt=dynamic_prompt) # Stream the agent execution messages = [] current_tool = None tool_results = {} bot_response = "" with activity_container: progress_bar = st.progress(0) step_counter = 0 max_steps = 4 # search, extract, generate, render for chunk in plot_agent.stream( {"messages": [{"role": "user", "content": user_input}]}, stream_mode="updates", config={"recursion_limit": 10}, ): node_name = next(iter(chunk)) if node_name == "agent": if "messages" in chunk[node_name]: message = chunk[node_name]["messages"][-1] messages.append(message) # Parse tool calls if hasattr(message, 'tool_calls') and message.tool_calls: for tool_call in message.tool_calls: current_tool = tool_call['name'] step_counter += 1 progress_bar.progress(min(step_counter / max_steps, 1.0)) st.markdown(f"""

🔧 Using Tool: {current_tool}

Arguments:

{json.dumps(tool_call['args'], indent=2)}
""", unsafe_allow_html=True) time.sleep(0.5) # Visual delay for better UX # Show assistant responses elif hasattr(message, 'content') and message.content: bot_response = message.content st.markdown(f"""
🤖 Plot-Agent: {message.content}
""", unsafe_allow_html=True) elif node_name == "tools": # Show tool results for tool_name, result in chunk[node_name].items(): tool_results[tool_name] = result if tool_name == "search_web": st.markdown(f"""

🔍 Search Results

Found {len(result)} sources

View Sources
", unsafe_allow_html=True) elif tool_name == "extract_data": st.markdown(f"""

📊 Extracted Data

Processed {len(result)} data points

View Data Sample
{json.dumps(result[:3] if len(result) > 3 else result, indent=2)}
""", unsafe_allow_html=True) elif tool_name == "generate_plot_code": st.markdown(f"""

🎨 Generated Plot Code

View Code
{result[:500]}...
""", unsafe_allow_html=True) elif tool_name == "render_plot": st.markdown(f"""

✅ Plot Rendered

File: {result}

""", unsafe_allow_html=True) # Set the generated plot for display and auto-refresh st.session_state.latest_plot = result # Clear selected plot to show latest if 'selected_plot' in st.session_state: del st.session_state.selected_plot progress_bar.progress(1.0) # Update conversation history with bot response if bot_response: st.session_state.conversation_history[-1]['bot'] = bot_response # Update status with status_placeholder.container(): st.markdown('
Plot generation completed!
', unsafe_allow_html=True) # Force rerun to show the new plot immediately time.sleep(0.5) # Small delay to ensure file is written st.rerun() except Exception as e: with status_placeholder.container(): st.markdown(f'
Error: {str(e)}
', unsafe_allow_html=True) # Display generated plot st.header("📊 Generated Visualization") # Determine which plot to show (latest has priority over selected) plot_file = None if hasattr(st.session_state, 'latest_plot') and st.session_state.latest_plot: plot_file = st.session_state.latest_plot st.info("🆕 **Latest Generated Plot**") elif hasattr(st.session_state, 'selected_plot') and st.session_state.selected_plot: plot_file = st.session_state.selected_plot st.info(f"📁 **Selected Plot:** {Path(plot_file).stem}") if plot_file and Path(plot_file).exists(): # Display the HTML plot try: with open(plot_file, 'r', encoding='utf-8') as f: html_content = f.read() st.components.v1.html(html_content, height=600, scrolling=True) # Download button st.download_button( label="📥 Download Plot", data=html_content, file_name=Path(plot_file).name, mime="text/html" ) except Exception as e: st.error(f"Error loading plot: {str(e)}") elif hasattr(st.session_state, 'latest_plot') or hasattr(st.session_state, 'selected_plot'): st.error("Plot file not found! It may have been deleted.") else: st.info("👋 Generate a plot or select from recent plots to view here!") if __name__ == "__main__": main()