RAVENOCC commited on
Commit
6636113
·
verified ·
1 Parent(s): 614091a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +1723 -0
  2. requirements.txt +15 -0
app.py ADDED
@@ -0,0 +1,1723 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import numpy as np
5
+ import uuid
6
+ import time
7
+ import pickle
8
+ import base64
9
+ from io import BytesIO, StringIO
10
+ import sys
11
+ import operator
12
+ from typing import Literal, Sequence, TypedDict, Annotated, List, Dict, Tuple
13
+ import tempfile
14
+ import shutil
15
+ import plotly.io as pio
16
+ import io
17
+ import requests
18
+ import re
19
+ import json
20
+ import openai
21
+ # from fpdf import FPDF
22
+ import base64
23
+ from datetime import datetime
24
+ from reportlab.lib.pagesizes import letter
25
+ from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image
26
+ from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
27
+ from reportlab.lib.units import inch
28
+ from PIL import Image as PILImage
29
+
30
+
31
+ # Import LangChain and LangGraph components
32
+
33
+ from langchain_core.messages import AIMessage, ToolMessage, HumanMessage, BaseMessage
34
+ from langchain_core.prompts import ChatPromptTemplate
35
+ from langchain_openai import ChatOpenAI
36
+ from langchain_groq import ChatGroq
37
+ from langchain_experimental.utilities import PythonREPL
38
+ from langgraph.prebuilt import ToolInvocation, ToolExecutor
39
+ from langchain_core.tools import tool
40
+ from langgraph.prebuilt import InjectedState
41
+ from langgraph.graph import StateGraph, END
42
+ from reportlab.platypus import PageBreak
43
+ from PIL import Image as PILImage
44
+
45
+ # Set your Groq API key
46
+
47
+ os.environ["GROQ_API_KEY"] = "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"
48
+ # os.environ["OPENAI_API_KEY"] = "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"
49
+
50
+ os.environ["OPENAI_API_KEY"] = "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"
51
+
52
+ # Create temporary directory for file storage
53
+
54
+ if 'temp_dir' not in st.session_state:
55
+ st.session_state.temp_dir = tempfile.mkdtemp()
56
+ st.session_state.images_dir = os.path.join(st.session_state.temp_dir, "images/plotly_figures/pickle")
57
+ os.makedirs(st.session_state.images_dir, exist_ok=True)
58
+ print(f"Created temporary directory: {st.session_state.temp_dir}")
59
+ print(f"Created images directory: {st.session_state.images_dir}")
60
+
61
+ # Define the system prompt
62
+
63
+ SYSTEM_PROMPT = """## Role
64
+ You are a professional data scientist helping a non-technical user understand, analyze, and visualize their data.
65
+
66
+ ## Capabilities
67
+
68
+ 1. **Execute python code** using the `complete_python_task` tool.
69
+
70
+ ## Goals
71
+
72
+ 1. Understand the user's objectives clearly.
73
+
74
+ 2. Take the user on a data analysis journey, iterating to find the best way to visualize or analyse their data to solve their problems.
75
+
76
+ 3. Investigate if the goal is achievable by running Python code via the `python_code` field.
77
+
78
+ 4. Gain input from the user at every step to ensure the analysis is on the right track and to understand business nuances.
79
+
80
+ ## Code Guidelines
81
+
82
+ - **ALL INPUT DATA IS LOADED ALREADY**, so use the provided variable names to access the data.
83
+
84
+ - **VARIABLES PERSIST BETWEEN RUNS**, so reuse previously defined variables if needed.
85
+
86
+ - **TO SEE CODE OUTPUT**, use `print()` statements. You won't be able to see outputs of `pd.head()`, `pd.describe()` etc. otherwise.
87
+
88
+ - **ONLY USE THE FOLLOWING LIBRARIES**:
89
+
90
+ - `pandas`
91
+
92
+ - `sklearn`
93
+
94
+ - `plotly`
95
+
96
+ All these libraries are already imported for you.
97
+
98
+ ## Plotting Guidelines
99
+
100
+ - Always use the `plotly` library for plotting.
101
+
102
+ - Store all plotly figures inside a `plotly_figures` list, they will be saved automatically.
103
+
104
+ - Do not try and show the plots inline with `fig.show()`.
105
+
106
+ """
107
+
108
+ # Define the State class
109
+ class AgentState(TypedDict):
110
+ messages: Annotated[Sequence[BaseMessage], operator.add]
111
+ input_data: Annotated[List[Dict], operator.add]
112
+ intermediate_outputs: Annotated[List[dict], operator.add]
113
+ current_variables: dict
114
+ output_image_paths: Annotated[List[str], operator.add]
115
+
116
+ # Initialize session state variables
117
+
118
+ if 'in_memory_datasets' not in st.session_state:
119
+ st.session_state.in_memory_datasets = {}
120
+
121
+ if 'persistent_vars' not in st.session_state:
122
+ st.session_state.persistent_vars = {}
123
+
124
+ if 'dataset_metadata_list' not in st.session_state:
125
+ st.session_state.dataset_metadata_list = []
126
+
127
+ if 'chat_history' not in st.session_state:
128
+ st.session_state.chat_history = []
129
+
130
+ if 'dashboard_plots' not in st.session_state:
131
+ st.session_state.dashboard_plots = [None, None, None, None]
132
+
133
+ if 'columns' not in st.session_state:
134
+ st.session_state.columns = ["No columns available"]
135
+
136
+ if 'custom_plots_to_save' not in st.session_state:
137
+ st.session_state.custom_plots_to_save = {}
138
+
139
+ # Set up the tools
140
+
141
+ repl = PythonREPL()
142
+ plotly_saving_code = """import pickle
143
+
144
+ import uuid
145
+ import os
146
+ for figure in plotly_figures:
147
+ pickle_filename = f"{images_dir}/{uuid.uuid4()}.pickle"
148
+ with open(pickle_filename, 'wb') as f:
149
+ pickle.dump(figure, f)
150
+ """
151
+
152
+ @tool
153
+ def complete_python_task(
154
+ graph_state: Annotated[dict, InjectedState],
155
+ thought: str,
156
+ python_code: str
157
+ ) -> Tuple[str, dict]:
158
+
159
+ """Execute Python code for data analysis and visualization."""
160
+
161
+ current_variables = graph_state.get("current_variables", {})
162
+
163
+ # Load datasets from in-memory storage
164
+
165
+ for input_dataset in graph_state.get("input_data", []):
166
+ var_name = input_dataset.get("variable_name")
167
+ if var_name and var_name not in current_variables and var_name in st.session_state.in_memory_datasets:
168
+ print(f"Loading {var_name} from in-memory storage")
169
+ current_variables[var_name] = st.session_state.in_memory_datasets[var_name]
170
+ current_image_pickle_files = os.listdir(st.session_state.images_dir)
171
+
172
+ try:
173
+ # Capture stdout
174
+ old_stdout = sys.stdout
175
+ sys.stdout = StringIO()
176
+
177
+ # Execute the code and capture the result
178
+ exec_globals = globals().copy()
179
+ exec_globals.update(st.session_state.persistent_vars)
180
+ exec_globals.update(current_variables)
181
+ exec_globals.update({"plotly_figures": [], "images_dir": st.session_state.images_dir})
182
+ exec(python_code, exec_globals)
183
+
184
+ st.session_state.persistent_vars.update({k: v for k, v in exec_globals.items() if k not in globals()})
185
+
186
+ # Get the captured stdout
187
+ output = sys.stdout.getvalue()
188
+
189
+ # Restore stdout
190
+ sys.stdout = old_stdout
191
+
192
+ updated_state = {
193
+ "intermediate_outputs": [{"thought": thought, "code": python_code, "output": output}],
194
+ "current_variables": st.session_state.persistent_vars
195
+ }
196
+
197
+ if 'plotly_figures' in exec_globals and exec_globals['plotly_figures']:
198
+ exec(plotly_saving_code, exec_globals)
199
+
200
+ # Check if any images were created
201
+ new_image_folder_contents = os.listdir(st.session_state.images_dir)
202
+ new_image_files = [file for file in new_image_folder_contents if file not in current_image_pickle_files]
203
+
204
+ if new_image_files:
205
+ updated_state["output_image_paths"] = new_image_files
206
+ st.session_state.persistent_vars["plotly_figures"] = []
207
+ return output, updated_state
208
+
209
+ except Exception as e:
210
+ sys.stdout = old_stdout # Restore stdout in case of error
211
+ print(f"Error in complete_python_task: {str(e)}")
212
+ return str(e), {"intermediate_outputs": [{"thought": thought, "code": python_code, "output": str(e)}]}
213
+
214
+ # Set up the LLM and tools ( For testing purposes use the model names mentioned in the comments)
215
+ llm = ChatGroq(model="gemma2-9b-it", temperature=0)
216
+ # llm=ChatOpenAI(model="gpt-4-turbo", temperature=0)
217
+
218
+
219
+ # "deepseek-r1-distill-llama-70b"
220
+ # "meta-llama/llama-4-scout-17b-16e-instruct"
221
+ # "deepseek-r1-distill-qwen-32b"
222
+
223
+
224
+ tools = [complete_python_task]
225
+
226
+ model = llm.bind_tools(tools)
227
+
228
+ tool_executor = ToolExecutor(tools)
229
+
230
+ # Load the prompt template
231
+ chat_template = ChatPromptTemplate.from_messages([
232
+ ("system", SYSTEM_PROMPT),
233
+ ("placeholder", "{messages}"),
234
+ ])
235
+
236
+ model = chat_template | model
237
+
238
+ def create_data_summary(state: AgentState) -> str:
239
+
240
+ summary = ""
241
+ variables = []
242
+
243
+ # Add sample data for each dataset
244
+ for d in state.get("input_data", []):
245
+ var_name = d.get("variable_name")
246
+ if var_name:
247
+
248
+ variables.append(var_name)
249
+ summary += f"\n\nVariable: {var_name}\n"
250
+ summary += f"Description: {d.get('data_description', 'No description')}\n"
251
+
252
+ # Add sample data if available
253
+
254
+ if var_name in st.session_state.in_memory_datasets:
255
+
256
+ df = st.session_state.in_memory_datasets[var_name]
257
+ summary += "\nSample Data (first 5 rows):\n"
258
+ summary += df.head(5).to_string()
259
+
260
+ if "current_variables" in state:
261
+
262
+ remaining_variables = [v for v in state["current_variables"] if v not in variables and not v.startswith("_")]
263
+
264
+ for v in remaining_variables:
265
+
266
+ var_value = state["current_variables"].get(v)
267
+
268
+ if isinstance(var_value, pd.DataFrame):
269
+ summary += f"\n\nVariable: {v} (DataFrame with shape {var_value.shape})"
270
+
271
+ else:
272
+ summary += f"\n\nVariable: {v}"
273
+ return summary
274
+
275
+ def route_to_tools(state: AgentState) -> Literal["tools", "__end__"]:
276
+
277
+ """Determine if we should route to tools or end the chain"""
278
+
279
+ if messages := state.get("messages", []):
280
+ ai_message = messages[-1]
281
+
282
+ else:
283
+ raise ValueError(f"No messages found in input state to tool_edge: {state}")
284
+
285
+ if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
286
+ return "tools"
287
+
288
+ return "__end__"
289
+
290
+
291
+ def call_model(state: AgentState):
292
+
293
+ """Call the LLM to get a response"""
294
+ current_data_template = """The following data is available:\n{data_summary}"""
295
+ current_data_message = HumanMessage(
296
+
297
+ content=current_data_template.format(data_summary=create_data_summary(state))
298
+
299
+ )
300
+ messages = [current_data_message] + state["messages"]
301
+ llm_outputs = model.invoke({"messages": messages})
302
+ return {"messages": [llm_outputs], "intermediate_outputs": [current_data_message.content]}
303
+
304
+
305
+ def call_tools(state: AgentState):
306
+
307
+ """Execute tools called by the LLM"""
308
+ last_message = state["messages"][-1]
309
+ tool_invocations = []
310
+
311
+ if isinstance(last_message, AIMessage) and hasattr(last_message, 'tool_calls'):
312
+
313
+ tool_invocations = [
314
+
315
+ ToolInvocation(
316
+ tool=tool_call["name"],
317
+ tool_input={**tool_call["args"], "graph_state": state}
318
+
319
+ ) for tool_call in last_message.tool_calls
320
+
321
+ ]
322
+ responses = tool_executor.batch(tool_invocations, return_exceptions=True)
323
+
324
+ tool_messages = []
325
+
326
+ state_updates = {}
327
+
328
+ for tc, response in zip(last_message.tool_calls, responses):
329
+
330
+ if isinstance(response, Exception):
331
+
332
+ print(f"Exception in tool execution: {str(response)}")
333
+ tool_messages.append(ToolMessage(
334
+ content=f"Error: {str(response)}",
335
+ name=tc["name"],
336
+ tool_call_id=tc["id"]
337
+ ))
338
+
339
+ continue
340
+
341
+ message, updates = response
342
+ tool_messages.append(ToolMessage(
343
+
344
+ content=str(message),
345
+ name=tc["name"],
346
+ tool_call_id=tc["id"]
347
+
348
+ ))
349
+
350
+ # Merge updates instead of overwriting
351
+
352
+ for key, value in updates.items():
353
+
354
+ if key in state_updates:
355
+
356
+ if isinstance(value, list) and isinstance(state_updates[key], list):
357
+ state_updates[key].extend(value)
358
+
359
+ elif isinstance(value, dict) and isinstance(state_updates[key], dict):
360
+ state_updates[key].update(value)
361
+
362
+ else:
363
+ state_updates[key] = value
364
+
365
+ else:
366
+ state_updates[key] = value
367
+
368
+ if 'messages' not in state_updates:
369
+ state_updates["messages"] = []
370
+
371
+ state_updates["messages"] = tool_messages
372
+ return state_updates
373
+
374
+ # Set up the graph
375
+
376
+ workflow = StateGraph(AgentState)
377
+ workflow.add_node("agent", call_model)
378
+ workflow.add_node("tools", call_tools)
379
+ workflow.add_conditional_edges(
380
+
381
+ "agent",
382
+
383
+ route_to_tools,
384
+
385
+ {
386
+
387
+ "tools": "tools",
388
+
389
+ "__end__": END
390
+
391
+ }
392
+
393
+ )
394
+ workflow.add_edge("tools", "agent")
395
+ workflow.set_entry_point("agent")
396
+
397
+ chain = workflow.compile()
398
+
399
+
400
+ def process_file_upload(files):
401
+
402
+ """Process uploaded files and return dataframe previews and column names"""
403
+
404
+ st.session_state.in_memory_datasets = {} # Clear previous datasets
405
+ st.session_state.dataset_metadata_list = [] # Clear previous metadata
406
+ st.session_state.persistent_vars.clear() # Clear persistent variables for new session
407
+
408
+ if not files:
409
+
410
+ return "No files uploaded.", [], ["No columns available"]
411
+
412
+ results = []
413
+
414
+ all_columns = [] # Track all columns from all datasets
415
+
416
+ for file in files:
417
+
418
+ try:
419
+
420
+ # Use file object directly
421
+
422
+ if file.name.endswith('.csv'):
423
+
424
+ df = pd.read_csv(file)
425
+
426
+ elif file.name.endswith(('.xls', '.xlsx')):
427
+
428
+ df = pd.read_excel(file)
429
+
430
+ else:
431
+
432
+ results.append(f"Unsupported file format: {file.name}. Please upload CSV or Excel files.")
433
+
434
+ continue
435
+
436
+ var_name = file.name.split('.')[0].replace('-', '_').replace(' ', '_').lower()
437
+
438
+ st.session_state.in_memory_datasets[var_name] = df
439
+
440
+ # Collect all columns
441
+ all_columns.extend(df.columns.tolist())
442
+
443
+ # Create dataset metadata
444
+ dataset_metadata = {
445
+
446
+ "variable_name": var_name,
447
+ "data_path": "in_memory",
448
+ "data_description": f"Dataset containing {df.shape[0]} rows and {df.shape[1]} columns. Columns: {', '.join(df.columns.tolist())}",
449
+ "original_filename": file.name
450
+
451
+ }
452
+
453
+ st.session_state.dataset_metadata_list.append(dataset_metadata)
454
+
455
+ # Return preview of the dataset
456
+ preview = f"### Dataset: {file.name}\nVariable name: `{var_name}`\n\n"
457
+ preview += df.head(10).to_markdown()
458
+ results.append(preview)
459
+ print(f"Successfully processed {file.name}")
460
+
461
+ except Exception as e:
462
+
463
+ print(f"Error processing {file.name}: {str(e)}")
464
+ results.append(f"Error processing {file.name}: {str(e)}")
465
+
466
+ # Get unique columns
467
+
468
+ unique_columns = []
469
+ seen = set()
470
+
471
+ for col in all_columns:
472
+
473
+ if col not in seen:
474
+
475
+ seen.add(col)
476
+
477
+ unique_columns.append(col)
478
+
479
+ if not unique_columns:
480
+
481
+ unique_columns = ["No columns available"]
482
+
483
+ print(f"Found {len(unique_columns)} unique columns across datasets")
484
+ return "\n\n".join(results), st.session_state.dataset_metadata_list, unique_columns
485
+
486
+
487
+ def get_columns():
488
+
489
+ """Directly gets columns from in-memory datasets"""
490
+
491
+ all_columns = []
492
+
493
+ for var_name, df in st.session_state.in_memory_datasets.items():
494
+
495
+ if isinstance(df, pd.DataFrame):
496
+
497
+ all_columns.extend(df.columns.tolist())
498
+
499
+ # Remove duplicates while preserving order
500
+
501
+ unique_columns = []
502
+ seen = set()
503
+
504
+ for col in all_columns:
505
+
506
+ if col not in seen:
507
+ seen.add(col)
508
+ unique_columns.append(col)
509
+
510
+ if not unique_columns:
511
+ unique_columns = ["No columns available"]
512
+
513
+ print(f"Populating dropdowns with {len(unique_columns)} columns")
514
+ return unique_columns
515
+
516
+ # === FUNCTIONS ===
517
+
518
+ import openai
519
+ import pandas as pd
520
+ import json
521
+ import re
522
+
523
+ openai.api_key = "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" # Replace with your OpenAI API key
524
+ MODEL = "gpt-4-turbo"
525
+
526
+ def standard_clean(df):
527
+ df.columns = [re.sub(r'\W+', '_', col).strip().lower() for col in df.columns]
528
+ df.drop_duplicates(inplace=True)
529
+ df.dropna(axis=1, how='all', inplace=True)
530
+ df.dropna(axis=0, how='all', inplace=True)
531
+ for col in df.select_dtypes(include='object').columns:
532
+ df[col] = df[col].astype(str).str.strip()
533
+ return df
534
+
535
+ def query_openai(prompt):
536
+ try:
537
+ response = openai.ChatCompletion.create(
538
+ model=MODEL,
539
+ messages=[{"role": "user", "content": prompt}],
540
+ temperature=0.7
541
+ )
542
+ return response.choices[0].message.content
543
+ except Exception as e:
544
+ print(f"OpenAI API Error: {e}")
545
+ return "{}"
546
+
547
+ def llm_suggest_cleaning(df):
548
+ sample = df.head(10).to_csv(index=False)
549
+ prompt = f"""
550
+ You are a professional data wrangler. Below is a sample of a messy dataset.
551
+
552
+ Return a Python dictionary with the following keys:
553
+
554
+ 1. rename_columns – fix unclear or inconsistent column names
555
+ 2. convert_types – correct datatypes: int, float, str, or date
556
+ 3. fill_missing – use 'mean', 'median', 'mode', or a constant like 'Unknown' or 0
557
+ 4. value_map – map inconsistent values (e.g., yes/Yes/Y → Yes)
558
+
559
+ Do not drop any rows or columns. Your output must be a valid Python dict.
560
+
561
+ Example:
562
+ {{
563
+ "rename_columns": {{"dob": "date_of_birth"}},
564
+ "convert_types": {{"age": "int", "salary": "float", "signup_date": "date"}},
565
+ "fill_missing": {{"gender": "mode", "salary": -1}},
566
+ "value_map": {{
567
+ "gender": {{"M": "Male", "F": "Female"}},
568
+ "subscribed": {{"Y": "Yes", "N": "No"}}
569
+ }}
570
+ }}
571
+ Apart from these mentioned steps, study the data and also do whatever things are good and needed for that particular dataset and do the cleaning.
572
+ Sample data:
573
+ {sample}
574
+ """
575
+ raw_response = query_openai(prompt)
576
+ try:
577
+ suggestions = eval(raw_response)
578
+ return suggestions
579
+ except:
580
+ print("Could not parse suggestions.")
581
+ return {
582
+ "rename_columns": {},
583
+ "convert_types": {},
584
+ "fill_missing": {},
585
+ "value_map": {}
586
+ }
587
+
588
+ def apply_suggestions(df, suggestions):
589
+ df.rename(columns=suggestions.get("rename_columns", {}), inplace=True)
590
+
591
+ for col, dtype in suggestions.get("convert_types", {}).items():
592
+ if col not in df.columns:
593
+ continue
594
+ try:
595
+ if dtype == "int":
596
+ df[col] = pd.to_numeric(df[col], errors='coerce').astype("Int64")
597
+ elif dtype == "float":
598
+ df[col] = pd.to_numeric(df[col], errors='coerce')
599
+ elif dtype == "str":
600
+ df[col] = df[col].astype(str)
601
+ elif dtype == "date":
602
+ df[col] = pd.to_datetime(df[col], errors='coerce')
603
+ except:
604
+ print(f"Failed to convert {col} to {dtype}")
605
+
606
+ for col, method in suggestions.get("fill_missing", {}).items():
607
+ if col not in df.columns:
608
+ continue
609
+ try:
610
+ if method == "mean":
611
+ df[col].fillna(df[col].mean(), inplace=True)
612
+ elif method == "median":
613
+ df[col].fillna(df[col].median(), inplace=True)
614
+ elif method == "mode":
615
+ df[col].fillna(df[col].mode().iloc[0], inplace=True)
616
+ elif isinstance(method, str):
617
+ df[col].fillna(method, inplace=True)
618
+ except:
619
+ print(f"Could not fill missing values for {col}")
620
+
621
+ for col, mapping in suggestions.get("value_map", {}).items():
622
+ if col in df.columns:
623
+ try:
624
+ df[col] = df[col].replace(mapping)
625
+ except:
626
+ print(f"Could not map values in {col}")
627
+
628
+ return df
629
+
630
+ # def generate_pdf_report():
631
+
632
+ # """Generate a PDF report with chat history and dashboard visualizations"""
633
+
634
+ # try:
635
+
636
+ # # Create PDF object with Unicode support
637
+
638
+ # pdf = FPDF()
639
+
640
+ # pdf.add_page()
641
+
642
+ # pdf.add_font('DejaVu', '', 'DejaVuSansCondensed.ttf', uni=True)
643
+
644
+
645
+
646
+ # # Set title
647
+
648
+ # pdf.set_font('Arial', 'B', 16)
649
+
650
+ # pdf.cell(0, 10, 'Data Analysis Report', 0, 1, 'C')
651
+
652
+ # pdf.ln(10)
653
+
654
+
655
+
656
+ # # Add timestamp
657
+
658
+ # pdf.set_font('Arial', 'I', 10)
659
+
660
+ # pdf.cell(0, 5, f'Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}', 0, 1, 'R')
661
+
662
+ # pdf.ln(10)
663
+
664
+
665
+
666
+ # # Add dashboard plots if available
667
+
668
+ # pdf.set_font('Arial', 'B', 14)
669
+
670
+ # pdf.cell(0, 10, 'Dashboard Visualizations', 0, 1, 'L')
671
+
672
+
673
+
674
+ # plot_count = 0
675
+
676
+ # for i, plot in enumerate(st.session_state.dashboard_plots):
677
+
678
+ # if plot is not None:
679
+
680
+ # plot_count += 1
681
+
682
+ # # Convert plotly figure to image
683
+
684
+ # img_bytes = io.BytesIO()
685
+
686
+ # plot.write_image(img_bytes, format='png', width=500, height=300)
687
+
688
+ # img_bytes.seek(0)
689
+
690
+
691
+
692
+ # # Create a temporary file for the image
693
+
694
+ # temp_img_path = f"{st.session_state.temp_dir}/plot_{i}.png"
695
+
696
+ # with open(temp_img_path, 'wb') as f:
697
+
698
+ # f.write(img_bytes.getvalue())
699
+
700
+
701
+
702
+ # # Add to PDF
703
+
704
+ # pdf.ln(5)
705
+
706
+ # pdf.cell(0, 5, f'Visualization {i+1}', 0, 1, 'L')
707
+
708
+ # pdf.image(temp_img_path, x=10, w=180)
709
+
710
+ # pdf.ln(5)
711
+
712
+
713
+
714
+ # if plot_count == 0:
715
+
716
+ # pdf.set_font('Arial', '', 10)
717
+
718
+ # pdf.cell(0, 10, 'No visualizations have been added to the dashboard.', 0, 1, 'L')
719
+
720
+
721
+
722
+ # # Add chat history
723
+
724
+ # pdf.add_page()
725
+
726
+ # pdf.set_font('Arial', 'B', 14)
727
+
728
+ # pdf.cell(0, 10, 'Analysis Conversation History', 0, 1, 'L')
729
+
730
+
731
+
732
+ # if st.session_state.chat_history:
733
+
734
+ # pdf.set_font('Arial', '', 10)
735
+
736
+ # for i, (user_msg, assistant_msg) in enumerate(st.session_state.chat_history):
737
+
738
+ # # Clean messages of emojis and other problematic characters
739
+
740
+ # user_msg_clean = ''.join(c for c in user_msg if ord(c) < 128)
741
+
742
+
743
+
744
+ # # Simplify assistant message (remove markdown and image references)
745
+
746
+ # assistant_msg_clean = assistant_msg.replace('![Visualization]', '[Visualization included in dashboard]')
747
+
748
+ # assistant_msg_clean = ''.join(c for c in assistant_msg_clean if ord(c) < 128)
749
+
750
+
751
+
752
+ # pdf.ln(5)
753
+
754
+ # pdf.set_font('Arial', 'B', 10)
755
+
756
+ # pdf.cell(0, 5, f'You: ', 0, 1, 'L')
757
+
758
+
759
+
760
+ # pdf.set_font('Arial', '', 10)
761
+
762
+ # pdf.multi_cell(0, 5, user_msg_clean)
763
+
764
+
765
+
766
+ # pdf.ln(3)
767
+
768
+ # pdf.set_font('Arial', 'B', 10)
769
+
770
+ # pdf.cell(0, 5, f'Assistant: ', 0, 1, 'L')
771
+
772
+
773
+
774
+ # pdf.set_font('Arial', '', 10)
775
+
776
+ # pdf.multi_cell(0, 5, assistant_msg_clean[:1000] + ('...' if len(assistant_msg_clean) > 1000 else ''))
777
+
778
+
779
+
780
+ # pdf.ln(5)
781
+
782
+ # else:
783
+
784
+ # pdf.set_font('Arial', '', 10)
785
+
786
+ # pdf.cell(0, 10, 'No conversation history available.', 0, 1, 'L')
787
+
788
+
789
+
790
+ # # Save PDF to a bytes buffer
791
+
792
+ # pdf_output = io.BytesIO()
793
+
794
+ # pdf.output(pdf_output)
795
+
796
+ # pdf_output.seek(0)
797
+
798
+
799
+
800
+ # return pdf_output.getvalue()
801
+
802
+
803
+
804
+ # except Exception as e:
805
+
806
+ # import traceback
807
+
808
+ # print(f"Error generating PDF report: {str(e)}")
809
+
810
+ # print(traceback.format_exc())
811
+
812
+ # return None
813
+
814
+
815
+
816
+ # import io
817
+
818
+ # from reportlab.lib.pagesizes import letter
819
+
820
+ # from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image
821
+
822
+ # from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
823
+
824
+ # from reportlab.lib.units import inch
825
+
826
+ # import plotly.io as pio
827
+
828
+ # from PIL import Image as PILImage
829
+
830
+ # import numpy as np
831
+
832
+ # import base64
833
+
834
+ # from datetime import datetime
835
+
836
+
837
+
838
+ def capture_dashboard_screenshot():
839
+
840
+ """Capture the entire dashboard as a single image"""
841
+
842
+ try:
843
+ # Create a figure that combines all dashboard plots
844
+ import plotly.graph_objects as go
845
+ from plotly.subplots import make_subplots
846
+
847
+
848
+ # Create a 2x2 subplot
849
+
850
+ fig = make_subplots(rows=2, cols=2,
851
+ subplot_titles=["Visualization 1", "Visualization 2",
852
+
853
+ "Visualization 3", "Visualization 4"])
854
+
855
+ # Add each plot from the dashboard to the combined figure
856
+
857
+ for i, plot in enumerate(st.session_state.dashboard_plots):
858
+
859
+ if plot is not None:
860
+
861
+ row = (i // 2) + 1
862
+ col = (i % 2) + 1
863
+
864
+ # Extract traces from the original figure and add to our subplot
865
+
866
+ for trace in plot.data:
867
+ fig.add_trace(trace, row=row, col=col)
868
+
869
+
870
+
871
+ # Copy layout properties for each subplot
872
+
873
+ for axis_type in ['xaxis', 'yaxis']:
874
+
875
+ axis_name = f"{axis_type}{i+1 if i > 0 else ''}"
876
+ subplot_name = f"{axis_type}{row}{col}"
877
+
878
+ # Copy axis properties if they exist
879
+
880
+ if hasattr(plot.layout, axis_name):
881
+ axis_props = getattr(plot.layout, axis_name)
882
+ fig.update_layout({subplot_name: axis_props})
883
+
884
+
885
+
886
+ # Update layout for better appearance
887
+
888
+ fig.update_layout(
889
+ height=800,
890
+ width=1000,
891
+ title_text="Dashboard Overview",
892
+ showlegend=False,
893
+ )
894
+
895
+
896
+
897
+ # Save to a temporary file
898
+
899
+ dashboard_path = f"{st.session_state.temp_dir}/dashboard_combined.png"
900
+
901
+ fig.write_image(dashboard_path, scale=2) # Higher scale for better resolution
902
+
903
+ return dashboard_path
904
+
905
+ except Exception as e:
906
+
907
+ import traceback
908
+ print(f"Error capturing dashboard: {str(e)}")
909
+ print(traceback.format_exc())
910
+ return None
911
+
912
+
913
+
914
+ def generate_enhanced_pdf_report():
915
+
916
+ """Generate an enhanced PDF report with chat history first and dashboard screenshot"""
917
+
918
+ try:
919
+
920
+ # Create a buffer for the PDF
921
+
922
+ buffer = io.BytesIO()
923
+
924
+ # Create the PDF document with adjusted page size if needed
925
+
926
+ doc = SimpleDocTemplate(buffer, pagesize=letter,
927
+ leftMargin=36, rightMargin=36,
928
+ topMargin=36, bottomMargin=36)
929
+
930
+ styles = getSampleStyleSheet()
931
+
932
+ # Add custom styles
933
+
934
+ styles.add(ParagraphStyle(name='ReportTitle',
935
+
936
+ parent=styles['Heading1'],
937
+
938
+ alignment=1)) # 1 is centered
939
+
940
+
941
+ # Create the document content
942
+
943
+ elements = []
944
+
945
+ # Add title
946
+
947
+ elements.append(Paragraph('Data Analysis Report', styles['ReportTitle']))
948
+ elements.append(Spacer(1, 0.25*inch))
949
+
950
+
951
+ # Add timestamp
952
+
953
+ timestamp = Paragraph(f'Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}',
954
+
955
+ styles['Italic'])
956
+
957
+ elements.append(timestamp)
958
+ elements.append(Spacer(1, 0.5*inch))
959
+
960
+
961
+ # Add conversation history FIRST
962
+
963
+ elements.append(Paragraph('Analysis Conversation History', styles['Heading2']))
964
+ elements.append(Spacer(1, 0.1*inch))
965
+
966
+
967
+ if st.session_state.chat_history:
968
+
969
+ for i, (user_msg, assistant_msg) in enumerate(st.session_state.chat_history):
970
+
971
+ elements.append(Paragraph(f'<b>You:</b>', styles['Normal']))
972
+ elements.append(Paragraph(user_msg, styles['Normal']))
973
+ elements.append(Spacer(1, 0.1*inch))
974
+ elements.append(Paragraph(f'<b>Assistant:</b>', styles['Normal']))
975
+
976
+ # Simplify assistant message (remove markdown and image references)
977
+
978
+ simplified_msg = assistant_msg.replace('![Visualization]', '[Visualization included in dashboard]')
979
+ simplified_msg = simplified_msg[:1000] + ('...' if len(simplified_msg) > 1000 else '')
980
+ elements.append(Paragraph(simplified_msg, styles['Normal']))
981
+ elements.append(Spacer(1, 0.2*inch))
982
+
983
+ else:
984
+ elements.append(Paragraph('No conversation history available.', styles['Normal']))
985
+
986
+
987
+
988
+ # Force a page break before the dashboard
989
+
990
+ elements.append(PageBreak())
991
+
992
+ # Add dashboard as a single screenshot (on a new page)
993
+
994
+ elements.append(Paragraph('Dashboard Overview', styles['Heading2']))
995
+ elements.append(Spacer(1, 0.1*inch))
996
+
997
+
998
+ # Capture the dashboard as a single image
999
+
1000
+ dashboard_img_path = capture_dashboard_screenshot()
1001
+
1002
+
1003
+
1004
+ if dashboard_img_path:
1005
+
1006
+ # Calculate available width (accounting for margins)
1007
+
1008
+ available_width = doc.width
1009
+
1010
+ # Create PIL image to get dimensions
1011
+
1012
+ pil_img = PILImage.open(dashboard_img_path)
1013
+ img_width, img_height = pil_img.size
1014
+
1015
+
1016
+ # Calculate scaling factor to fit within page width
1017
+
1018
+ scale_factor = available_width / img_width
1019
+
1020
+ # Calculate new height based on aspect ratio
1021
+
1022
+ new_height = img_height * scale_factor
1023
+
1024
+ # Add the image with scaled dimensions
1025
+
1026
+ img = Image(dashboard_img_path, width=available_width, height=new_height)
1027
+ elements.append(img)
1028
+
1029
+ else:
1030
+
1031
+ # Fallback: Add individual plots if combined dashboard fails
1032
+ elements.append(Paragraph('Dashboard Visualizations (Individual)', styles['Heading3']))
1033
+
1034
+ plot_count = 0
1035
+
1036
+ for i, plot in enumerate(st.session_state.dashboard_plots):
1037
+
1038
+ if plot is not None:
1039
+
1040
+ plot_count += 1
1041
+
1042
+ # Convert plotly figure to image
1043
+ img_bytes = io.BytesIO()
1044
+ plot.write_image(img_bytes, format='png', width=500, height=300)
1045
+ img_bytes.seek(0)
1046
+
1047
+ # Create a temporary file for the image
1048
+ temp_img_path = f"{st.session_state.temp_dir}/plot_{i}.png"
1049
+
1050
+ with open(temp_img_path, 'wb') as f:
1051
+
1052
+ f.write(img_bytes.getvalue())
1053
+
1054
+ # Add to PDF with appropriate scaling
1055
+
1056
+ img = Image(temp_img_path, width=5*inch, height=3*inch)
1057
+ elements.append(Paragraph(f'Visualization {i+1}', styles['Heading4']))
1058
+ elements.append(Spacer(1, 0.1*inch))
1059
+ elements.append(img)
1060
+ elements.append(Spacer(1, 0.2*inch))
1061
+
1062
+ if plot_count == 0:
1063
+ elements.append(Paragraph('No visualizations have been added to the dashboard.',
1064
+ styles['Normal']))
1065
+
1066
+
1067
+
1068
+ # Build the PDF
1069
+ doc.build(elements)
1070
+
1071
+ # Get the value of the buffer
1072
+ pdf_value = buffer.getvalue()
1073
+ buffer.close()
1074
+
1075
+ return pdf_value
1076
+
1077
+ except Exception as e:
1078
+
1079
+ import traceback
1080
+ print(f"Error generating enhanced PDF report: {str(e)}")
1081
+ print(traceback.format_exc())
1082
+ return None
1083
+
1084
+ def chat_with_workflow(message, history, dataset_info):
1085
+
1086
+ """Send user query to the workflow and get response"""
1087
+
1088
+ if not dataset_info:
1089
+ return "Please upload at least one dataset before asking questions."
1090
+
1091
+ print(f"Chat with workflow called with {len(dataset_info)} datasets")
1092
+
1093
+ try:
1094
+
1095
+ # Extract chat history for context
1096
+
1097
+ previous_messages = []
1098
+
1099
+ for exchange in history:
1100
+
1101
+ if exchange[0]: # User message
1102
+ previous_messages.append(HumanMessage(content=exchange[0]))
1103
+
1104
+ if exchange[1]: # AI response
1105
+ previous_messages.append(AIMessage(content=exchange[1]))
1106
+
1107
+ # Initialize the workflow state
1108
+
1109
+ state = AgentState(
1110
+
1111
+ messages=previous_messages + [HumanMessage(content=message)],
1112
+ input_data=dataset_info,
1113
+ intermediate_outputs=[],
1114
+ current_variables=st.session_state.persistent_vars,
1115
+ output_image_paths=[]
1116
+
1117
+ )
1118
+
1119
+ # Execute the workflow
1120
+
1121
+ print("Executing workflow...")
1122
+
1123
+ result = chain.invoke(state)
1124
+
1125
+ print("Workflow execution completed")
1126
+
1127
+ # Extract messages from the result
1128
+ messages = result["messages"]
1129
+
1130
+ # Format the response
1131
+
1132
+ response = ""
1133
+
1134
+ for msg in messages:
1135
+ if hasattr(msg, "content"):
1136
+ response += msg.content + "\n\n"
1137
+
1138
+ # Check if there are any visualization images
1139
+
1140
+ if "output_image_paths" in result and result["output_image_paths"]:
1141
+ response += "### Visualizations\n\n"
1142
+ for img_path in result["output_image_paths"]:
1143
+
1144
+ try:
1145
+ full_path = os.path.join(st.session_state.images_dir, img_path)
1146
+ with open(full_path, 'rb') as f:
1147
+ fig = pickle.load(f)
1148
+
1149
+ # Convert plotly figure to image
1150
+ img_bytes = BytesIO()
1151
+ fig.update_layout(width=800, height=500)
1152
+ pio.write_image(fig, img_bytes, format='png')
1153
+ img_bytes.seek(0)
1154
+
1155
+ # Convert to base64 for markdown image
1156
+
1157
+ b64_img = base64.b64encode(img_bytes.read()).decode()
1158
+ response += f"![Visualization](data:image/png;base64,{b64_img})\n\n"
1159
+
1160
+ except Exception as e:
1161
+ response += f"Error loading visualization: {str(e)}\n\n"
1162
+ return response
1163
+
1164
+ except Exception as e:
1165
+
1166
+ import traceback
1167
+ print(f"Error in chat_with_workflow: {str(e)}")
1168
+ print(traceback.format_exc())
1169
+ return f"Error executing workflow: {str(e)}"
1170
+
1171
+ def auto_generate_dashboard(dataset_info):
1172
+
1173
+ """Generate an automatic dashboard with four plots"""
1174
+
1175
+ if not dataset_info:
1176
+ return "Please upload a dataset first.", [None, None, None, None]
1177
+
1178
+ prompt = """
1179
+
1180
+ You are a data visualization expert. Given a dataset, identify the top 4 most insightful plots using statistical reasoning or patterns (correlation, distribution, trends).
1181
+
1182
+ Use plotly and store the plots in a list named plotly_figures.
1183
+
1184
+ Include multivariate plots using color/size/facets when helpful.
1185
+
1186
+ """
1187
+
1188
+ state = AgentState(
1189
+ messages=[HumanMessage(content=prompt)],
1190
+ input_data=dataset_info,
1191
+ intermediate_outputs=[],
1192
+ current_variables=st.session_state.persistent_vars,
1193
+ output_image_paths=[]
1194
+ )
1195
+
1196
+ result = chain.invoke(state)
1197
+ figures = []
1198
+
1199
+ if "output_image_paths" in result:
1200
+
1201
+ for img_path in result["output_image_paths"][:4]:
1202
+
1203
+ try:
1204
+
1205
+ full_path = os.path.join(st.session_state.images_dir, img_path)
1206
+ with open(full_path, 'rb') as f:
1207
+
1208
+ fig = pickle.load(f)
1209
+
1210
+ figures.append(fig)
1211
+
1212
+ except Exception as e:
1213
+
1214
+ print(f"Error loading figure: {e}")
1215
+
1216
+ while len(figures) < 4:
1217
+
1218
+ figures.append(None)
1219
+
1220
+ st.session_state.dashboard_plots = figures
1221
+ return "Dashboard generated!", figures
1222
+
1223
+
1224
+ def generate_custom_plots_with_llm(dataset_info, x_col, y_col, facet_col):
1225
+
1226
+ """Generate custom plots based on user-selected columns"""
1227
+
1228
+ if not dataset_info or not x_col or not y_col:
1229
+
1230
+ return [None, None, None]
1231
+
1232
+ prompt = f"""
1233
+
1234
+ You are a data visualization expert.
1235
+
1236
+ Create 3 insightful visualizations using Plotly based on:
1237
+
1238
+ - X-axis: {x_col}
1239
+
1240
+ - Y-axis: {y_col}
1241
+
1242
+ - Facet (optional): {facet_col if facet_col != 'None' else 'None'}
1243
+
1244
+ Try to find interesting relationships, trends, or clusters using appropriate chart types.
1245
+
1246
+ Use `plotly_figures` list and avoid using fig.show().
1247
+
1248
+ """
1249
+
1250
+ state = AgentState(
1251
+ messages=[HumanMessage(content=prompt)],
1252
+ input_data=dataset_info,
1253
+ intermediate_outputs=[],
1254
+ current_variables=st.session_state.persistent_vars,
1255
+ output_image_paths=[]
1256
+ )
1257
+
1258
+ result = chain.invoke(state)
1259
+
1260
+ figures = []
1261
+
1262
+ if "output_image_paths" in result:
1263
+
1264
+ for img_path in result["output_image_paths"][:3]:
1265
+
1266
+ try:
1267
+
1268
+ full_path = os.path.join(st.session_state.images_dir, img_path)
1269
+
1270
+ with open(full_path, 'rb') as f:
1271
+
1272
+ fig = pickle.load(f)
1273
+
1274
+ figures.append(fig)
1275
+
1276
+ except Exception as e:
1277
+
1278
+ print(f"Error loading figure: {e}")
1279
+
1280
+ while len(figures) < 3:
1281
+ figures.append(None)
1282
+ return figures
1283
+
1284
+ # def add_custom_to_dashboard(fig, index):
1285
+
1286
+ # """Add a custom plot to the dashboard"""
1287
+
1288
+ # if fig is not None:
1289
+
1290
+ # # Find the first empty slot
1291
+
1292
+ # for i in range(len(st.session_state.dashboard_plots)):
1293
+
1294
+ # if st.session_state.dashboard_plots[i] is None:
1295
+
1296
+ # st.session_state.dashboard_plots[i] = fig
1297
+
1298
+ # break
1299
+
1300
+ def remove_plot(index):
1301
+
1302
+ """Remove a plot from the dashboard"""
1303
+
1304
+ if 0 <= index < len(st.session_state.dashboard_plots):
1305
+ st.session_state.dashboard_plots[index] = None
1306
+
1307
+
1308
+ def respond(message):
1309
+
1310
+ """Handle chat message response"""
1311
+
1312
+ if not st.session_state.dataset_metadata_list:
1313
+ bot_message = "Please upload at least one dataset before asking questions."
1314
+
1315
+ else:
1316
+ bot_message = chat_with_workflow(message, st.session_state.chat_history, st.session_state.dataset_metadata_list)
1317
+
1318
+ st.session_state.chat_history.append((message, bot_message))
1319
+ st.rerun()
1320
+
1321
+
1322
+ def save_plot_to_dashboard(plot_index):
1323
+
1324
+ """Callback for the Add Plot button"""
1325
+
1326
+ for i in range(len(st.session_state.dashboard_plots)):
1327
+ if st.session_state.dashboard_plots[i] is None:
1328
+ # Found an empty slot
1329
+ st.session_state.dashboard_plots[i] = st.session_state.custom_plots_to_save[plot_index]
1330
+ return
1331
+
1332
+
1333
+ # Streamlit UI
1334
+ st.set_page_config(page_title="Data Analysis Assistant", layout="wide")
1335
+ st.title("Data Analysis Assistant")
1336
+ st.markdown("Upload your datasets, ask questions, and generate visualizations to gain insights.")
1337
+
1338
+ # Create tabs
1339
+ tab1, tab2, tab3, tab4, tab5 = st.tabs(["Upload Datasets", "Data Cleaning", "Chat with AI Assistant", "Auto Dashboard Generator", "Generate Report"])
1340
+
1341
+ # with tab1:
1342
+
1343
+ # st.header("Upload Datasets")
1344
+ # uploaded_files = st.file_uploader("Upload CSV or Excel Files",
1345
+ # accept_multiple_files=True,
1346
+ # type=['csv', 'xlsx', 'xls'])
1347
+
1348
+
1349
+
1350
+ # if uploaded_files and st.button("Process Uploaded Files"):
1351
+
1352
+ # with st.spinner("Processing files..."):
1353
+ # preview, _, columns = process_file_upload(uploaded_files)
1354
+ # st.markdown(preview)
1355
+ # st.session_state.columns = columns
1356
+ # st.rerun()
1357
+
1358
+ with tab1:
1359
+ st.header("Upload Datasets")
1360
+ uploaded_files = st.file_uploader("Upload CSV or Excel Files",
1361
+ accept_multiple_files=True,
1362
+ type=['csv', 'xlsx', 'xls'])
1363
+
1364
+ if uploaded_files and st.button("Process Uploaded Files"):
1365
+ with st.spinner("Processing files..."):
1366
+ preview, metadata_list, columns = process_file_upload(uploaded_files)
1367
+ st.session_state.columns = columns
1368
+
1369
+ # Display basic information about processed files
1370
+ st.success(f"✅ Successfully processed {len(uploaded_files)} file(s)")
1371
+
1372
+ # Show detailed preview for each dataset
1373
+ st.subheader("Dataset Previews")
1374
+
1375
+ for dataset_name, df in st.session_state.in_memory_datasets.items():
1376
+ with st.expander(f"Preview: {dataset_name}"):
1377
+ # Display dataset info
1378
+ st.write(f"**Rows:** {df.shape[0]} | **Columns:** {df.shape[1]}")
1379
+
1380
+ # Display column information
1381
+ col_info = pd.DataFrame({
1382
+ 'Column Name': df.columns,
1383
+ 'Data Type': df.dtypes.astype(str),
1384
+ 'Non-Null Count': df.count().values,
1385
+ 'Sample Values': [', '.join(df[col].dropna().astype(str).head(3).tolist()) for col in df.columns]
1386
+ })
1387
+
1388
+ # Show column information in a compact table
1389
+ st.write("**Column Information:**")
1390
+ st.dataframe(col_info, use_container_width=True)
1391
+
1392
+ # Show actual data preview
1393
+ st.write("**Data Preview (First 10 rows):**")
1394
+ st.dataframe(df.head(10), use_container_width=True)
1395
+
1396
+ # Provide hint for the next steps
1397
+ st.info("👆 Click on the dataset names above to see detailed previews. Then proceed to the Data Cleaning tab to clean your data or Chat with AI Assistant to analyze it.")
1398
+
1399
+ with tab2:
1400
+ st.header("Data Cleaning")
1401
+
1402
+ if 'cleaning_done' not in st.session_state:
1403
+ st.session_state.cleaning_done = False
1404
+
1405
+ if 'cleaned_datasets' not in st.session_state:
1406
+ st.session_state.cleaned_datasets = {}
1407
+
1408
+ if 'cleaning_summaries' not in st.session_state:
1409
+ st.session_state.cleaning_summaries = {}
1410
+
1411
+ if st.session_state.get("in_memory_datasets"):
1412
+ if not st.session_state.cleaning_done:
1413
+ if st.button("Run Data Cleaning"):
1414
+ with st.spinner("Running LLM-assisted cleaning..."):
1415
+ for name, df in st.session_state.in_memory_datasets.items():
1416
+ raw_df = df.copy()
1417
+ df_std = standard_clean(raw_df.copy())
1418
+ suggestions = llm_suggest_cleaning(df_std.copy())
1419
+ df_clean = apply_suggestions(df_std.copy(), suggestions)
1420
+ st.session_state.cleaned_datasets[name] = df_clean
1421
+ st.session_state.cleaning_summaries[name] = suggestions
1422
+ st.session_state.cleaning_done = True
1423
+ st.rerun()
1424
+ else:
1425
+ st.info("Click Run Data Cleaning to clean your datasets using the LLM.")
1426
+ else:
1427
+ for name, df_clean in st.session_state.cleaned_datasets.items():
1428
+ raw_df = st.session_state.in_memory_datasets[name]
1429
+
1430
+ st.subheader(f"Dataset: {name}")
1431
+ col1, col2 = st.columns(2)
1432
+
1433
+ with col1:
1434
+ st.markdown("Original Data (First 5 Rows)")
1435
+ st.dataframe(raw_df.head())
1436
+
1437
+ with col2:
1438
+ st.markdown("Cleaned Data (First 5 Rows)")
1439
+ st.dataframe(df_clean.head())
1440
+
1441
+ st.markdown("Summary of Cleaning Actions")
1442
+ suggestions = st.session_state.cleaning_summaries[name]
1443
+ summary_text = ""
1444
+
1445
+ if suggestions:
1446
+ for key, value in suggestions.items():
1447
+ summary_text += f"**{key}**: {json.dumps(value, indent=2)}\n\n"
1448
+ st.markdown(summary_text)
1449
+
1450
+ st.markdown("Refine the Cleaning (Natural Language Instructions)")
1451
+ user_input = st.text_input("Example: Convert 'dob' to datetime and fill missing with '2000-01-01'",
1452
+ key=f"user_input_{name}")
1453
+
1454
+ if f'corrections_{name}' not in st.session_state:
1455
+ st.session_state[f'corrections_{name}'] = []
1456
+
1457
+ if st.button("Apply Correction", key=f'apply_correction_{name}'):
1458
+ if user_input.strip():
1459
+ correction_prompt = f"""
1460
+ You are a data cleaning expert. Below is a previously cleaned dataset with these actions:
1461
+
1462
+ {summary_text}
1463
+
1464
+ The user now wants the following additional instruction:
1465
+ \"{user_input.strip()}\"
1466
+
1467
+ Write only the Python code that modifies the pandas DataFrame `df` accordingly.
1468
+ Do not include explanations or markdown.
1469
+ """
1470
+ correction_code = query_openai(correction_prompt)
1471
+
1472
+ try:
1473
+ df = st.session_state.cleaned_datasets[name].copy()
1474
+ local_vars = {"df": df}
1475
+ exec(correction_code, {}, local_vars)
1476
+ df_updated = local_vars["df"]
1477
+
1478
+ st.session_state.cleaned_datasets[name] = df_updated
1479
+ st.session_state[f'corrections_{name}'].append((user_input, correction_code))
1480
+ st.success("Correction applied.")
1481
+ st.rerun()
1482
+
1483
+ except Exception as e:
1484
+ st.error(f"Failed to apply correction: {str(e)}")
1485
+
1486
+ if st.session_state[f'corrections_{name}']:
1487
+ st.markdown("Applied Corrections")
1488
+ for i, (msg, code) in enumerate(st.session_state[f'corrections_{name}']):
1489
+ st.markdown(f"**Instruction:** {msg}")
1490
+ st.code(code, language='python')
1491
+
1492
+ col1, col2 = st.columns([1, 2])
1493
+ with col1:
1494
+ if st.button("Reset Cleaning and Re-run"):
1495
+ st.session_state.cleaning_done = False
1496
+ st.rerun()
1497
+
1498
+ with col2:
1499
+ if st.button("Finalize and Proceed to Visualizations"):
1500
+ st.session_state.cleaning_finalized = True
1501
+ st.rerun()
1502
+ else:
1503
+ st.info("Please upload and process datasets first.")
1504
+
1505
+ with tab3:
1506
+ st.header("Chat with AI Assistant")
1507
+ st.markdown("""
1508
+
1509
+ ## Example Questions
1510
+
1511
+ - "What analysis can you perform on this dataset?"
1512
+
1513
+ - "Show me basic statistics for all columns"
1514
+
1515
+ - "Create a correlation heatmap"
1516
+
1517
+ - "Plot the distribution of a specific column"
1518
+
1519
+ - "What is the relationship between two columns?"
1520
+
1521
+ """)
1522
+
1523
+ # Display chat history
1524
+
1525
+ for exchange in st.session_state.chat_history:
1526
+
1527
+ with st.chat_message("user"):
1528
+ st.write(exchange[0])
1529
+
1530
+ with st.chat_message("assistant"):
1531
+ st.write(exchange[1])
1532
+
1533
+ # Chat input
1534
+
1535
+ if prompt := st.chat_input("Your question"):
1536
+ with st.spinner("Thinking..."):
1537
+ respond(prompt)
1538
+
1539
+ with tab4:
1540
+ st.header("Auto Dashboard Generator")
1541
+
1542
+
1543
+
1544
+ # Dashboard controls
1545
+
1546
+ dashboard_title = st.text_input("Dashboard Title", placeholder="Enter your dashboard title")
1547
+
1548
+
1549
+ col1, col2 = st.columns(2)
1550
+
1551
+ with col1:
1552
+ if st.button("Generate Suggested Dashboard (Auto)"):
1553
+ with st.spinner("Generating dashboard..."):
1554
+ message, figures = auto_generate_dashboard(st.session_state.dataset_metadata_list)
1555
+ st.success(message)
1556
+
1557
+ with col2:
1558
+ if st.button("Refresh Column Options"):
1559
+ st.session_state.columns = get_columns()
1560
+ st.rerun()
1561
+
1562
+
1563
+
1564
+ # Dashboard display
1565
+
1566
+ st.subheader("Dashboard")
1567
+
1568
+
1569
+ # Row 1
1570
+
1571
+ col1, col2 = st.columns(2)
1572
+
1573
+ with col1:
1574
+ if st.session_state.dashboard_plots[0]:
1575
+
1576
+ st.plotly_chart(st.session_state.dashboard_plots[0], use_container_width=True)
1577
+
1578
+ if st.button("Remove Plot 1"):
1579
+
1580
+ remove_plot(0)
1581
+ st.rerun()
1582
+
1583
+ with col2:
1584
+ if st.session_state.dashboard_plots[1]:
1585
+
1586
+ st.plotly_chart(st.session_state.dashboard_plots[1], use_container_width=True)
1587
+
1588
+ if st.button("Remove Plot 2"):
1589
+ remove_plot(1)
1590
+ st.rerun()
1591
+
1592
+
1593
+
1594
+ # Row 2
1595
+
1596
+ col3, col4 = st.columns(2)
1597
+
1598
+ with col3:
1599
+
1600
+ if st.session_state.dashboard_plots[2]:
1601
+
1602
+ st.plotly_chart(st.session_state.dashboard_plots[2], use_container_width=True)
1603
+
1604
+ if st.button("Remove Plot 3"):
1605
+ remove_plot(2)
1606
+ st.rerun()
1607
+
1608
+ with col4:
1609
+
1610
+ if st.session_state.dashboard_plots[3]:
1611
+
1612
+ st.plotly_chart(st.session_state.dashboard_plots[3], use_container_width=True)
1613
+
1614
+ if st.button("Remove Plot 4"):
1615
+
1616
+ remove_plot(3)
1617
+ st.rerun()
1618
+
1619
+ # Custom plot generator
1620
+
1621
+ st.subheader("Custom Plot Generator")
1622
+
1623
+ # Column selection
1624
+ col1, col2, col3 = st.columns(3)
1625
+
1626
+ with col1:
1627
+ x_axis = st.selectbox("X-axis Column", options=st.session_state.columns)
1628
+
1629
+ with col2:
1630
+ y_axis = st.selectbox("Y-axis Column", options=st.session_state.columns)
1631
+
1632
+ with col3:
1633
+ facet = st.selectbox("Facet (optional)", options=["None"] + st.session_state.columns)
1634
+
1635
+
1636
+ if st.button("Generate Custom Visualizations"):
1637
+
1638
+ with st.spinner("Generating custom visualizations..."):
1639
+
1640
+ custom_plots = generate_custom_plots_with_llm(st.session_state.dataset_metadata_list, x_axis, y_axis, facet)
1641
+ # Store plots in session state
1642
+ for i, plot in enumerate(custom_plots):
1643
+
1644
+ if plot:
1645
+ st.session_state.custom_plots_to_save[i] = plot
1646
+
1647
+ # Display custom plots with add buttons
1648
+
1649
+ for i, plot in enumerate(custom_plots):
1650
+
1651
+ if plot:
1652
+
1653
+ st.plotly_chart(plot, use_container_width=True)
1654
+
1655
+ st.button(
1656
+ f"Add Plot {i+1} to Dashboard",
1657
+ key=f"add_plot_{i}",
1658
+ on_click=save_plot_to_dashboard,
1659
+ args=(i,)
1660
+
1661
+ )
1662
+
1663
+
1664
+ with tab5:
1665
+
1666
+ st.header("Generate Analysis Report")
1667
+
1668
+ st.markdown("""
1669
+
1670
+ Generate a PDF report containing:
1671
+
1672
+ - Dashboard visualizations
1673
+
1674
+ - Chat conversation history
1675
+
1676
+ """)
1677
+
1678
+ report_title = st.text_input("Report Title (Optional)", "Data Analysis Report")
1679
+
1680
+
1681
+
1682
+ if st.button("Generate PDF Report"):
1683
+
1684
+ with st.spinner("Generating report..."):
1685
+
1686
+ pdf_data = generate_enhanced_pdf_report()
1687
+
1688
+ if pdf_data:
1689
+ # Create download button for PDF
1690
+ b64_pdf = base64.b64encode(pdf_data).decode('utf-8')
1691
+
1692
+
1693
+
1694
+ # Create download link
1695
+ pdf_download_link = f'<a href="data:application/pdf;base64,{b64_pdf}" download="data_analysis_report.pdf">Download PDF Report</a>'
1696
+
1697
+ st.markdown("### Your report is ready!")
1698
+
1699
+ st.markdown(pdf_download_link, unsafe_allow_html=True)
1700
+
1701
+ # Preview option (simplified)
1702
+
1703
+ with st.expander("Preview Report"):
1704
+ st.warning("PDF preview is not available in Streamlit, please download the report to view it.")
1705
+
1706
+ else:
1707
+ st.error("Failed to generate the report. Please try again.")
1708
+
1709
+
1710
+
1711
+ # Cleanup on app exit
1712
+ def cleanup():
1713
+ try:
1714
+ shutil.rmtree(st.session_state.temp_dir)
1715
+
1716
+ print(f"Cleaned up temporary directory: {st.session_state.temp_dir}")
1717
+
1718
+ except Exception as e:
1719
+ print(f"Error cleaning up: {e}")
1720
+
1721
+ import atexit
1722
+
1723
+ atexit.register(cleanup)
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ pandas
3
+ numpy
4
+ plotly
5
+ langchain
6
+ langgraph==0.2.74
7
+ langchain-core
8
+ langchain-groq
9
+ openai
10
+ langchain-experimental
11
+ reportlab
12
+ Pillow
13
+ scikit-learn
14
+ tabulate
15
+ kaleido