import os
import json
import uuid
import textwrap
import streamlit as st
from typing import List, Dict, Any, Optional
from pathlib import Path
import time
# Load environment variables
from dotenv import load_dotenv
load_dotenv()
# Get API keys from environment
OPENAI_API_KEY = os.getenv("OPENAI_API_KEY")
TAVILY_API_KEY = os.getenv("TAVILY_API_KEY")
if not OPENAI_API_KEY or not TAVILY_API_KEY:
st.error("Please set OPENAI_API_KEY and TAVILY_API_KEY in your .env file")
st.stop()
# Set environment variables
os.environ['OPENAI_API_KEY'] = OPENAI_API_KEY
os.environ['TAVILY_API_KEY'] = TAVILY_API_KEY
# Imports after setting environment variables
from openai import OpenAI
from tavily import TavilyClient
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
from langchain_core.tools import tool
import plotly.express as px
import plotly.graph_objects as go
import plotly.io as pio
# Initialize clients
openai_client = OpenAI(api_key=OPENAI_API_KEY)
tavily_client = TavilyClient(TAVILY_API_KEY)
llm_model = ChatOpenAI(model="gpt-4o-mini", temperature=0)
# Create plots directory
plots_dir = Path("./plots")
plots_dir.mkdir(exist_ok=True)
# Initialize session state for conversation memory
if 'conversation_history' not in st.session_state:
st.session_state.conversation_history = []
if 'current_data' not in st.session_state:
st.session_state.current_data = None
if 'current_plot_context' not in st.session_state:
st.session_state.current_plot_context = {}
# ── TOOLS ────────────────────────────────────────────────────────────────
@tool
def search_web(query: str, max_results: int = 5) -> List[Dict[str, Any]]:
"""Return Tavily results (title, url, raw_content, score)."""
return tavily_client.search(
query=query,
max_results=max_results,
search_depth="advanced",
chunks_per_source=3,
include_raw_content=True,
)["results"]
@tool
def extract_data(
raw_results: List[Dict[str, Any]],
schema: Optional[str] = None
) -> List[Dict[str, Any]]:
"""Turn *raw_results* into structured JSON matching *schema*.
If schema is None, a minimal list-of-dicts schema is inferred."""
if schema is None:
schema = '[{"OS":"string","MarketShare":"number"}]'
sys = "Return ONLY valid JSON. No markdown."
usr = (
f"Raw:\n{json.dumps(raw_results, ensure_ascii=False)[:4000]}"
f"\n\nSchema:\n{schema}"
)
res = openai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "system", "content": sys},
{"role": "user", "content": usr}],
temperature=0, max_tokens=2000,
response_format={"type": "json_object"},
)
return json.loads(res.choices[0].message.content.strip())
@tool
def generate_plot_code(
data: List[Dict[str, Any]],
instructions: str
) -> str:
"""Return RAW python defining create_plot(data)->fig."""
sys = ("Return ONLY python code (no markdown) that defines "
"`create_plot(data)` and returns a Plotly figure.")
usr = f"Data:\n{json.dumps(data, indent=2)}\n\nInstructions:\n{instructions}"
res = openai_client.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "system", "content": sys},
{"role": "user", "content": usr}],
temperature=0, max_tokens=1500,
response_format={"type": "text"},
)
return res.choices[0].message.content.strip()
@tool
def render_plot(
code: str,
data: List[Dict[str, Any]],
filename: str | None = None
) -> str:
"""Exec *code* and save fig to HTML; returns filepath."""
if filename is None:
filename = f"plot_{uuid.uuid4().hex[:8]}.html"
# Ensure filename is saved in plots directory
filepath = plots_dir / filename
ctx = {"px": px, "go": go, "pio": pio}
exec(code, ctx) # defines create_plot
pio.write_html(ctx["create_plot"](data), str(filepath))
# Store current data and context for conversation memory
st.session_state.current_data = data
st.session_state.current_plot_context = {
'code': code,
'data': data,
'filepath': str(filepath),
'filename': filename
}
return str(filepath)
# ── AGENT PROMPT ─────────────────────────────────────────────────────────
cheat_sheet = textwrap.dedent("""
┏━━ TOOL ARG GUIDE ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┓
┃ search_web {{query:str, max_results?:int}} ┃
┃ extract_data {{raw_results:…, schema?:str}} ← schema optional ┃
┃ generate_plot_code {{data:…, instructions:str}} ┃
┃ render_plot {{code:…, data:… [,filename]}} → then STOP ┃
┗━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┛
""")
def create_agent_prompt(conversation_history, current_data):
"""Create dynamic agent prompt with conversation context."""
context_info = ""
if current_data:
context_info = f"""
CURRENT DATA CONTEXT:
You have access to previously extracted data: {json.dumps(current_data[:2], indent=2)}...
If the user asks to modify the current plot, you can skip search_web and extract_data steps and directly use this data.
"""
conversation_context = ""
if conversation_history:
recent_messages = conversation_history[-4:] # Last 4 messages for context
conversation_context = f"""
CONVERSATION HISTORY:
{chr(10).join([f"User: {msg['user']}" + (f"\nBot: {msg['bot']}" if msg.get('bot') else "") for msg in recent_messages])}
"""
return f"""
You are Plot-Agent, an AI visualization assistant with conversation memory.
{context_info}
{conversation_context}
PIPELINE: search_web → extract_data → generate_plot_code → render_plot.
RULES
• If user asks to modify current plot and you have current data, skip search_web and extract_data.
• If extract_data gets no schema, that's OK; the tool will infer one.
• After render_plot, reply with the file path & a one-liner, then **end**.
• Use conversation context to understand user's intent better.
{cheat_sheet}
"""
agent_prompt = create_agent_prompt([], None) # Initial prompt
TOOLS = [search_web, extract_data, generate_plot_code, render_plot]
plot_agent = create_react_agent(llm_model, TOOLS, prompt=agent_prompt)
# ── STREAMLIT UI ─────────────────────────────────────────────────────────
def main():
st.set_page_config(
page_title="Plot-Agent 🤖📊",
page_icon="📊",
layout="wide",
initial_sidebar_state="expanded"
)
# Custom CSS for better styling
st.markdown("""
""", unsafe_allow_html=True)
# Header
st.markdown("""
🤖 Plot-Agent: AI-Powered Data Visualization
Search the web, extract data, and create stunning visualizations automatically!
""", unsafe_allow_html=True)
# Sidebar
with st.sidebar:
st.header("⚙️ Configuration")
# API Status
st.subheader("🔑 API Status")
if OPENAI_API_KEY and TAVILY_API_KEY:
st.success("✅ API Keys Loaded")
else:
st.error("❌ API Keys Missing")
# Recent Plots
st.subheader("📁 Recent Plots")
plot_files = list(plots_dir.glob("*.html"))
if plot_files:
plot_files.sort(key=lambda x: x.stat().st_mtime, reverse=True)
for i, plot_file in enumerate(plot_files[:5]):
if st.button(f"📊 {plot_file.stem}", key=f"recent_{i}"):
st.session_state.selected_plot = str(plot_file)
# Clear latest_plot to avoid conflicts
if 'latest_plot' in st.session_state:
del st.session_state.latest_plot
st.rerun()
else:
st.info("No plots generated yet")
# Clear plots
if st.button("🗑️ Clear All Plots"):
for plot_file in plot_files:
plot_file.unlink()
# Clear session state plot references
if 'latest_plot' in st.session_state:
del st.session_state.latest_plot
if 'selected_plot' in st.session_state:
del st.session_state.selected_plot
st.success("All plots cleared!")
st.rerun()
# Clear conversation
if st.button("🗑️ Clear Conversation"):
st.session_state.conversation_history = []
st.session_state.current_data = None
st.session_state.current_plot_context = {}
st.success("Conversation cleared!")
st.rerun()
# Show current context
if st.session_state.current_data:
st.subheader("💾 Current Data Context")
st.success(f"📊 {len(st.session_state.current_data)} data points available")
with st.expander("View Data Sample"):
st.json(st.session_state.current_data[:3])
# Show conversation history
if st.session_state.conversation_history:
st.subheader("💬 Conversation History")
with st.expander(f"View History ({len(st.session_state.conversation_history)} messages)"):
for i, msg in enumerate(st.session_state.conversation_history[-5:]):
st.write(f"**{i+1}. User:** {msg['user']}")
if msg.get('bot'):
st.write(f"**Bot:** {msg['bot']}")
# Main interface
col1, col2 = st.columns([1, 1])
with col1:
st.header("💬 Chat with Plot-Agent")
# Input form
with st.form("plot_request"):
user_input = st.text_area(
"What visualization would you like to create?",
placeholder="e.g., Create a bar chart of top 10 countries by GDP in 2024",
height=100
)
submitted = st.form_submit_button("🚀 Generate Plot", use_container_width=True)
# Example prompts
st.subheader("💡 Example Prompts")
# Dynamic examples based on context
base_examples = [
"Create a line chart of Bitcoin price over the last 6 months",
"Show a pie chart of global smartphone market share in 2024",
"Make a bar chart of top 10 most populous cities in the world",
"Create a scatter plot of countries by GDP vs population",
]
context_examples = []
if st.session_state.current_data:
context_examples = [
"Change the current chart to a pie chart",
"Make the bars horizontal instead of vertical",
"Add different colors to each data point",
"Change the title and add axis labels",
]
all_examples = context_examples + base_examples
for i, example in enumerate(all_examples[:6]): # Show max 6 examples
prefix = "🔄" if i < len(context_examples) else "📝"
if st.button(f"{prefix} {example}", key=f"example_{i}"):
st.session_state.user_input = example
submitted = True
user_input = example
with col2:
st.header("🔄 Agent Activity")
# Create placeholders for real-time updates
status_placeholder = st.empty()
activity_placeholder = st.empty()
# Process request
if submitted and user_input:
# Add user message to conversation history
st.session_state.conversation_history.append({
'user': user_input,
'timestamp': time.time()
})
with status_placeholder.container():
st.markdown('🚀 Starting Plot-Agent...
',
unsafe_allow_html=True)
# Create containers for activity logging
activity_container = activity_placeholder.container()
try:
# Create dynamic agent with conversation context
dynamic_prompt = create_agent_prompt(
st.session_state.conversation_history,
st.session_state.current_data
)
plot_agent = create_react_agent(llm_model, TOOLS, prompt=dynamic_prompt)
# Stream the agent execution
messages = []
current_tool = None
tool_results = {}
bot_response = ""
with activity_container:
progress_bar = st.progress(0)
step_counter = 0
max_steps = 4 # search, extract, generate, render
for chunk in plot_agent.stream(
{"messages": [{"role": "user", "content": user_input}]},
stream_mode="updates",
config={"recursion_limit": 10},
):
node_name = next(iter(chunk))
if node_name == "agent":
if "messages" in chunk[node_name]:
message = chunk[node_name]["messages"][-1]
messages.append(message)
# Parse tool calls
if hasattr(message, 'tool_calls') and message.tool_calls:
for tool_call in message.tool_calls:
current_tool = tool_call['name']
step_counter += 1
progress_bar.progress(min(step_counter / max_steps, 1.0))
st.markdown(f"""
""", unsafe_allow_html=True)
time.sleep(0.5) # Visual delay for better UX
# Show assistant responses
elif hasattr(message, 'content') and message.content:
bot_response = message.content
st.markdown(f"""
🤖 Plot-Agent: {message.content}
""", unsafe_allow_html=True)
elif node_name == "tools":
# Show tool results
for tool_name, result in chunk[node_name].items():
tool_results[tool_name] = result
if tool_name == "search_web":
st.markdown(f"""
", unsafe_allow_html=True)
elif tool_name == "extract_data":
st.markdown(f"""
""", unsafe_allow_html=True)
elif tool_name == "generate_plot_code":
st.markdown(f"""
""", unsafe_allow_html=True)
elif tool_name == "render_plot":
st.markdown(f"""
""", unsafe_allow_html=True)
# Set the generated plot for display and auto-refresh
st.session_state.latest_plot = result
# Clear selected plot to show latest
if 'selected_plot' in st.session_state:
del st.session_state.selected_plot
progress_bar.progress(1.0)
# Update conversation history with bot response
if bot_response:
st.session_state.conversation_history[-1]['bot'] = bot_response
# Update status
with status_placeholder.container():
st.markdown('✅ Plot generation completed!
',
unsafe_allow_html=True)
# Force rerun to show the new plot immediately
time.sleep(0.5) # Small delay to ensure file is written
st.rerun()
except Exception as e:
with status_placeholder.container():
st.markdown(f'❌ Error: {str(e)}
',
unsafe_allow_html=True)
# Display generated plot
st.header("📊 Generated Visualization")
# Determine which plot to show (latest has priority over selected)
plot_file = None
if hasattr(st.session_state, 'latest_plot') and st.session_state.latest_plot:
plot_file = st.session_state.latest_plot
st.info("🆕 **Latest Generated Plot**")
elif hasattr(st.session_state, 'selected_plot') and st.session_state.selected_plot:
plot_file = st.session_state.selected_plot
st.info(f"📁 **Selected Plot:** {Path(plot_file).stem}")
if plot_file and Path(plot_file).exists():
# Display the HTML plot
try:
with open(plot_file, 'r', encoding='utf-8') as f:
html_content = f.read()
st.components.v1.html(html_content, height=600, scrolling=True)
# Download button
st.download_button(
label="📥 Download Plot",
data=html_content,
file_name=Path(plot_file).name,
mime="text/html"
)
except Exception as e:
st.error(f"Error loading plot: {str(e)}")
elif hasattr(st.session_state, 'latest_plot') or hasattr(st.session_state, 'selected_plot'):
st.error("Plot file not found! It may have been deleted.")
else:
st.info("👋 Generate a plot or select from recent plots to view here!")
if __name__ == "__main__":
main()