prasannahf commited on
Commit
c51e04d
Β·
verified Β·
1 Parent(s): 157814f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +395 -0
app.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ # from dotenv import load_dotenv
4
+ # Load environment variables (only for local development)
5
+ try:
6
+ from dotenv import load_dotenv
7
+ load_dotenv()
8
+ except ImportError:
9
+ # dotenv not available in Streamlit Cloud, use Streamlit secrets instead
10
+ pass
11
+ import io
12
+ import sys
13
+ from contextlib import redirect_stdout, redirect_stderr
14
+ import matplotlib.pyplot as plt
15
+
16
+ # Import required packages
17
+ from langchain_groq import ChatGroq
18
+ from langgraph.types import Command
19
+ from langgraph.prebuilt import create_react_agent
20
+ from langchain_core.tools import tool
21
+ from typing_extensions import Literal
22
+ from langgraph.graph import MessagesState, StateGraph, START, END
23
+ from langchain_core.messages import BaseMessage, HumanMessage
24
+ from typing import Annotated
25
+ from langchain_community.tools import DuckDuckGoSearchRun
26
+ from langchain_community.tools.tavily_search import TavilySearchResults
27
+ from langchain_experimental.utilities import PythonREPL
28
+ from pydantic import SecretStr
29
+ from langchain_core.runnables import RunnableConfig
30
+
31
+ # Page configuration
32
+ st.set_page_config(
33
+ page_title="AI Research & Chart Generator",
34
+ page_icon="πŸ“Š",
35
+ layout="wide",
36
+ initial_sidebar_state="expanded"
37
+ )
38
+
39
+ # Custom CSS
40
+ st.markdown("""
41
+ <style>
42
+ .main-header {
43
+ text-align: center;
44
+ padding: 1rem 0;
45
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
46
+ color: white;
47
+ border-radius: 10px;
48
+ margin-bottom: 2rem;
49
+ }
50
+ .chat-message {
51
+ padding: 1rem;
52
+ border-radius: 10px;
53
+ margin: 1rem 0;
54
+ border-left: 4px solid #667eea;
55
+ }
56
+ .researcher-message {
57
+ background-color: #f0f8ff;
58
+ border-left-color: #4CAF50;
59
+ }
60
+ .chart-generator-message {
61
+ background-color: #fff5f5;
62
+ border-left-color: #FF6B6B;
63
+ }
64
+ .user-message {
65
+ background-color: #f9f9f9;
66
+ border-left-color: #667eea;
67
+ }
68
+ </style>
69
+ """, unsafe_allow_html=True)
70
+
71
+ # Initialize session state
72
+ if 'workflow_result' not in st.session_state:
73
+ st.session_state.workflow_result = None
74
+ if 'chart_generated' not in st.session_state:
75
+ st.session_state.chart_generated = False
76
+
77
+ @st.cache_resource
78
+ def initialize_workflow():
79
+ """Initialize the workflow with proper configuration."""
80
+
81
+ # Get API keys from environment or Streamlit secrets
82
+ groq_api_key = os.getenv("GROQ_API_KEY") or st.secrets.get("GROQ_API_KEY", "")
83
+ tavily_api_key = os.getenv("TAVILY_API_KEY") or st.secrets.get("TAVILY_API_KEY", "")
84
+
85
+ if not groq_api_key:
86
+ st.error("❌ Groq API key not found! Please add it in Streamlit secrets or .env file")
87
+ st.stop()
88
+
89
+ # Set up LLM
90
+ llm = ChatGroq(
91
+ model="llama3-70b-8192",
92
+ api_key=SecretStr(groq_api_key) if groq_api_key else None,
93
+ temperature=0.1
94
+ )
95
+
96
+ # Set up search tool
97
+ if tavily_api_key:
98
+ search_tool = TavilySearchResults(tavily_api_key=tavily_api_key)
99
+ else:
100
+ search_tool = DuckDuckGoSearchRun()
101
+
102
+ # Set up Python REPL
103
+ repl = PythonREPL()
104
+
105
+ # Define Python REPL tool
106
+ @tool
107
+ def python_repl_tool(
108
+ code: Annotated[str, "The python code to execute to generate your chart."],
109
+ ):
110
+ """Use this to execute python code. If you want to see the output of a value,
111
+ you should print it out with `print(...)`. This is visible to the user."""
112
+
113
+ try:
114
+ # Enhanced code with matplotlib backend for Streamlit
115
+ enhanced_code = f"""
116
+ import matplotlib
117
+ matplotlib.use('Agg')
118
+ import matplotlib.pyplot as plt
119
+ import pandas as pd
120
+ import numpy as np
121
+ import seaborn as sns
122
+ import warnings
123
+ warnings.filterwarnings('ignore')
124
+
125
+ # Configure for better display
126
+ plt.style.use('default')
127
+ plt.rcParams['figure.figsize'] = (12, 8)
128
+ plt.rcParams['figure.dpi'] = 100
129
+
130
+ {code}
131
+
132
+ # Check if plot was created and save
133
+ if plt.get_fignums():
134
+ plt.savefig('generated_chart.png', bbox_inches='tight', dpi=150)
135
+ print("Chart created and saved successfully!")
136
+ chart_created = True
137
+ else:
138
+ chart_created = False
139
+ print("No chart was created.")
140
+ """
141
+
142
+ # Execute the enhanced code
143
+ result = repl.run(enhanced_code)
144
+
145
+ # Check if chart file was actually created
146
+ if os.path.exists('generated_chart.png'):
147
+ st.session_state.chart_generated = True
148
+ # Display the chart immediately in Streamlit
149
+ try:
150
+ import matplotlib.pyplot as plt
151
+ import matplotlib.image as mpimg
152
+
153
+ # Read and display the saved image
154
+ img = mpimg.imread('generated_chart.png')
155
+ fig, ax = plt.subplots(figsize=(12, 8))
156
+ ax.imshow(img)
157
+ ax.axis('off')
158
+ st.pyplot(fig)
159
+ plt.close('all') # Clean up
160
+
161
+ except Exception as display_error:
162
+ st.warning(f"Chart saved but display failed: {display_error}")
163
+ else:
164
+ st.session_state.chart_generated = False
165
+
166
+ return f"Successfully executed:\n```python\n{code}\n```\nOutput: {result}\n\nIf you have completed all tasks, respond with FINAL ANSWER"
167
+
168
+ except Exception as e:
169
+ return f"Failed to execute. Error: {repr(e)}"
170
+
171
+ # System prompt function
172
+ def make_system_prompt(instruction: str) -> str:
173
+ return (
174
+ "You are a helpful AI assistant, collaborating with other assistants."
175
+ " Use the provided tools to progress towards answering the question."
176
+ " If you are unable to fully answer, that's OK, another assistant with different tools "
177
+ " will help where you left off. Execute what you can to make progress."
178
+ " If you or any of the other assistants have the final answer or deliverable,"
179
+ " prefix your response with FINAL ANSWER so the team knows to stop."
180
+ f"\n{instruction}"
181
+ )
182
+
183
+ # Node routing function
184
+ def get_next_node(last_message: BaseMessage, goto: str):
185
+ if "FINAL ANSWER" in last_message.content:
186
+ return END
187
+ return goto
188
+
189
+ # Agent 1: Research Node
190
+ def research_node(state: MessagesState) -> Command:
191
+ research_agent = create_react_agent(
192
+ llm,
193
+ tools=[search_tool],
194
+ prompt=make_system_prompt(
195
+ """You can only do research. You are working with a chart generator colleague.
196
+ Your job is to:
197
+ 1. Search for the requested data
198
+ 2. Gather specific numerical data, statistics, or information needed
199
+ 3. Present the data in a clear, structured format
200
+ 4. Do NOT attempt to create charts yourself
201
+
202
+ When you have sufficient data, clearly indicate that your chart_generator
203
+ colleague should take over to create the visualization."""
204
+ ),
205
+ )
206
+
207
+ result = research_agent.invoke(state)
208
+ goto = get_next_node(result["messages"][-1], "chart_generator")
209
+ result["messages"][-1] = HumanMessage(
210
+ content=result["messages"][-1].content,
211
+ name="researcher"
212
+ )
213
+ return Command(update={"messages": result["messages"]}, goto=goto)
214
+
215
+ # Agent 2: Chart Generator Node
216
+ def chart_node(state: MessagesState) -> Command:
217
+ chart_agent = create_react_agent(
218
+ llm,
219
+ tools=[python_repl_tool],
220
+ prompt=make_system_prompt(
221
+ """You can only generate charts. You are working with a researcher colleague.
222
+ Your job is to:
223
+ 1. Take the data provided by the researcher
224
+ 2. Create the requested visualization using matplotlib
225
+ 3. Use proper labels, titles, and formatting
226
+ 4. Once the chart is created successfully, respond with FINAL ANSWER
227
+
228
+ Available libraries: matplotlib, pandas, numpy, seaborn
229
+ IMPORTANT: Always include plt.show() at the end of your code to ensure the chart is displayed.
230
+
231
+ Do NOT search for additional data - use what the researcher provided.
232
+
233
+ Example chart code structure:
234
+ ```python
235
+ # Your data processing here
236
+ plt.figure(figsize=(12, 8))
237
+ # Your plotting code here
238
+ plt.title('Your Chart Title')
239
+ plt.xlabel('X Label')
240
+ plt.ylabel('Y Label')
241
+ plt.grid(True, alpha=0.3)
242
+ plt.tight_layout()
243
+ plt.show() # This is essential for display
244
+ ```"""
245
+ ),
246
+ )
247
+
248
+ result = chart_agent.invoke(state)
249
+ goto = get_next_node(result["messages"][-1], "researcher")
250
+ result["messages"][-1] = HumanMessage(
251
+ content=result["messages"][-1].content,
252
+ name="chart_generator"
253
+ )
254
+ return Command(update={"messages": result["messages"]}, goto=goto)
255
+
256
+ # Build the workflow
257
+ workflow = StateGraph(MessagesState)
258
+ workflow.add_node("researcher", research_node)
259
+ workflow.add_node("chart_generator", chart_node)
260
+ workflow.add_edge(START, "researcher")
261
+
262
+ # Compile the workflow
263
+ app = workflow.compile()
264
+
265
+ return app
266
+
267
+ def display_conversation(messages):
268
+ """Display the conversation in a nice format."""
269
+ for i, msg in enumerate(messages):
270
+ if hasattr(msg, 'name') and msg.name:
271
+ if msg.name == "researcher":
272
+ st.markdown(f"""
273
+ <div class="chat-message researcher-message">
274
+ <strong>πŸ” RESEARCHER:</strong><br>
275
+ {msg.content}
276
+ </div>
277
+ """, unsafe_allow_html=True)
278
+ elif msg.name == "chart_generator":
279
+ st.markdown(f"""
280
+ <div class="chat-message chart-generator-message">
281
+ <strong>πŸ“Š CHART GENERATOR:</strong><br>
282
+ {msg.content}
283
+ </div>
284
+ """, unsafe_allow_html=True)
285
+ elif i == 0: # User message
286
+ st.markdown(f"""
287
+ <div class="chat-message user-message">
288
+ <strong>πŸ‘€ USER:</strong><br>
289
+ {msg.content}
290
+ </div>
291
+ """, unsafe_allow_html=True)
292
+
293
+ # Main app
294
+ def main():
295
+ # Header
296
+ st.markdown("""
297
+ < ivory="e8e616e0-d894-4936-a3f5-391682ee794c" title="app.py" contentType="text/python">
298
+ <div class="main-header">
299
+ <h1>πŸ€– AI Research & Chart Generator</h1>
300
+ <p>Multi-Agent System for Intelligent Data Research and Visualization</p>
301
+ </div>
302
+ """, unsafe_allow_html=True)
303
+
304
+ # Sidebar
305
+ with st.sidebar:
306
+ st.header("πŸ“‹ How it works")
307
+ st.markdown("""
308
+ 1. **πŸ” Research Agent**: Searches for data online
309
+ 2. **πŸ“Š Chart Generator**: Creates visualizations
310
+ 3. **🀝 Collaboration**: Agents work together seamlessly
311
+ """)
312
+
313
+ st.header("πŸ’‘ Example Queries")
314
+ example_queries = [
315
+ "Show me top 10 most populated countries with a bar chart",
316
+ "What is UK's GDP in past 3 years, draw line chart",
317
+ "Create a line chart of Bitcoin price trend in last 6 months",
318
+ "IPL winners in last 5 years with their final match scores",
319
+ "Global temperature trends in last decade visualization"
320
+ ]
321
+
322
+ for query in example_queries:
323
+ if st.button(f"πŸ“ {query[:30]}...", key=query, use_container_width=True):
324
+ st.session_state.selected_query = query
325
+
326
+ # Initialize workflow
327
+ try:
328
+ app = initialize_workflow()
329
+ st.success("βœ… Multi-Agent System Initialized Successfully!")
330
+ except Exception as e:
331
+ st.error(f"❌ Failed to initialize: {str(e)}")
332
+ st.stop()
333
+
334
+ # Main input
335
+ col1, col2 = st.columns([4, 1])
336
+
337
+ with col1:
338
+ user_query = st.text_input(
339
+ "🎯 What would you like to research and visualize?",
340
+ value=st.session_state.get('selected_query', ''),
341
+ placeholder="e.g., Show me top 10 most populated countries with a bar chart"
342
+ )
343
+
344
+ with col2:
345
+ recursion_limit = st.number_input("Max Steps", min_value=5, max_value=2500, value=1500)
346
+
347
+ # Generate button
348
+ if st.button("πŸš€ Generate Research & Chart", type="primary", use_container_width=True):
349
+ if user_query:
350
+ st.session_state.chart_generated = False
351
+
352
+ with st.spinner("πŸ€– Agents are working together..."):
353
+ try:
354
+ config: RunnableConfig = {"recursion_limit": recursion_limit}
355
+ result = app.invoke(
356
+ {"messages": [("user", user_query)]},
357
+ config=config
358
+ )
359
+ st.session_state.workflow_result = result
360
+
361
+ except Exception as e:
362
+ st.error(f"❌ Error during execution: {str(e)}")
363
+ st.exception(e)
364
+ else:
365
+ st.warning("⚠️ Please enter a query!")
366
+
367
+ # Display results
368
+ if st.session_state.workflow_result:
369
+ st.header("πŸ—£οΈ Agent Conversation")
370
+
371
+ with st.expander("View Full Conversation", expanded=True):
372
+ display_conversation(st.session_state.workflow_result["messages"])
373
+
374
+ # Check for generated chart file
375
+ chart_path = "generated_chart.png"
376
+ if os.path.exists(chart_path):
377
+ st.success("πŸŽ‰ Chart generated successfully!")
378
+ st.image(chart_path, caption="Generated Chart", use_column_width=True)
379
+ st.session_state.chart_generated = True
380
+
381
+ # Download option
382
+ with open(chart_path, "rb") as file:
383
+ st.download_button(
384
+ label="πŸ“₯ Download Chart",
385
+ data=file.read(),
386
+ file_name="ai_generated_chart.png",
387
+ mime="image/png"
388
+ )
389
+ elif st.session_state.chart_generated:
390
+ st.warning("⚠️ Chart was generated but file not found.")
391
+ else:
392
+ st.info("ℹ️ No chart generated yet or agents are still working.")
393
+
394
+ if __name__ == "__main__":
395
+ main()