RAVENOCC commited on
Commit
f2ed321
·
verified ·
1 Parent(s): 4cfe76c

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +1571 -0
  2. requirements.txt +15 -0
app.py ADDED
@@ -0,0 +1,1571 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import pickle
5
+ import base64
6
+ from io import BytesIO, StringIO
7
+ import sys
8
+ import operator
9
+ from typing import Literal, Sequence, TypedDict, Annotated, List, Dict, Tuple
10
+ import tempfile
11
+ import shutil
12
+ import plotly.io as pio
13
+ import io
14
+ import re
15
+ import json
16
+ import openai
17
+ # from fpdf import FPDF
18
+ import base64
19
+ from datetime import datetime
20
+ from reportlab.lib.pagesizes import letter
21
+ from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image
22
+ from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
23
+ from reportlab.lib.units import inch
24
+ from PIL import Image as PILImage
25
+
26
+ # Import LangChain and LangGraph components
27
+ from langchain_core.messages import AIMessage, ToolMessage, HumanMessage, BaseMessage
28
+ from langchain_core.prompts import ChatPromptTemplate
29
+ from langchain_openai import ChatOpenAI
30
+ from langchain_experimental.utilities import PythonREPL
31
+ from langgraph.prebuilt import ToolInvocation, ToolExecutor
32
+ from langchain_core.tools import tool
33
+ from langgraph.prebuilt import InjectedState
34
+ from langgraph.graph import StateGraph, END
35
+ from reportlab.platypus import PageBreak
36
+
37
+
38
+ # Initialize session state for AI provider settings
39
+ if 'ai_provider' not in st.session_state:
40
+ st.session_state.ai_provider = "openai"
41
+
42
+ if 'api_key' not in st.session_state:
43
+ st.session_state.api_key = ""
44
+
45
+ if 'selected_model' not in st.session_state:
46
+ st.session_state.selected_model = "gpt-4"
47
+
48
+ # Define model options for each provider
49
+ OPENAI_MODELS = ["gpt-4", "gpt-4-turbo", "gpt-4-mini", "gpt-3.5-turbo"]
50
+ GROQ_MODELS = ["llama3.3-70b-versatile", "gemma2-9b-it", "llama-3-8b-8192"]
51
+
52
+ # Create temporary directory for file storage
53
+ if 'temp_dir' not in st.session_state:
54
+ st.session_state.temp_dir = tempfile.mkdtemp()
55
+ st.session_state.images_dir = os.path.join(st.session_state.temp_dir, "images/plotly_figures/pickle")
56
+ os.makedirs(st.session_state.images_dir, exist_ok=True)
57
+ print(f"Created temporary directory: {st.session_state.temp_dir}")
58
+ print(f"Created images directory: {st.session_state.images_dir}")
59
+
60
+ # Define the system prompt
61
+ SYSTEM_PROMPT = """## Role
62
+ You are a professional data scientist helping a non-technical user understand, analyze, and visualize their data.
63
+
64
+ ## Capabilities
65
+ 1. **Execute python code** using the `complete_python_task` tool.
66
+
67
+ ## Goals
68
+ 1. Understand the user's objectives clearly.
69
+ 2. Take the user on a data analysis journey, iterating to find the best way to visualize or analyse their data to solve their problems.
70
+ 3. Investigate if the goal is achievable by running Python code via the `python_code` field.
71
+ 4. Gain input from the user at every step to ensure the analysis is on the right track and to understand business nuances.
72
+
73
+ ## Code Guidelines
74
+ - **ALL INPUT DATA IS LOADED ALREADY**, so use the provided variable names to access the data.
75
+ - **VARIABLES PERSIST BETWEEN RUNS**, so reuse previously defined variables if needed.
76
+ - **TO SEE CODE OUTPUT**, use `print()` statements. You won't be able to see outputs of `pd.head()`, `pd.describe()` etc. otherwise.
77
+ - **ONLY USE THE FOLLOWING LIBRARIES**:
78
+ - `pandas`
79
+ - `sklearn` (including all major ML models)
80
+ - `plotly`
81
+ - `numpy`
82
+
83
+ All these libraries are already imported for you.
84
+
85
+ ## Machine Learning Guidelines
86
+ - For regression tasks:
87
+ - Linear Regression: `LinearRegression`
88
+ - Logistic Regression: `LogisticRegression`
89
+ - Ridge Regression: `Ridge`
90
+ - Lasso Regression: `Lasso`
91
+ - Random Forest Regression: `RandomForestRegressor`
92
+
93
+ - For classification tasks:
94
+ - Logistic Regression: `LogisticRegression`
95
+ - Decision Trees: `DecisionTreeClassifier`
96
+ - Random Forests: `RandomForestClassifier`
97
+ - Support Vector Machines: `SVC`
98
+ - K-Nearest Neighbors: `KNeighborsClassifier`
99
+ - Naive Bayes: `GaussianNB`
100
+
101
+ - For clustering:
102
+ - K-Means: `KMeans`
103
+ - DBSCAN: `DBSCAN`
104
+
105
+ - For dimensionality reduction:
106
+ - PCA: `PCA`
107
+
108
+ - Always preprocess data appropriately:
109
+ - Scale numerical features with `StandardScaler` or `MinMaxScaler`
110
+ - Encode categorical variables with `OneHotEncoder` when needed
111
+ - Handle missing values with `SimpleImputer`
112
+
113
+ - Always split data into training and testing sets using `train_test_split`
114
+ - Evaluate models using appropriate metrics:
115
+ - For regression: `mean_squared_error`, `mean_absolute_error`, `r2_score`
116
+ - For classification: `accuracy_score`, `confusion_matrix`, `classification_report`
117
+ - For clustering: `silhouette_score`
118
+
119
+ - Consider using `cross_val_score` for more robust evaluation
120
+ - Visualize ML results with plotly when possible
121
+
122
+ ## Plotting Guidelines
123
+ - Always use the `plotly` library for plotting.
124
+ - Store all plotly figures inside a `plotly_figures` list, they will be saved automatically.
125
+ - Do not try and show the plots inline with `fig.show()`.
126
+ """
127
+
128
+ # Define the State class
129
+ class AgentState(TypedDict):
130
+ messages: Annotated[Sequence[BaseMessage], operator.add]
131
+ input_data: Annotated[List[Dict], operator.add]
132
+ intermediate_outputs: Annotated[List[dict], operator.add]
133
+ current_variables: dict
134
+ output_image_paths: Annotated[List[str], operator.add]
135
+
136
+ # Initialize session state variables
137
+ if 'in_memory_datasets' not in st.session_state:
138
+ st.session_state.in_memory_datasets = {}
139
+
140
+ if 'persistent_vars' not in st.session_state:
141
+ st.session_state.persistent_vars = {}
142
+
143
+ if 'dataset_metadata_list' not in st.session_state:
144
+ st.session_state.dataset_metadata_list = []
145
+
146
+ if 'chat_history' not in st.session_state:
147
+ st.session_state.chat_history = []
148
+
149
+ if 'dashboard_plots' not in st.session_state:
150
+ st.session_state.dashboard_plots = [None, None, None, None]
151
+
152
+ if 'columns' not in st.session_state:
153
+ st.session_state.columns = ["No columns available"]
154
+
155
+ if 'custom_plots_to_save' not in st.session_state:
156
+ st.session_state.custom_plots_to_save = {}
157
+
158
+ # Set up the tools
159
+ repl = PythonREPL()
160
+ plotly_saving_code = """import pickle
161
+
162
+ import uuid
163
+ import os
164
+ for figure in plotly_figures:
165
+ pickle_filename = f"{images_dir}/{uuid.uuid4()}.pickle"
166
+ with open(pickle_filename, 'wb') as f:
167
+ pickle.dump(figure, f)
168
+ """
169
+
170
+ @tool
171
+ def complete_python_task(
172
+ graph_state: Annotated[dict, InjectedState],
173
+ thought: str,
174
+ python_code: str
175
+ ) -> Tuple[str, dict]:
176
+ """Execute Python code for data analysis and visualization."""
177
+
178
+ current_variables = graph_state.get("current_variables", {})
179
+
180
+ # Load datasets from in-memory storage
181
+ for input_dataset in graph_state.get("input_data", []):
182
+ var_name = input_dataset.get("variable_name")
183
+ if var_name and var_name not in current_variables and var_name in st.session_state.in_memory_datasets:
184
+ print(f"Loading {var_name} from in-memory storage")
185
+ current_variables[var_name] = st.session_state.in_memory_datasets[var_name]
186
+ current_image_pickle_files = os.listdir(st.session_state.images_dir)
187
+
188
+ try:
189
+ # Capture stdout
190
+ old_stdout = sys.stdout
191
+ sys.stdout = StringIO()
192
+
193
+ # Execute the code and capture the result
194
+ exec_globals = globals().copy()
195
+ exec_globals.update(st.session_state.persistent_vars)
196
+ exec_globals.update(current_variables)
197
+
198
+ # Add scikit-learn modules to execution environment
199
+ import sklearn
200
+ import numpy as np
201
+
202
+ # Import scikit-learn components
203
+ from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge, Lasso # type: ignore
204
+ from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, GradientBoostingClassifier # type: ignore
205
+ from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
206
+ from sklearn.svm import SVC, SVR
207
+ from sklearn.naive_bayes import GaussianNB
208
+ from sklearn.decomposition import PCA
209
+ from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
210
+ from sklearn.cluster import KMeans, DBSCAN
211
+ from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder
212
+ from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
213
+ from sklearn.metrics import (
214
+ accuracy_score, confusion_matrix, classification_report,
215
+ mean_squared_error, r2_score, mean_absolute_error, silhouette_score
216
+ )
217
+ from sklearn.pipeline import Pipeline
218
+ from sklearn.impute import SimpleImputer
219
+
220
+ # Update execution globals with all ML components
221
+ exec_globals.update({
222
+ "plotly_figures": [],
223
+ "images_dir": st.session_state.images_dir,
224
+ "np": np,
225
+ # Linear models
226
+ "LinearRegression": LinearRegression,
227
+ "LogisticRegression": LogisticRegression,
228
+ "Ridge": Ridge,
229
+ "Lasso": Lasso,
230
+ # Tree-based models
231
+ "DecisionTreeClassifier": DecisionTreeClassifier,
232
+ "DecisionTreeRegressor": DecisionTreeRegressor,
233
+ "RandomForestClassifier": RandomForestClassifier,
234
+ "RandomForestRegressor": RandomForestRegressor,
235
+ "GradientBoostingClassifier": GradientBoostingClassifier,
236
+ # SVM models
237
+ "SVC": SVC,
238
+ "SVR": SVR,
239
+ # Other models
240
+ "GaussianNB": GaussianNB,
241
+ "PCA": PCA,
242
+ "KNeighborsClassifier": KNeighborsClassifier,
243
+ "KNeighborsRegressor": KNeighborsRegressor,
244
+ "KMeans": KMeans,
245
+ "DBSCAN": DBSCAN,
246
+ # Preprocessing
247
+ "StandardScaler": StandardScaler,
248
+ "MinMaxScaler": MinMaxScaler,
249
+ "OneHotEncoder": OneHotEncoder,
250
+ "SimpleImputer": SimpleImputer,
251
+ # Model selection and evaluation
252
+ "train_test_split": train_test_split,
253
+ "cross_val_score": cross_val_score,
254
+ "GridSearchCV": GridSearchCV,
255
+ "accuracy_score": accuracy_score,
256
+ "confusion_matrix": confusion_matrix,
257
+ "classification_report": classification_report,
258
+ "mean_squared_error": mean_squared_error,
259
+ "r2_score": r2_score,
260
+ "mean_absolute_error": mean_absolute_error,
261
+ "silhouette_score": silhouette_score,
262
+ # Pipeline
263
+ "Pipeline": Pipeline
264
+ })
265
+
266
+ exec(python_code, exec_globals)
267
+
268
+ st.session_state.persistent_vars.update({k: v for k, v in exec_globals.items() if k not in globals()})
269
+
270
+ # Get the captured stdout
271
+ output = sys.stdout.getvalue()
272
+
273
+ # Restore stdout
274
+ sys.stdout = old_stdout
275
+
276
+ updated_state = {
277
+ "intermediate_outputs": [{"thought": thought, "code": python_code, "output": output}],
278
+ "current_variables": st.session_state.persistent_vars
279
+ }
280
+
281
+ if 'plotly_figures' in exec_globals and exec_globals['plotly_figures']:
282
+ exec(plotly_saving_code, exec_globals)
283
+
284
+ # Check if any images were created
285
+ new_image_folder_contents = os.listdir(st.session_state.images_dir)
286
+ new_image_files = [file for file in new_image_folder_contents if file not in current_image_pickle_files]
287
+
288
+ if new_image_files:
289
+ updated_state["output_image_paths"] = new_image_files
290
+ st.session_state.persistent_vars["plotly_figures"] = []
291
+ return output, updated_state
292
+
293
+ except Exception as e:
294
+ sys.stdout = old_stdout # Restore stdout in case of error
295
+ print(f"Error in complete_python_task: {str(e)}")
296
+ return str(e), {"intermediate_outputs": [{"thought": thought, "code": python_code, "output": str(e)}]}
297
+
298
+ # Function to initialize the LLM based on selected provider and model
299
+ def initialize_llm():
300
+ api_key = st.session_state.api_key
301
+ model = st.session_state.selected_model
302
+
303
+ if not api_key:
304
+ return None
305
+
306
+ try:
307
+ if st.session_state.ai_provider == "openai":
308
+ os.environ["OPENAI_API_KEY"] = api_key
309
+ return ChatOpenAI(model=model, temperature=0)
310
+ elif st.session_state.ai_provider == "groq":
311
+ os.environ["GROQ_API_KEY"] = api_key
312
+ # For Groq, set the base URL and use the model
313
+ from langchain_groq import ChatGroq
314
+ return ChatGroq(model=model, temperature=0)
315
+ except Exception as e:
316
+ print(f"Error initializing LLM: {str(e)}")
317
+ return None
318
+
319
+ # Set up the tools
320
+ tools = [complete_python_task]
321
+ tool_executor = ToolExecutor(tools)
322
+
323
+ # Load the prompt template
324
+ chat_template = ChatPromptTemplate.from_messages([
325
+ ("system", SYSTEM_PROMPT),
326
+ ("placeholder", "{messages}"),
327
+ ])
328
+
329
+ def create_data_summary(state: AgentState) -> str:
330
+ summary = ""
331
+ variables = []
332
+
333
+ # Add sample data for each dataset
334
+ for d in state.get("input_data", []):
335
+ var_name = d.get("variable_name")
336
+ if var_name:
337
+
338
+ variables.append(var_name)
339
+ summary += f"\n\nVariable: {var_name}\n"
340
+ summary += f"Description: {d.get('data_description', 'No description')}\n"
341
+
342
+ # Add sample data if available
343
+ if var_name in st.session_state.in_memory_datasets:
344
+ df = st.session_state.in_memory_datasets[var_name]
345
+ summary += "\nSample Data (first 5 rows):\n"
346
+ summary += df.head(5).to_string()
347
+
348
+ if "current_variables" in state:
349
+ remaining_variables = [v for v in state["current_variables"] if v not in variables and not v.startswith("_")]
350
+
351
+ for v in remaining_variables:
352
+
353
+ var_value = state["current_variables"].get(v)
354
+
355
+ if isinstance(var_value, pd.DataFrame):
356
+ summary += f"\n\nVariable: {v} (DataFrame with shape {var_value.shape})"
357
+ else:
358
+ summary += f"\n\nVariable: {v}"
359
+ return summary
360
+
361
+ def route_to_tools(state: AgentState) -> Literal["tools", "__end__"]:
362
+ """Determine if we should route to tools or end the chain"""
363
+ if messages := state.get("messages", []):
364
+ ai_message = messages[-1]
365
+ else:
366
+ raise ValueError(f"No messages found in input state to tool_edge: {state}")
367
+
368
+ if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
369
+ return "tools"
370
+
371
+ return "__end__"
372
+
373
+ def call_model(state: AgentState):
374
+ """Call the LLM to get a response"""
375
+ current_data_template = """The following data is available:\n{data_summary}"""
376
+ current_data_message = HumanMessage(
377
+ content=current_data_template.format(data_summary=create_data_summary(state))
378
+ )
379
+ messages = [current_data_message] + state["messages"]
380
+
381
+ # Get the initialized LLM
382
+ llm = initialize_llm()
383
+ if llm is None:
384
+ return {"messages": [AIMessage(content="Please configure a valid API key and model in the settings tab.")]}
385
+
386
+ # Create the model with bound tools
387
+ model = llm.bind_tools(tools)
388
+ model = chat_template | model
389
+
390
+ llm_outputs = model.invoke({"messages": messages})
391
+ return {"messages": [llm_outputs], "intermediate_outputs": [current_data_message.content]}
392
+
393
+ def call_tools(state: AgentState):
394
+ """Execute tools called by the LLM"""
395
+ last_message = state["messages"][-1]
396
+ tool_invocations = []
397
+
398
+ if isinstance(last_message, AIMessage) and hasattr(last_message, 'tool_calls'):
399
+ tool_invocations = [
400
+ ToolInvocation(
401
+ tool=tool_call["name"],
402
+ tool_input={**tool_call["args"], "graph_state": state}
403
+ ) for tool_call in last_message.tool_calls
404
+ ]
405
+ responses = tool_executor.batch(tool_invocations, return_exceptions=True)
406
+
407
+ tool_messages = []
408
+ state_updates = {}
409
+
410
+ for tc, response in zip(last_message.tool_calls, responses):
411
+ if isinstance(response, Exception):
412
+ print(f"Exception in tool execution: {str(response)}")
413
+ tool_messages.append(ToolMessage(
414
+ content=f"Error: {str(response)}",
415
+ name=tc["name"],
416
+ tool_call_id=tc["id"]
417
+ ))
418
+ continue
419
+
420
+ message, updates = response
421
+ tool_messages.append(ToolMessage(
422
+ content=str(message),
423
+ name=tc["name"],
424
+ tool_call_id=tc["id"]
425
+ ))
426
+
427
+ # Merge updates instead of overwriting
428
+ for key, value in updates.items():
429
+ if key in state_updates:
430
+ if isinstance(value, list) and isinstance(state_updates[key], list):
431
+ state_updates[key].extend(value)
432
+ elif isinstance(value, dict) and isinstance(state_updates[key], dict):
433
+ state_updates[key].update(value)
434
+ else:
435
+ state_updates[key] = value
436
+ else:
437
+ state_updates[key] = value
438
+
439
+ if 'messages' not in state_updates:
440
+ state_updates["messages"] = []
441
+
442
+ state_updates["messages"] = tool_messages
443
+ return state_updates
444
+
445
+ # Set up the graph
446
+ workflow = StateGraph(AgentState)
447
+ workflow.add_node("agent", call_model)
448
+ workflow.add_node("tools", call_tools)
449
+ workflow.add_conditional_edges(
450
+ "agent",
451
+ route_to_tools,
452
+ {
453
+ "tools": "tools",
454
+ "__end__": END
455
+ }
456
+ )
457
+ workflow.add_edge("tools", "agent")
458
+ workflow.set_entry_point("agent")
459
+
460
+ chain = workflow.compile()
461
+
462
+ def process_file_upload(files):
463
+ """Process uploaded files and return dataframe previews and column names"""
464
+ st.session_state.in_memory_datasets = {} # Clear previous datasets
465
+ st.session_state.dataset_metadata_list = [] # Clear previous metadata
466
+ st.session_state.persistent_vars.clear() # Clear persistent variables for new session
467
+
468
+ if not files:
469
+ return "No files uploaded.", [], ["No columns available"]
470
+
471
+ results = []
472
+ all_columns = [] # Track all columns from all datasets
473
+
474
+ for file in files:
475
+ try:
476
+ # Use file object directly
477
+ if file.name.endswith('.csv'):
478
+ df = pd.read_csv(file)
479
+ elif file.name.endswith(('.xls', '.xlsx')):
480
+ df = pd.read_excel(file)
481
+ else:
482
+ results.append(f"Unsupported file format: {file.name}. Please upload CSV or Excel files.")
483
+ continue
484
+
485
+ var_name = file.name.split('.')[0].replace('-', '_').replace(' ', '_').lower()
486
+ st.session_state.in_memory_datasets[var_name] = df
487
+
488
+ # Collect all columns
489
+ all_columns.extend(df.columns.tolist())
490
+
491
+ # Create dataset metadata
492
+ dataset_metadata = {
493
+ "variable_name": var_name,
494
+ "data_path": "in_memory",
495
+ "data_description": f"Dataset containing {df.shape[0]} rows and {df.shape[1]} columns. Columns: {', '.join(df.columns.tolist())}",
496
+ "original_filename": file.name
497
+ }
498
+
499
+ st.session_state.dataset_metadata_list.append(dataset_metadata)
500
+
501
+ # Return preview of the dataset
502
+ preview = f"### Dataset: {file.name}\nVariable name: `{var_name}`\n\n"
503
+ preview += df.head(10).to_markdown()
504
+ results.append(preview)
505
+ print(f"Successfully processed {file.name}")
506
+
507
+ except Exception as e:
508
+ print(f"Error processing {file.name}: {str(e)}")
509
+ results.append(f"Error processing {file.name}: {str(e)}")
510
+
511
+ # Get unique columns
512
+ unique_columns = []
513
+ seen = set()
514
+
515
+ for col in all_columns:
516
+ if col not in seen:
517
+ seen.add(col)
518
+ unique_columns.append(col)
519
+
520
+ if not unique_columns:
521
+ unique_columns = ["No columns available"]
522
+
523
+ print(f"Found {len(unique_columns)} unique columns across datasets")
524
+ return "\n\n".join(results), st.session_state.dataset_metadata_list, unique_columns
525
+
526
+ def get_columns():
527
+ """Directly gets columns from in-memory datasets"""
528
+ all_columns = []
529
+
530
+ for var_name, df in st.session_state.in_memory_datasets.items():
531
+ if isinstance(df, pd.DataFrame):
532
+ all_columns.extend(df.columns.tolist())
533
+
534
+ # Remove duplicates while preserving order
535
+ unique_columns = []
536
+ seen = set()
537
+
538
+ for col in all_columns:
539
+ if col not in seen:
540
+ seen.add(col)
541
+ unique_columns.append(col)
542
+
543
+ if not unique_columns:
544
+ unique_columns = ["No columns available"]
545
+
546
+ print(f"Populating dropdowns with {len(unique_columns)} columns")
547
+ return unique_columns
548
+
549
+ # === FUNCTIONS ===
550
+ import openai
551
+ import pandas as pd
552
+ import json
553
+ import re
554
+
555
+ def standard_clean(df):
556
+ df.columns = [re.sub(r'\W+', '_', col).strip().lower() for col in df.columns]
557
+ df.drop_duplicates(inplace=True)
558
+ df.dropna(axis=1, how='all', inplace=True)
559
+ df.dropna(axis=0, how='all', inplace=True)
560
+ for col in df.select_dtypes(include='object').columns:
561
+ df[col] = df[col].astype(str).str.strip()
562
+ return df
563
+
564
+ def query_openai(prompt):
565
+ try:
566
+ # Use the configured API key and model from session state
567
+ api_key = st.session_state.api_key
568
+ model = st.session_state.selected_model
569
+
570
+ if st.session_state.ai_provider == "openai":
571
+ client = openai.OpenAI(api_key=api_key)
572
+ response = client.chat.completions.create(
573
+ model=model,
574
+ messages=[{"role": "user", "content": prompt}],
575
+ temperature=0.7
576
+ )
577
+ return response.choices[0].message.content
578
+ elif st.session_state.ai_provider == "groq":
579
+ from groq import Groq
580
+ client = Groq(api_key=api_key)
581
+ response = client.chat.completions.create(
582
+ model=model,
583
+ messages=[{"role": "user", "content": prompt}],
584
+ temperature=0.7
585
+ )
586
+ return response.choices[0].message.content
587
+ except Exception as e:
588
+ print(f"API Error: {e}")
589
+ return "{}"
590
+
591
+ def llm_suggest_cleaning(df):
592
+ sample = df.head(10).to_csv(index=False)
593
+ prompt = f"""
594
+ You are a professional data wrangler. Below is a sample of a messy dataset.
595
+
596
+ Return a Python dictionary with the following keys:
597
+
598
+ 1. rename_columns – fix unclear or inconsistent column names
599
+ 2. convert_types – correct datatypes: int, float, str, or date
600
+ 3. fill_missing – use 'mean', 'median', 'mode', or a constant like 'Unknown' or 0
601
+ 4. value_map – map inconsistent values (e.g., yes/Yes/Y → Yes)
602
+
603
+ Do not drop any rows or columns. Your output must be a valid Python dict.
604
+
605
+ Example:
606
+ {{
607
+ "rename_columns": {{"dob": "date_of_birth"}},
608
+ "convert_types": {{"age": "int", "salary": "float", "signup_date": "date"}},
609
+ "fill_missing": {{"gender": "mode", "salary": -1}},
610
+ "value_map": {{
611
+ "gender": {{"M": "Male", "F": "Female"}},
612
+ "subscribed": {{"Y": "Yes", "N": "No"}}
613
+ }}
614
+ }}
615
+ Apart from these mentioned steps, study the data and also do whatever things are good and needed for that particular dataset and do the cleaning.
616
+ Sample data:
617
+ {sample}
618
+ """
619
+ raw_response = query_openai(prompt)
620
+ try:
621
+ suggestions = eval(raw_response)
622
+ return suggestions
623
+ except:
624
+ print("Could not parse suggestions.")
625
+ return {
626
+ "rename_columns": {},
627
+ "convert_types": {},
628
+ "fill_missing": {},
629
+ "value_map": {}
630
+ }
631
+
632
+ def apply_suggestions(df, suggestions):
633
+ df.rename(columns=suggestions.get("rename_columns", {}), inplace=True)
634
+
635
+ for col, dtype in suggestions.get("convert_types", {}).items():
636
+ if col not in df.columns:
637
+ continue
638
+ try:
639
+ if dtype == "int":
640
+ df[col] = pd.to_numeric(df[col], errors='coerce').astype("Int64")
641
+ elif dtype == "float":
642
+ df[col] = pd.to_numeric(df[col], errors='coerce')
643
+ elif dtype == "str":
644
+ df[col] = df[col].astype(str)
645
+ elif dtype == "date":
646
+ df[col] = pd.to_datetime(df[col], errors='coerce')
647
+ except:
648
+ print(f"Failed to convert {col} to {dtype}")
649
+
650
+ for col, method in suggestions.get("fill_missing", {}).items():
651
+ if col not in df.columns:
652
+ continue
653
+ try:
654
+ if method == "mean":
655
+ df[col].fillna(df[col].mean(), inplace=True)
656
+ elif method == "median":
657
+ df[col].fillna(df[col].median(), inplace=True)
658
+ elif method == "mode":
659
+ df[col].fillna(df[col].mode().iloc[0], inplace=True)
660
+ elif isinstance(method, str):
661
+ df[col].fillna(method, inplace=True)
662
+ except:
663
+ print(f"Could not fill missing values for {col}")
664
+
665
+ for col, mapping in suggestions.get("value_map", {}).items():
666
+ if col in df.columns:
667
+ try:
668
+ df[col] = df[col].replace(mapping)
669
+ except:
670
+ print(f"Could not map values in {col}")
671
+
672
+ return df
673
+
674
+ def capture_dashboard_screenshot():
675
+ """Capture the entire dashboard as a single image"""
676
+ try:
677
+ # Create a figure that combines all dashboard plots
678
+ import plotly.graph_objects as go
679
+ from plotly.subplots import make_subplots
680
+
681
+ # Create a 2x2 subplot
682
+ fig = make_subplots(rows=2, cols=2,
683
+ subplot_titles=["Visualization 1", "Visualization 2",
684
+ "Visualization 3", "Visualization 4"])
685
+
686
+ # Add each plot from the dashboard to the combined figure
687
+ for i, plot in enumerate(st.session_state.dashboard_plots):
688
+ if plot is not None:
689
+ row = (i // 2) + 1
690
+ col = (i % 2) + 1
691
+
692
+ # Extract traces from the original figure and add to our subplot
693
+ for trace in plot.data:
694
+ fig.add_trace(trace, row=row, col=col)
695
+
696
+ # Copy layout properties for each subplot
697
+ for axis_type in ['xaxis', 'yaxis']:
698
+ axis_name = f"{axis_type}{i+1 if i > 0 else ''}"
699
+ subplot_name = f"{axis_type}{row}{col}"
700
+
701
+ # Copy axis properties if they exist
702
+ if hasattr(plot.layout, axis_name):
703
+ axis_props = getattr(plot.layout, axis_name)
704
+ fig.update_layout({subplot_name: axis_props})
705
+
706
+ # Update layout for better appearance
707
+ fig.update_layout(
708
+ height=800,
709
+ width=1000,
710
+ title_text="Dashboard Overview",
711
+ showlegend=False,
712
+ )
713
+
714
+ # Save to a temporary file
715
+ dashboard_path = f"{st.session_state.temp_dir}/dashboard_combined.png"
716
+ fig.write_image(dashboard_path, scale=2) # Higher scale for better resolution
717
+ return dashboard_path
718
+
719
+ except Exception as e:
720
+ import traceback
721
+ print(f"Error capturing dashboard: {str(e)}")
722
+ print(traceback.format_exc())
723
+ return None
724
+
725
+ def generate_enhanced_pdf_report():
726
+ """Generate an enhanced PDF report with proper handling of base64 image data"""
727
+ try:
728
+ # Create a buffer for the PDF
729
+ buffer = io.BytesIO()
730
+
731
+ # Create the PDF document
732
+ doc = SimpleDocTemplate(buffer, pagesize=letter,
733
+ leftMargin=36, rightMargin=36,
734
+ topMargin=36, bottomMargin=36)
735
+
736
+ # Create custom styles with better formatting
737
+ styles = getSampleStyleSheet()
738
+
739
+ # Add custom styles with improved formatting
740
+ styles.add(ParagraphStyle(
741
+ name='ReportTitle',
742
+ parent=styles['Heading1'],
743
+ fontSize=24,
744
+ alignment=1, # Centered
745
+ spaceAfter=20,
746
+ textColor='#2C3E50' # Dark blue color
747
+ ))
748
+
749
+ styles.add(ParagraphStyle(
750
+ name='SectionHeader',
751
+ parent=styles['Heading2'],
752
+ fontSize=16,
753
+ spaceBefore=15,
754
+ spaceAfter=10,
755
+ textColor='#2C3E50',
756
+ borderWidth=1,
757
+ borderColor='#95A5A6',
758
+ borderPadding=5,
759
+ borderRadius=5
760
+ ))
761
+
762
+ styles.add(ParagraphStyle(
763
+ name='SubHeader',
764
+ parent=styles['Heading3'],
765
+ fontSize=14,
766
+ spaceBefore=10,
767
+ spaceAfter=8,
768
+ textColor='#34495E',
769
+ fontWeight='bold'
770
+ ))
771
+ styles.add(ParagraphStyle(
772
+ name='UserMessage',
773
+ parent=styles['Normal'],
774
+ fontSize=11,
775
+ leftIndent=10,
776
+ spaceBefore=8,
777
+ spaceAfter=4
778
+ ))
779
+
780
+ styles.add(ParagraphStyle(
781
+ name='AssistantMessage',
782
+ parent=styles['Normal'],
783
+ fontSize=11,
784
+ leftIndent=10,
785
+ spaceBefore=4,
786
+ spaceAfter=12,
787
+ textColor='#2980B9'
788
+ ))
789
+
790
+ styles.add(ParagraphStyle(
791
+ name='Timestamp',
792
+ parent=styles['Italic'],
793
+ fontSize=10,
794
+ textColor='#7F8C8D',
795
+ alignment=2 # Right aligned
796
+ ))
797
+
798
+ # Create the document content
799
+ elements = []
800
+
801
+ # Add title
802
+ elements.append(Paragraph('Data Analysis Report', styles['ReportTitle']))
803
+
804
+ # Add timestamp
805
+ elements.append(Paragraph(f'Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}',
806
+ styles['Timestamp']))
807
+ elements.append(Spacer(1, 0.5*inch))
808
+
809
+ # Add conversation history with better formatting
810
+ elements.append(Paragraph('Analysis Conversation History', styles['SectionHeader']))
811
+
812
+ if st.session_state.chat_history:
813
+ for i, (user_msg, assistant_msg) in enumerate(st.session_state.chat_history):
814
+ # Format user message with proper styling
815
+ elements.append(Paragraph(f'<b>You:</b>', styles['SubHeader']))
816
+ user_msg_formatted = user_msg.replace('\n', '<br/>')
817
+ elements.append(Paragraph(user_msg_formatted, styles['UserMessage']))
818
+
819
+ # Process assistant message to handle visualization
820
+ # Look for markdown image syntax with base64 data
821
+ base64_pattern = r'!\[Visualization\]\(data:image\/png;base64,([^\)]+)\)'
822
+
823
+ # Check if the message contains visualizations
824
+ if '### Visualizations' in assistant_msg or re.search(base64_pattern, assistant_msg):
825
+ # Split the message at the Visualizations header if it exists
826
+ if '### Visualizations' in assistant_msg:
827
+ parts = assistant_msg.split('### Visualizations', 1)
828
+ text_part = parts[0]
829
+ viz_part = "### Visualizations" + parts[1] if len(parts) > 1 else ""
830
+ else:
831
+ # If no header but still has visualization
832
+ match = re.search(base64_pattern, assistant_msg)
833
+ text_part = assistant_msg[:match.start()]
834
+ viz_part = assistant_msg[match.start():]
835
+
836
+ # Format the text part
837
+ elements.append(Paragraph(f'<b>Assistant:</b>', styles['SubHeader']))
838
+ text_part = text_part.replace('\n', '<br/>')
839
+ elements.append(Paragraph(text_part, styles['AssistantMessage']))
840
+
841
+ # Process visualizations
842
+ matches = re.findall(base64_pattern, viz_part)
843
+ for j, base64_data in enumerate(matches):
844
+ try:
845
+ # Decode the base64 image
846
+ image_data = base64.b64decode(base64_data)
847
+
848
+ # Create a temporary file for the image
849
+ temp_img_path = f"{st.session_state.temp_dir}/chat_viz_{i}_{j}.png"
850
+
851
+ with open(temp_img_path, 'wb') as f:
852
+ f.write(image_data)
853
+
854
+ # Add the image to the PDF
855
+ elements.append(Paragraph(f'<b>Visualization:</b>', styles['SubHeader']))
856
+ elements.append(Spacer(1, 0.1*inch))
857
+ img = Image(temp_img_path, width=6*inch, height=4*inch)
858
+ elements.append(img)
859
+ elements.append(Spacer(1, 0.2*inch))
860
+ except Exception as e:
861
+ print(f"Error processing base64 image: {str(e)}")
862
+ elements.append(Paragraph(f"[Error displaying visualization: {str(e)}]",
863
+ styles['Normal']))
864
+ else:
865
+ # No visualizations, just format the text
866
+ elements.append(Paragraph(f'<b>Assistant:</b>', styles['SubHeader']))
867
+ assistant_msg_formatted = assistant_msg.replace('\n', '<br/>')
868
+ if len(assistant_msg_formatted) > 1500:
869
+ assistant_msg_formatted = assistant_msg_formatted[:1500] + '...'
870
+ elements.append(Paragraph(assistant_msg_formatted, styles['AssistantMessage']))
871
+
872
+ elements.append(Spacer(1, 0.2*inch))
873
+ else:
874
+ elements.append(Paragraph('No conversation history available.', styles['Normal']))
875
+
876
+ # Force a page break before the dashboard
877
+ elements.append(PageBreak())
878
+
879
+ # Add dashboard section header
880
+ elements.append(Paragraph('Dashboard Overview', styles['SectionHeader']))
881
+ elements.append(Spacer(1, 0.2*inch))
882
+
883
+ # Capture the dashboard as a single image
884
+ dashboard_img_path = capture_dashboard_screenshot()
885
+
886
+ if dashboard_img_path:
887
+ # Calculate available width
888
+ available_width = doc.width
889
+
890
+ # Create PIL image to get dimensions
891
+ pil_img = PILImage.open(dashboard_img_path)
892
+ img_width, img_height = pil_img.size
893
+
894
+ # Calculate scaling factor to fit within page width
895
+ scale_factor = available_width / img_width
896
+
897
+ # Calculate new height based on aspect ratio
898
+ new_height = img_height * scale_factor
899
+
900
+ # Add the image with scaled dimensions
901
+ img = Image(dashboard_img_path, width=available_width, height=new_height)
902
+ elements.append(img)
903
+ else:
904
+ # Fallback: Add individual plots if combined dashboard fails
905
+ plot_count = 0
906
+ for i, plot in enumerate(st.session_state.dashboard_plots):
907
+ if plot is not None:
908
+ plot_count += 1
909
+
910
+ # Convert plotly figure to image
911
+ img_bytes = io.BytesIO()
912
+ plot.write_image(img_bytes, format='png', width=500, height=300)
913
+ img_bytes.seek(0)
914
+
915
+ # Create a temporary file for the image
916
+ temp_img_path = f"{st.session_state.temp_dir}/plot_{i}.png"
917
+
918
+ with open(temp_img_path, 'wb') as f:
919
+ f.write(img_bytes.getvalue())
920
+
921
+ # Add to PDF with appropriate caption and formatting
922
+ elements.append(Paragraph(f'Dashboard Visualization {i+1}', styles['SubHeader']))
923
+ elements.append(Spacer(1, 0.1*inch))
924
+
925
+ # Add the image with proper scaling
926
+ img = Image(temp_img_path, width=6.5*inch, height=4*inch)
927
+ elements.append(img)
928
+ elements.append(Spacer(1, 0.3*inch))
929
+
930
+ if plot_count == 0:
931
+ elements.append(Paragraph('No visualizations have been added to the dashboard.',
932
+ styles['Normal']))
933
+
934
+ # Build the PDF with improved formatting
935
+ doc.build(elements)
936
+
937
+ # Get the value of the buffer
938
+ pdf_value = buffer.getvalue()
939
+ buffer.close()
940
+
941
+ return pdf_value
942
+
943
+ except Exception as e:
944
+ import traceback
945
+ print(f"Error generating enhanced PDF report: {str(e)}")
946
+ print(traceback.format_exc())
947
+ return None
948
+
949
+ def chat_with_workflow(message, history, dataset_info):
950
+ """Send user query to the workflow and get response"""
951
+
952
+ if not dataset_info:
953
+ return "Please upload at least one dataset before asking questions."
954
+
955
+ # Check if we have a valid API key and model
956
+ if not st.session_state.api_key:
957
+ return "Please set up your API key and model in the Settings tab before chatting."
958
+
959
+ print(f"Chat with workflow called with {len(dataset_info)} datasets")
960
+
961
+ try:
962
+ # Extract chat history for context (last 3 exchanges)
963
+ max_history = 3
964
+ previous_messages = []
965
+
966
+ if history:
967
+ start_idx = max(0, len(history) - max_history)
968
+ recent_history = history[start_idx:]
969
+
970
+ for exchange in recent_history:
971
+ if exchange[0]: # User message
972
+ previous_messages.append(HumanMessage(content=exchange[0]))
973
+ if exchange[1]: # AI response
974
+ previous_messages.append(AIMessage(content=exchange[1]))
975
+
976
+ # Initialize the workflow state
977
+ state = AgentState(
978
+ messages=previous_messages + [HumanMessage(content=message)],
979
+ input_data=dataset_info,
980
+ intermediate_outputs=[],
981
+ current_variables=st.session_state.persistent_vars,
982
+ output_image_paths=[]
983
+ )
984
+
985
+ # Execute the workflow
986
+ print("Executing workflow...")
987
+ result = chain.invoke(state)
988
+ print("Workflow execution completed")
989
+
990
+ # Extract messages from the result
991
+ messages = result["messages"]
992
+
993
+ # Format the response - only get the latest response
994
+ response = ""
995
+ if messages:
996
+ latest_message = messages[-1] # Get only the last message
997
+ if hasattr(latest_message, "content"):
998
+ content = latest_message.content
999
+
1000
+ # Clean up the response
1001
+ # Remove any instances where the user's message is repeated
1002
+ if message in content:
1003
+ content = content.split(message)[-1].strip()
1004
+
1005
+ # Remove any chat history markers
1006
+ content_lines = content.split('\n')
1007
+ filtered_lines = [line for line in content_lines
1008
+ if not line.strip().startswith(("You:", "User:", "Human:", "Assistant:"))]
1009
+ content = '\n'.join(filtered_lines)
1010
+
1011
+ response = content.strip() + "\n\n"
1012
+
1013
+ # Handle visualizations
1014
+ if "output_image_paths" in result and result["output_image_paths"]:
1015
+ response += "### Visualizations\n\n"
1016
+ for img_path in result["output_image_paths"]:
1017
+ try:
1018
+ full_path = os.path.join(st.session_state.images_dir, img_path)
1019
+ with open(full_path, 'rb') as f:
1020
+ fig = pickle.load(f)
1021
+
1022
+ # Convert plotly figure to image
1023
+ img_bytes = BytesIO()
1024
+ fig.update_layout(width=800, height=500)
1025
+ pio.write_image(fig, img_bytes, format='png')
1026
+ img_bytes.seek(0)
1027
+
1028
+ # Convert to base64 for markdown image
1029
+ b64_img = base64.b64encode(img_bytes.read()).decode()
1030
+ response += f"![Visualization](data:image/png;base64,{b64_img})\n\n"
1031
+ except Exception as e:
1032
+ response += f"Error loading visualization: {str(e)}\n\n"
1033
+
1034
+ return response
1035
+
1036
+ except Exception as e:
1037
+ import traceback
1038
+ print(f"Error in chat_with_workflow: {str(e)}")
1039
+ print(traceback.format_exc())
1040
+ return f"Error executing workflow: {str(e)}"
1041
+
1042
+ def auto_generate_dashboard(dataset_info):
1043
+ """Generate an automatic dashboard with four plots"""
1044
+
1045
+ if not dataset_info:
1046
+ return "Please upload a dataset first.", [None, None, None, None]
1047
+
1048
+ prompt = """
1049
+ You are a data visualization expert. Given a dataset, identify the top 4 most insightful plots using statistical reasoning or patterns (correlation, distribution, trends).
1050
+
1051
+ Use plotly and store the plots in a list named plotly_figures.
1052
+
1053
+ Include multivariate plots using color/size/facets when helpful.
1054
+ """
1055
+
1056
+ state = AgentState(
1057
+ messages=[HumanMessage(content=prompt)],
1058
+ input_data=dataset_info,
1059
+ intermediate_outputs=[],
1060
+ current_variables=st.session_state.persistent_vars,
1061
+ output_image_paths=[]
1062
+ )
1063
+
1064
+ result = chain.invoke(state)
1065
+ figures = []
1066
+
1067
+ if "output_image_paths" in result:
1068
+ for img_path in result["output_image_paths"][:4]:
1069
+ try:
1070
+ full_path = os.path.join(st.session_state.images_dir, img_path)
1071
+ with open(full_path, 'rb') as f:
1072
+ fig = pickle.load(f)
1073
+ figures.append(fig)
1074
+ except Exception as e:
1075
+ print(f"Error loading figure: {e}")
1076
+
1077
+ while len(figures) < 4:
1078
+ figures.append(None)
1079
+
1080
+ st.session_state.dashboard_plots = figures
1081
+ return "Dashboard generated!", figures
1082
+
1083
+ def generate_custom_plots_with_llm(dataset_info, x_col, y_col, facet_col):
1084
+ """Generate custom plots based on user-selected columns"""
1085
+
1086
+ if not dataset_info or not x_col or not y_col:
1087
+ return [None, None, None]
1088
+
1089
+ prompt = f"""
1090
+ You are a data visualization expert.
1091
+
1092
+ Create 3 insightful visualizations using Plotly based on:
1093
+
1094
+ - X-axis: {x_col}
1095
+ - Y-axis: {y_col}
1096
+ - Facet (optional): {facet_col if facet_col != 'None' else 'None'}
1097
+
1098
+ Try to find interesting relationships, trends, or clusters using appropriate chart types.
1099
+
1100
+ Use `plotly_figures` list and avoid using fig.show().
1101
+ """
1102
+
1103
+ state = AgentState(
1104
+ messages=[HumanMessage(content=prompt)],
1105
+ input_data=dataset_info,
1106
+ intermediate_outputs=[],
1107
+ current_variables=st.session_state.persistent_vars,
1108
+ output_image_paths=[]
1109
+ )
1110
+
1111
+ result = chain.invoke(state)
1112
+ figures = []
1113
+
1114
+ if "output_image_paths" in result:
1115
+ for img_path in result["output_image_paths"][:3]:
1116
+ try:
1117
+ full_path = os.path.join(st.session_state.images_dir, img_path)
1118
+ with open(full_path, 'rb') as f:
1119
+ fig = pickle.load(f)
1120
+ figures.append(fig)
1121
+ except Exception as e:
1122
+ print(f"Error loading figure: {e}")
1123
+
1124
+ while len(figures) < 3:
1125
+ figures.append(None)
1126
+ return figures
1127
+
1128
+ def remove_plot(index):
1129
+ """Remove a plot from the dashboard"""
1130
+ if 0 <= index < len(st.session_state.dashboard_plots):
1131
+ st.session_state.dashboard_plots[index] = None
1132
+
1133
+ def respond(message):
1134
+ """Handle chat message response"""
1135
+ if not st.session_state.dataset_metadata_list:
1136
+ bot_message = "Please upload at least one dataset before asking questions."
1137
+ else:
1138
+ bot_message = chat_with_workflow(message, st.session_state.chat_history, st.session_state.dataset_metadata_list)
1139
+
1140
+ st.session_state.chat_history.append((message, bot_message))
1141
+ st.rerun()
1142
+
1143
+ def save_plot_to_dashboard(plot_index):
1144
+ """Callback for the Add Plot button"""
1145
+ for i in range(len(st.session_state.dashboard_plots)):
1146
+ if st.session_state.dashboard_plots[i] is None:
1147
+ # Found an empty slot
1148
+ st.session_state.dashboard_plots[i] = st.session_state.custom_plots_to_save[plot_index]
1149
+ return
1150
+
1151
+ # New function to check if settings are valid
1152
+ def is_settings_valid():
1153
+ """Check if API key and model are configured"""
1154
+ return st.session_state.api_key != ""
1155
+
1156
+ # Streamlit UI with left panel settings
1157
+ st.set_page_config(page_title="Data Analysis Assistant", layout="wide")
1158
+
1159
+ # Create side panel for settings
1160
+ with st.sidebar:
1161
+ st.header("Settings")
1162
+ st.info("⚠️ Configure your API settings before uploading data")
1163
+
1164
+ # AI Provider selection
1165
+ provider = st.radio("Select AI Provider",
1166
+ options=["OpenAI", "Groq"],
1167
+ index=0 if st.session_state.ai_provider == "openai" else 1,
1168
+ horizontal=True)
1169
+
1170
+ # Update session state based on selection
1171
+ st.session_state.ai_provider = provider.lower()
1172
+
1173
+ # API Key input
1174
+ api_key = st.text_input("Enter API Key",
1175
+ value=st.session_state.api_key,
1176
+ type="password",
1177
+ help="Your API key for the selected provider")
1178
+
1179
+ # Display different model options based on provider
1180
+ if st.session_state.ai_provider == "openai":
1181
+ model_options = OPENAI_MODELS
1182
+ model_help = "GPT-4 provides the best results but is slower. GPT-3.5-Turbo is faster but less capable."
1183
+ else: # groq
1184
+ model_options = GROQ_MODELS
1185
+ model_help = "Llama 3.3 70B is most capable. Gemma 2 9B offers good balance. Llama 3 8B is fastest."
1186
+
1187
+ # Model selection
1188
+ selected_model = st.selectbox("Select Model",
1189
+ options=model_options,
1190
+ index=model_options.index(st.session_state.selected_model) if st.session_state.selected_model in model_options else 0,
1191
+ help=model_help)
1192
+
1193
+ # Save button
1194
+ if st.button("Save Settings"):
1195
+ st.session_state.api_key = api_key
1196
+ st.session_state.selected_model = selected_model
1197
+
1198
+ # Test the API key and model
1199
+ try:
1200
+ # Initialize LLM using the provided settings
1201
+ test_llm = initialize_llm()
1202
+ if test_llm:
1203
+ st.success(f"✅ Successfully configured {provider} with model: {selected_model}")
1204
+ else:
1205
+ st.error("Failed to initialize the AI provider. Please check your API key and model selection.")
1206
+ except Exception as e:
1207
+ st.error(f"Error testing settings: {str(e)}")
1208
+
1209
+ # Display current settings
1210
+ st.subheader("Current Settings")
1211
+ settings_info = f"""
1212
+ - **Provider**: {st.session_state.ai_provider.upper()}
1213
+ - **Model**: {st.session_state.selected_model}
1214
+ - **API Key**: {'✅ Set' if st.session_state.api_key else '❌ Not Set'}
1215
+ """
1216
+ st.markdown(settings_info)
1217
+
1218
+ # Provider-specific information
1219
+ if st.session_state.ai_provider == "openai":
1220
+ with st.expander("OpenAI Models Info"):
1221
+ st.info("""
1222
+ - **GPT-4**: Most powerful model, best for complex analysis and detailed explanations
1223
+ - **GPT-4-Turbo**: Faster than GPT-4 with similar capabilities
1224
+ - **GPT-4-Mini**: Economical option with good performance for standard tasks
1225
+ - **GPT-3.5-Turbo**: Fastest option, suitable for basic analysis and visualization
1226
+ """)
1227
+ else:
1228
+ with st.expander("Groq Models Info"):
1229
+ st.info("""
1230
+ - **llama3.3-70b-versatile**: Most powerful model for comprehensive analysis
1231
+ - **gemma2-9b-it**: Good balance of speed and capabilities
1232
+ - **llama-3-8b-8192**: Fastest option for basic analysis tasks
1233
+ """)
1234
+
1235
+ # Integration instructions
1236
+ with st.expander("How to get API Keys"):
1237
+ if st.session_state.ai_provider == "openai":
1238
+ st.markdown("""
1239
+ ### Getting an OpenAI API Key
1240
+
1241
+ 1. Go to [OpenAI's platform](https://platform.openai.com)
1242
+ 2. Sign up or log in to your account
1243
+ 3. Navigate to the API section
1244
+ 4. Create a new API key
1245
+ 5. Copy the key and paste it above
1246
+
1247
+ Note: OpenAI API usage incurs charges based on tokens used.
1248
+ """)
1249
+ else:
1250
+ st.markdown("""
1251
+ ### Getting a Groq API Key
1252
+
1253
+ 1. Go to [Groq's website](https://console.groq.com/keys)
1254
+ 2. Sign up or log in to your account
1255
+ 3. Navigate to API Keys section
1256
+ 4. Create a new API key
1257
+ 5. Copy the key and paste it above
1258
+
1259
+ Note: Check Groq's pricing page for current rates.
1260
+ """)
1261
+
1262
+ # Main content
1263
+ st.title("Data Analysis Assistant")
1264
+ st.markdown("Upload your datasets, ask questions, and generate visualizations to gain insights.")
1265
+
1266
+ # Check if settings are valid before showing tabs
1267
+ if not is_settings_valid():
1268
+ st.warning("⚠️ Please configure your API key in the sidebar settings panel before proceeding.")
1269
+ st.stop()
1270
+
1271
+ # Create main tabs - only if settings are valid
1272
+ tab1, tab2, tab3, tab4, tab5 = st.tabs(["Upload Datasets", "Data Cleaning", "Chat with AI Assistant", "Auto Dashboard Generator", "Generate Report"])
1273
+
1274
+ with tab1:
1275
+ st.header("Upload Datasets")
1276
+ uploaded_files = st.file_uploader("Upload CSV or Excel Files",
1277
+ accept_multiple_files=True,
1278
+ type=['csv', 'xlsx', 'xls'])
1279
+
1280
+ if uploaded_files and st.button("Process Uploaded Files"):
1281
+ with st.spinner("Processing files..."):
1282
+ preview, metadata_list, columns = process_file_upload(uploaded_files)
1283
+ st.session_state.columns = columns
1284
+
1285
+ # Display basic information about processed files
1286
+ st.success(f"✅ Successfully processed {len(uploaded_files)} file(s)")
1287
+
1288
+ # Show detailed preview for each dataset
1289
+ st.subheader("Dataset Previews")
1290
+
1291
+ for dataset_name, df in st.session_state.in_memory_datasets.items():
1292
+ with st.expander(f"Preview: {dataset_name}"):
1293
+ # Display dataset info
1294
+ st.write(f"**Rows:** {df.shape[0]} | **Columns:** {df.shape[1]}")
1295
+
1296
+ # Display column information
1297
+ col_info = pd.DataFrame({
1298
+ 'Column Name': df.columns,
1299
+ 'Data Type': df.dtypes.astype(str),
1300
+ 'Non-Null Count': df.count().values,
1301
+ 'Sample Values': [', '.join(df[col].dropna().astype(str).head(3).tolist()) for col in df.columns]
1302
+ })
1303
+
1304
+ # Show column information in a compact table
1305
+ st.write("**Column Information:**")
1306
+ st.dataframe(col_info, use_container_width=True)
1307
+
1308
+ # Show actual data preview
1309
+ st.write("**Data Preview (First 10 rows):**")
1310
+ st.dataframe(df.head(10), use_container_width=True)
1311
+
1312
+ # Provide hint for the next steps
1313
+ st.info("👆 Click on the dataset names above to see detailed previews. Then proceed to the Data Cleaning tab to clean your data or Chat with AI Assistant to analyze it.")
1314
+
1315
+ with tab2:
1316
+ st.header("Data Cleaning")
1317
+
1318
+ if 'cleaning_done' not in st.session_state:
1319
+ st.session_state.cleaning_done = False
1320
+
1321
+ if 'cleaned_datasets' not in st.session_state:
1322
+ st.session_state.cleaned_datasets = {}
1323
+
1324
+ if 'cleaning_summaries' not in st.session_state:
1325
+ st.session_state.cleaning_summaries = {}
1326
+
1327
+ if st.session_state.get("in_memory_datasets"):
1328
+ if not st.session_state.cleaning_done:
1329
+ if st.button("Run Data Cleaning"):
1330
+ with st.spinner("Running LLM-assisted cleaning..."):
1331
+ for name, df in st.session_state.in_memory_datasets.items():
1332
+ raw_df = df.copy()
1333
+ df_std = standard_clean(raw_df.copy())
1334
+ suggestions = llm_suggest_cleaning(df_std.copy())
1335
+ df_clean = apply_suggestions(df_std.copy(), suggestions)
1336
+ st.session_state.cleaned_datasets[name] = df_clean
1337
+ st.session_state.cleaning_summaries[name] = suggestions
1338
+ st.session_state.cleaning_done = True
1339
+ st.rerun()
1340
+ else:
1341
+ st.info("Click Run Data Cleaning to clean your datasets using the LLM.")
1342
+ else:
1343
+ for name, df_clean in st.session_state.cleaned_datasets.items():
1344
+ raw_df = st.session_state.in_memory_datasets[name]
1345
+
1346
+ st.subheader(f"Dataset: {name}")
1347
+ col1, col2 = st.columns(2)
1348
+
1349
+ with col1:
1350
+ st.markdown("Original Data (First 5 Rows)")
1351
+ st.dataframe(raw_df.head())
1352
+
1353
+ with col2:
1354
+ st.markdown("Cleaned Data (First 5 Rows)")
1355
+ st.dataframe(df_clean.head())
1356
+
1357
+ st.markdown("Summary of Cleaning Actions")
1358
+ suggestions = st.session_state.cleaning_summaries[name]
1359
+ summary_text = ""
1360
+
1361
+ if suggestions:
1362
+ for key, value in suggestions.items():
1363
+ summary_text += f"**{key}**: {json.dumps(value, indent=2)}\n\n"
1364
+ st.markdown(summary_text)
1365
+
1366
+ st.markdown("Refine the Cleaning (Natural Language Instructions)")
1367
+ user_input = st.text_input("Example: Convert 'dob' to datetime and fill missing with '2000-01-01'",
1368
+ key=f"user_input_{name}")
1369
+
1370
+ if f'corrections_{name}' not in st.session_state:
1371
+ st.session_state[f'corrections_{name}'] = []
1372
+
1373
+ if st.button("Apply Correction", key=f'apply_correction_{name}'):
1374
+ if user_input.strip():
1375
+ correction_prompt = f"""
1376
+ You are a data cleaning expert. Below is a previously cleaned dataset with these actions:
1377
+
1378
+ {summary_text}
1379
+
1380
+ The user now wants the following additional instruction:
1381
+ \"{user_input.strip()}\"
1382
+
1383
+ Write only the Python code that modifies the pandas DataFrame `df` accordingly.
1384
+ Do not include explanations or markdown.
1385
+ """
1386
+ correction_code = query_openai(correction_prompt)
1387
+
1388
+ try:
1389
+ df = st.session_state.cleaned_datasets[name].copy()
1390
+ local_vars = {"df": df}
1391
+ exec(correction_code, {}, local_vars)
1392
+ df_updated = local_vars["df"]
1393
+
1394
+ st.session_state.cleaned_datasets[name] = df_updated
1395
+ st.session_state[f'corrections_{name}'].append((user_input, correction_code))
1396
+ st.success("Correction applied.")
1397
+ st.rerun()
1398
+
1399
+ except Exception as e:
1400
+ st.error(f"Failed to apply correction: {str(e)}")
1401
+
1402
+ if st.session_state[f'corrections_{name}']:
1403
+ st.markdown("Applied Corrections")
1404
+ for i, (msg, code) in enumerate(st.session_state[f'corrections_{name}']):
1405
+ st.markdown(f"**Instruction:** {msg}")
1406
+ st.code(code, language='python')
1407
+
1408
+ col1, col2 = st.columns([1, 2])
1409
+ with col1:
1410
+ if st.button("Reset Cleaning and Re-run"):
1411
+ st.session_state.cleaning_done = False
1412
+ st.rerun()
1413
+
1414
+ with col2:
1415
+ if st.button("Finalize and Proceed to Visualizations"):
1416
+ st.session_state.cleaning_finalized = True
1417
+ st.rerun()
1418
+ else:
1419
+ st.info("Please upload and process datasets first.")
1420
+
1421
+ with tab3:
1422
+ st.header("Chat with AI Assistant")
1423
+
1424
+ st.markdown("""
1425
+ ## Example Questions
1426
+ - "What analysis can you perform on this dataset?"
1427
+ - "Show me basic statistics for all columns"
1428
+ - "Create a correlation heatmap"
1429
+ - "Plot the distribution of a specific column"
1430
+ - "What is the relationship between two columns?"
1431
+ """)
1432
+
1433
+ # Display chat history
1434
+ for exchange in st.session_state.chat_history:
1435
+ with st.chat_message("user"):
1436
+ st.write(exchange[0])
1437
+ with st.chat_message("assistant"):
1438
+ st.write(exchange[1])
1439
+
1440
+ # Chat input
1441
+ if prompt := st.chat_input("Your question"):
1442
+ with st.spinner("Thinking..."):
1443
+ respond(prompt)
1444
+
1445
+ with tab4:
1446
+ st.header("Auto Dashboard Generator")
1447
+
1448
+ # Dashboard controls
1449
+ dashboard_title = st.text_input("Dashboard Title", placeholder="Enter your dashboard title")
1450
+
1451
+ col1, col2 = st.columns(2)
1452
+
1453
+ with col1:
1454
+ if st.button("Generate Suggested Dashboard (Auto)"):
1455
+ with st.spinner("Generating dashboard..."):
1456
+ message, figures = auto_generate_dashboard(st.session_state.dataset_metadata_list)
1457
+ st.success(message)
1458
+
1459
+ with col2:
1460
+ if st.button("Refresh Column Options"):
1461
+ st.session_state.columns = get_columns()
1462
+ st.rerun()
1463
+
1464
+ # Dashboard display
1465
+ st.subheader("Dashboard")
1466
+
1467
+ # Row 1
1468
+ col1, col2 = st.columns(2)
1469
+
1470
+ with col1:
1471
+ if st.session_state.dashboard_plots[0]:
1472
+ st.plotly_chart(st.session_state.dashboard_plots[0], use_container_width=True)
1473
+ if st.button("Remove Plot 1"):
1474
+ remove_plot(0)
1475
+ st.rerun()
1476
+
1477
+ with col2:
1478
+ if st.session_state.dashboard_plots[1]:
1479
+ st.plotly_chart(st.session_state.dashboard_plots[1], use_container_width=True)
1480
+ if st.button("Remove Plot 2"):
1481
+ remove_plot(1)
1482
+ st.rerun()
1483
+
1484
+ # Row 2
1485
+ col3, col4 = st.columns(2)
1486
+
1487
+ with col3:
1488
+ if st.session_state.dashboard_plots[2]:
1489
+ st.plotly_chart(st.session_state.dashboard_plots[2], use_container_width=True)
1490
+ if st.button("Remove Plot 3"):
1491
+ remove_plot(2)
1492
+ st.rerun()
1493
+
1494
+ with col4:
1495
+ if st.session_state.dashboard_plots[3]:
1496
+ st.plotly_chart(st.session_state.dashboard_plots[3], use_container_width=True)
1497
+ if st.button("Remove Plot 4"):
1498
+ remove_plot(3)
1499
+ st.rerun()
1500
+
1501
+ # Custom plot generator
1502
+ st.subheader("Custom Plot Generator")
1503
+
1504
+ # Column selection
1505
+ col1, col2, col3 = st.columns(3)
1506
+
1507
+ with col1:
1508
+ x_axis = st.selectbox("X-axis Column", options=st.session_state.columns)
1509
+
1510
+ with col2:
1511
+ y_axis = st.selectbox("Y-axis Column", options=st.session_state.columns)
1512
+
1513
+ with col3:
1514
+ facet = st.selectbox("Facet (optional)", options=["None"] + st.session_state.columns)
1515
+
1516
+ if st.button("Generate Custom Visualizations"):
1517
+ with st.spinner("Generating custom visualizations..."):
1518
+ custom_plots = generate_custom_plots_with_llm(st.session_state.dataset_metadata_list, x_axis, y_axis, facet)
1519
+ # Store plots in session state
1520
+ for i, plot in enumerate(custom_plots):
1521
+ if plot:
1522
+ st.session_state.custom_plots_to_save[i] = plot
1523
+
1524
+ # Display custom plots with add buttons
1525
+ for i, plot in enumerate(custom_plots):
1526
+ if plot:
1527
+ st.plotly_chart(plot, use_container_width=True)
1528
+ st.button(
1529
+ f"Add Plot {i+1} to Dashboard",
1530
+ key=f"add_plot_{i}",
1531
+ on_click=save_plot_to_dashboard,
1532
+ args=(i,)
1533
+ )
1534
+
1535
+ with tab5:
1536
+ st.header("Generate Analysis Report")
1537
+
1538
+ st.markdown("""
1539
+ Generate a PDF report containing:
1540
+ - Dashboard visualizations
1541
+ - Chat conversation history
1542
+ """)
1543
+
1544
+ report_title = st.text_input("Report Title (Optional)", "Data Analysis Report")
1545
+
1546
+ if st.button("Generate PDF Report"):
1547
+ with st.spinner("Generating report..."):
1548
+ pdf_data = generate_enhanced_pdf_report()
1549
+ if pdf_data:
1550
+ # Create download button for PDF
1551
+ b64_pdf = base64.b64encode(pdf_data).decode('utf-8')
1552
+ # Create download link
1553
+ pdf_download_link = f'<a href="data:application/pdf;base64,{b64_pdf}" download="data_analysis_report.pdf">Download PDF Report</a>'
1554
+ st.markdown("### Your report is ready!")
1555
+ st.markdown(pdf_download_link, unsafe_allow_html=True)
1556
+ # Preview option (simplified)
1557
+ with st.expander("Preview Report"):
1558
+ st.warning("PDF preview is not available in Streamlit, please download the report to view it.")
1559
+ else:
1560
+ st.error("Failed to generate the report. Please try again.")
1561
+
1562
+ # Cleanup on app exit
1563
+ def cleanup():
1564
+ try:
1565
+ shutil.rmtree(st.session_state.temp_dir)
1566
+ print(f"Cleaned up temporary directory: {st.session_state.temp_dir}")
1567
+ except Exception as e:
1568
+ print(f"Error cleaning up: {e}")
1569
+
1570
+ import atexit
1571
+ atexit.register(cleanup)
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ pandas
3
+ numpy
4
+ plotly
5
+ langchain
6
+ langgraph==0.2.74
7
+ langchain-core
8
+ langchain-groq
9
+ langchain_openai
10
+ langchain-experimental
11
+ reportlab
12
+ Pillow
13
+ scikit-learn
14
+ tabulate
15
+ kaleido