Spaces:
Build error
Build error
| import os | |
| import json | |
| import uuid | |
| import textwrap | |
| import streamlit as st | |
| from typing import List, Dict, Any, Optional | |
| from pathlib import Path | |
| import time | |
| # Load environment variables | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Get API keys from environment | |
| OPENAI_API_KEY = os.getenv("OPENAI_API_KEY") | |
| TAVILY_API_KEY = os.getenv("TAVILY_API_KEY") | |
| if not OPENAI_API_KEY or not TAVILY_API_KEY: | |
| st.error("Please set OPENAI_API_KEY and TAVILY_API_KEY in your .env file") | |
| st.stop() | |
| # Set environment variables | |
| os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY | |
| os.environ['TAVILY_API_KEY'] = TAVILY_API_KEY | |
| # Imports after setting environment variables | |
| from openai import OpenAI | |
| from tavily import TavilyClient | |
| from langchain_openai import ChatOpenAI | |
| from langgraph.prebuilt import create_react_agent | |
| from langchain_core.tools import tool | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| import plotly.io as pio | |
| # Initialize clients | |
| openai_client = OpenAI(api_key=OPENAI_API_KEY) | |
| tavily_client = TavilyClient(TAVILY_API_KEY) | |
| llm_model = ChatOpenAI(model="gpt-4o-mini", temperature=0) | |
| # Create plots directory | |
| plots_dir = Path("./plots") | |
| plots_dir.mkdir(exist_ok=True) | |
| # Initialize session state for conversation memory | |
| if 'conversation_history' not in st.session_state: | |
| st.session_state.conversation_history = [] | |
| if 'current_data' not in st.session_state: | |
| st.session_state.current_data = None | |
| if 'current_plot_context' not in st.session_state: | |
| st.session_state.current_plot_context = {} | |
| # ββ TOOLS ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def search_web(query: str, max_results: int = 5) -> List[Dict[str, Any]]: | |
| """Return Tavily results (title, url, raw_content, score).""" | |
| return tavily_client.search( | |
| query=query, | |
| max_results=max_results, | |
| search_depth="advanced", | |
| chunks_per_source=3, | |
| include_raw_content=True, | |
| )["results"] | |
| def extract_data( | |
| raw_results: List[Dict[str, Any]], | |
| schema: Optional[str] = None | |
| ) -> List[Dict[str, Any]]: | |
| """Turn *raw_results* into structured JSON matching *schema*. | |
| If schema is None, a minimal list-of-dicts schema is inferred.""" | |
| if schema is None: | |
| schema = '[{"OS":"string","MarketShare":"number"}]' | |
| sys = "Return ONLY valid JSON. No markdown." | |
| usr = ( | |
| f"Raw:\n{json.dumps(raw_results, ensure_ascii=False)[:4000]}" | |
| f"\n\nSchema:\n{schema}" | |
| ) | |
| res = openai_client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[{"role": "system", "content": sys}, | |
| {"role": "user", "content": usr}], | |
| temperature=0, max_tokens=2000, | |
| response_format={"type": "json_object"}, | |
| ) | |
| return json.loads(res.choices[0].message.content.strip()) | |
| def generate_plot_code( | |
| data: List[Dict[str, Any]], | |
| instructions: str | |
| ) -> str: | |
| """Return RAW python defining create_plot(data)->fig.""" | |
| sys = ("Return ONLY python code (no markdown) that defines " | |
| "`create_plot(data)` and returns a Plotly figure.") | |
| usr = f"Data:\n{json.dumps(data, indent=2)}\n\nInstructions:\n{instructions}" | |
| res = openai_client.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[{"role": "system", "content": sys}, | |
| {"role": "user", "content": usr}], | |
| temperature=0, max_tokens=1500, | |
| response_format={"type": "text"}, | |
| ) | |
| return res.choices[0].message.content.strip() | |
| def render_plot( | |
| code: str, | |
| data: List[Dict[str, Any]], | |
| filename: str | None = None | |
| ) -> str: | |
| """Exec *code* and save fig to HTML; returns filepath.""" | |
| if filename is None: | |
| filename = f"plot_{uuid.uuid4().hex[:8]}.html" | |
| # Ensure filename is saved in plots directory | |
| filepath = plots_dir / filename | |
| ctx = {"px": px, "go": go, "pio": pio} | |
| exec(code, ctx) # defines create_plot | |
| pio.write_html(ctx["create_plot"](data), str(filepath)) | |
| # Store current data and context for conversation memory | |
| st.session_state.current_data = data | |
| st.session_state.current_plot_context = { | |
| 'code': code, | |
| 'data': data, | |
| 'filepath': str(filepath), | |
| 'filename': filename | |
| } | |
| return str(filepath) | |
| # ββ AGENT PROMPT βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| cheat_sheet = textwrap.dedent(""" | |
| βββ TOOL ARG GUIDE ββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| β search_web {{query:str, max_results?:int}} β | |
| β extract_data {{raw_results:β¦, schema?:str}} β schema optional β | |
| β generate_plot_code {{data:β¦, instructions:str}} β | |
| β render_plot {{code:β¦, data:β¦ [,filename]}} β then STOP β | |
| βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| """) | |
| def create_agent_prompt(conversation_history, current_data): | |
| """Create dynamic agent prompt with conversation context.""" | |
| context_info = "" | |
| if current_data: | |
| context_info = f""" | |
| CURRENT DATA CONTEXT: | |
| You have access to previously extracted data: {json.dumps(current_data[:2], indent=2)}... | |
| If the user asks to modify the current plot, you can skip search_web and extract_data steps and directly use this data. | |
| """ | |
| conversation_context = "" | |
| if conversation_history: | |
| recent_messages = conversation_history[-4:] # Last 4 messages for context | |
| conversation_context = f""" | |
| CONVERSATION HISTORY: | |
| {chr(10).join([f"User: {msg['user']}" + (f"\nBot: {msg['bot']}" if msg.get('bot') else "") for msg in recent_messages])} | |
| """ | |
| return f""" | |
| You are Plot-Agent, an AI visualization assistant with conversation memory. | |
| {context_info} | |
| {conversation_context} | |
| PIPELINE: search_web β extract_data β generate_plot_code β render_plot. | |
| RULES | |
| β’ If user asks to modify current plot and you have current data, skip search_web and extract_data. | |
| β’ If extract_data gets no schema, that's OK; the tool will infer one. | |
| β’ After render_plot, reply with the file path & a one-liner, then **end**. | |
| β’ Use conversation context to understand user's intent better. | |
| {cheat_sheet} | |
| """ | |
| agent_prompt = create_agent_prompt([], None) # Initial prompt | |
| TOOLS = [search_web, extract_data, generate_plot_code, render_plot] | |
| plot_agent = create_react_agent(llm_model, TOOLS, prompt=agent_prompt) | |
| # ββ STREAMLIT UI βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| st.set_page_config( | |
| page_title="Plot-Agent π€π", | |
| page_icon="π", | |
| layout="wide", | |
| initial_sidebar_state="expanded" | |
| ) | |
| # Custom CSS for better styling | |
| st.markdown(""" | |
| <style> | |
| .main-header { | |
| background: linear-gradient(90deg, #667eea 0%, #764ba2 100%); | |
| padding: 1rem; | |
| border-radius: 10px; | |
| margin-bottom: 2rem; | |
| } | |
| .main-header h1 { | |
| color: white; | |
| margin: 0; | |
| text-align: center; | |
| } | |
| .status-box { | |
| border-left: 4px solid #4CAF50; | |
| background-color: #f9f9f9; | |
| padding: 10px; | |
| margin: 10px 0; | |
| border-radius: 5px; | |
| color: black !important; | |
| } | |
| .status-box * { | |
| color: black !important; | |
| } | |
| .tool-box { | |
| border: 1px solid #ddd; | |
| border-radius: 8px; | |
| padding: 15px; | |
| margin: 10px 0; | |
| background-color: #f8f9fa; | |
| color: black !important; | |
| } | |
| .tool-box * { | |
| color: black !important; | |
| } | |
| .error-box { | |
| border-left: 4px solid #f44336; | |
| background-color: #ffebee; | |
| padding: 10px; | |
| margin: 10px 0; | |
| border-radius: 5px; | |
| color: black !important; | |
| } | |
| .error-box * { | |
| color: black !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Header | |
| st.markdown(""" | |
| <div class="main-header"> | |
| <h1>π€ Plot-Agent: AI-Powered Data Visualization</h1> | |
| <p style="text-align: center; color: white; margin: 0;"> | |
| Search the web, extract data, and create stunning visualizations automatically! | |
| </p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Sidebar | |
| with st.sidebar: | |
| st.header("βοΈ Configuration") | |
| # API Status | |
| st.subheader("π API Status") | |
| if OPENAI_API_KEY and TAVILY_API_KEY: | |
| st.success("β API Keys Loaded") | |
| else: | |
| st.error("β API Keys Missing") | |
| # Recent Plots | |
| st.subheader("π Recent Plots") | |
| plot_files = list(plots_dir.glob("*.html")) | |
| if plot_files: | |
| plot_files.sort(key=lambda x: x.stat().st_mtime, reverse=True) | |
| for i, plot_file in enumerate(plot_files[:5]): | |
| if st.button(f"π {plot_file.stem}", key=f"recent_{i}"): | |
| st.session_state.selected_plot = str(plot_file) | |
| # Clear latest_plot to avoid conflicts | |
| if 'latest_plot' in st.session_state: | |
| del st.session_state.latest_plot | |
| st.rerun() | |
| else: | |
| st.info("No plots generated yet") | |
| # Clear plots | |
| if st.button("ποΈ Clear All Plots"): | |
| for plot_file in plot_files: | |
| plot_file.unlink() | |
| # Clear session state plot references | |
| if 'latest_plot' in st.session_state: | |
| del st.session_state.latest_plot | |
| if 'selected_plot' in st.session_state: | |
| del st.session_state.selected_plot | |
| st.success("All plots cleared!") | |
| st.rerun() | |
| # Clear conversation | |
| if st.button("ποΈ Clear Conversation"): | |
| st.session_state.conversation_history = [] | |
| st.session_state.current_data = None | |
| st.session_state.current_plot_context = {} | |
| st.success("Conversation cleared!") | |
| st.rerun() | |
| # Show current context | |
| if st.session_state.current_data: | |
| st.subheader("πΎ Current Data Context") | |
| st.success(f"π {len(st.session_state.current_data)} data points available") | |
| with st.expander("View Data Sample"): | |
| st.json(st.session_state.current_data[:3]) | |
| # Show conversation history | |
| if st.session_state.conversation_history: | |
| st.subheader("π¬ Conversation History") | |
| with st.expander(f"View History ({len(st.session_state.conversation_history)} messages)"): | |
| for i, msg in enumerate(st.session_state.conversation_history[-5:]): | |
| st.write(f"**{i+1}. User:** {msg['user']}") | |
| if msg.get('bot'): | |
| st.write(f"**Bot:** {msg['bot']}") | |
| # Main interface | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.header("π¬ Chat with Plot-Agent") | |
| # Input form | |
| with st.form("plot_request"): | |
| user_input = st.text_area( | |
| "What visualization would you like to create?", | |
| placeholder="e.g., Create a bar chart of top 10 countries by GDP in 2024", | |
| height=100 | |
| ) | |
| submitted = st.form_submit_button("π Generate Plot", use_container_width=True) | |
| # Example prompts | |
| st.subheader("π‘ Example Prompts") | |
| # Dynamic examples based on context | |
| base_examples = [ | |
| "Create a line chart of Bitcoin price over the last 6 months", | |
| "Show a pie chart of global smartphone market share in 2024", | |
| "Make a bar chart of top 10 most populous cities in the world", | |
| "Create a scatter plot of countries by GDP vs population", | |
| ] | |
| context_examples = [] | |
| if st.session_state.current_data: | |
| context_examples = [ | |
| "Change the current chart to a pie chart", | |
| "Make the bars horizontal instead of vertical", | |
| "Add different colors to each data point", | |
| "Change the title and add axis labels", | |
| ] | |
| all_examples = context_examples + base_examples | |
| for i, example in enumerate(all_examples[:6]): # Show max 6 examples | |
| prefix = "π" if i < len(context_examples) else "π" | |
| if st.button(f"{prefix} {example}", key=f"example_{i}"): | |
| st.session_state.user_input = example | |
| submitted = True | |
| user_input = example | |
| with col2: | |
| st.header("π Agent Activity") | |
| # Create placeholders for real-time updates | |
| status_placeholder = st.empty() | |
| activity_placeholder = st.empty() | |
| # Process request | |
| if submitted and user_input: | |
| # Add user message to conversation history | |
| st.session_state.conversation_history.append({ | |
| 'user': user_input, | |
| 'timestamp': time.time() | |
| }) | |
| with status_placeholder.container(): | |
| st.markdown('<div class="status-box">π <strong>Starting Plot-Agent...</strong></div>', | |
| unsafe_allow_html=True) | |
| # Create containers for activity logging | |
| activity_container = activity_placeholder.container() | |
| try: | |
| # Create dynamic agent with conversation context | |
| dynamic_prompt = create_agent_prompt( | |
| st.session_state.conversation_history, | |
| st.session_state.current_data | |
| ) | |
| plot_agent = create_react_agent(llm_model, TOOLS, prompt=dynamic_prompt) | |
| # Stream the agent execution | |
| messages = [] | |
| current_tool = None | |
| tool_results = {} | |
| bot_response = "" | |
| with activity_container: | |
| progress_bar = st.progress(0) | |
| step_counter = 0 | |
| max_steps = 4 # search, extract, generate, render | |
| for chunk in plot_agent.stream( | |
| {"messages": [{"role": "user", "content": user_input}]}, | |
| stream_mode="updates", | |
| config={"recursion_limit": 10}, | |
| ): | |
| node_name = next(iter(chunk)) | |
| if node_name == "agent": | |
| if "messages" in chunk[node_name]: | |
| message = chunk[node_name]["messages"][-1] | |
| messages.append(message) | |
| # Parse tool calls | |
| if hasattr(message, 'tool_calls') and message.tool_calls: | |
| for tool_call in message.tool_calls: | |
| current_tool = tool_call['name'] | |
| step_counter += 1 | |
| progress_bar.progress(min(step_counter / max_steps, 1.0)) | |
| st.markdown(f""" | |
| <div class="tool-box"> | |
| <h4>π§ Using Tool: {current_tool}</h4> | |
| <p><strong>Arguments:</strong></p> | |
| <pre>{json.dumps(tool_call['args'], indent=2)}</pre> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| time.sleep(0.5) # Visual delay for better UX | |
| # Show assistant responses | |
| elif hasattr(message, 'content') and message.content: | |
| bot_response = message.content | |
| st.markdown(f""" | |
| <div class="status-box"> | |
| <strong>π€ Plot-Agent:</strong> {message.content} | |
| </div> | |
| """, unsafe_allow_html=True) | |
| elif node_name == "tools": | |
| # Show tool results | |
| for tool_name, result in chunk[node_name].items(): | |
| tool_results[tool_name] = result | |
| if tool_name == "search_web": | |
| st.markdown(f""" | |
| <div class="tool-box"> | |
| <h4>π Search Results</h4> | |
| <p>Found {len(result)} sources</p> | |
| <details> | |
| <summary>View Sources</summary> | |
| <ul> | |
| """) | |
| for item in result[:3]: # Show first 3 sources | |
| st.markdown(f"<li><strong>{item.get('title', 'N/A')}</strong><br><small>{item.get('url', 'N/A')}</small></li>") | |
| st.markdown("</ul></details></div>", unsafe_allow_html=True) | |
| elif tool_name == "extract_data": | |
| st.markdown(f""" | |
| <div class="tool-box"> | |
| <h4>π Extracted Data</h4> | |
| <p>Processed {len(result)} data points</p> | |
| <details> | |
| <summary>View Data Sample</summary> | |
| <pre>{json.dumps(result[:3] if len(result) > 3 else result, indent=2)}</pre> | |
| </details> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| elif tool_name == "generate_plot_code": | |
| st.markdown(f""" | |
| <div class="tool-box"> | |
| <h4>π¨ Generated Plot Code</h4> | |
| <details> | |
| <summary>View Code</summary> | |
| <pre>{result[:500]}...</pre> | |
| </details> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| elif tool_name == "render_plot": | |
| st.markdown(f""" | |
| <div class="tool-box"> | |
| <h4>β Plot Rendered</h4> | |
| <p><strong>File:</strong> {result}</p> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # Set the generated plot for display and auto-refresh | |
| st.session_state.latest_plot = result | |
| # Clear selected plot to show latest | |
| if 'selected_plot' in st.session_state: | |
| del st.session_state.selected_plot | |
| progress_bar.progress(1.0) | |
| # Update conversation history with bot response | |
| if bot_response: | |
| st.session_state.conversation_history[-1]['bot'] = bot_response | |
| # Update status | |
| with status_placeholder.container(): | |
| st.markdown('<div class="status-box">β <strong>Plot generation completed!</strong></div>', | |
| unsafe_allow_html=True) | |
| # Force rerun to show the new plot immediately | |
| time.sleep(0.5) # Small delay to ensure file is written | |
| st.rerun() | |
| except Exception as e: | |
| with status_placeholder.container(): | |
| st.markdown(f'<div class="error-box">β <strong>Error:</strong> {str(e)}</div>', | |
| unsafe_allow_html=True) | |
| # Display generated plot | |
| st.header("π Generated Visualization") | |
| # Determine which plot to show (latest has priority over selected) | |
| plot_file = None | |
| if hasattr(st.session_state, 'latest_plot') and st.session_state.latest_plot: | |
| plot_file = st.session_state.latest_plot | |
| st.info("π **Latest Generated Plot**") | |
| elif hasattr(st.session_state, 'selected_plot') and st.session_state.selected_plot: | |
| plot_file = st.session_state.selected_plot | |
| st.info(f"π **Selected Plot:** {Path(plot_file).stem}") | |
| if plot_file and Path(plot_file).exists(): | |
| # Display the HTML plot | |
| try: | |
| with open(plot_file, 'r', encoding='utf-8') as f: | |
| html_content = f.read() | |
| st.components.v1.html(html_content, height=600, scrolling=True) | |
| # Download button | |
| st.download_button( | |
| label="π₯ Download Plot", | |
| data=html_content, | |
| file_name=Path(plot_file).name, | |
| mime="text/html" | |
| ) | |
| except Exception as e: | |
| st.error(f"Error loading plot: {str(e)}") | |
| elif hasattr(st.session_state, 'latest_plot') or hasattr(st.session_state, 'selected_plot'): | |
| st.error("Plot file not found! It may have been deleted.") | |
| else: | |
| st.info("π Generate a plot or select from recent plots to view here!") | |
| if __name__ == "__main__": | |
| main() |