RAVENOCC commited on
Commit
f3bbe97
·
verified ·
1 Parent(s): fb37d08

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +1618 -0
  2. requirements.txt +15 -0
app.py ADDED
@@ -0,0 +1,1618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import streamlit as st
3
+ import pandas as pd
4
+ import pickle
5
+ import base64
6
+ from io import BytesIO, StringIO
7
+ import sys
8
+ import operator
9
+ from typing import Literal, Sequence, TypedDict, Annotated, List, Dict, Tuple
10
+ import tempfile
11
+ import shutil
12
+ import plotly.io as pio
13
+ import io
14
+ import re
15
+ import json
16
+ import openai
17
+ # from fpdf import FPDF
18
+ import base64
19
+ from datetime import datetime
20
+ from reportlab.lib.pagesizes import letter
21
+ from reportlab.platypus import SimpleDocTemplate, Paragraph, Spacer, Image
22
+ from reportlab.lib.styles import getSampleStyleSheet, ParagraphStyle
23
+ from reportlab.lib.units import inch
24
+ from PIL import Image as PILImage
25
+
26
+ # Import LangChain and LangGraph components
27
+ from langchain_core.messages import AIMessage, ToolMessage, HumanMessage, BaseMessage
28
+ from langchain_core.prompts import ChatPromptTemplate
29
+ from langchain_openai import ChatOpenAI
30
+ from langchain_experimental.utilities import PythonREPL
31
+ from langgraph.prebuilt import ToolInvocation, ToolExecutor
32
+ from langchain_core.tools import tool
33
+ from langgraph.prebuilt import InjectedState
34
+ from langgraph.graph import StateGraph, END
35
+ from reportlab.platypus import PageBreak
36
+
37
+
38
+ # Initialize session state for AI provider settings
39
+ if 'ai_provider' not in st.session_state:
40
+ st.session_state.ai_provider = "openai"
41
+
42
+ if 'api_key' not in st.session_state:
43
+ st.session_state.api_key = ""
44
+
45
+ if 'selected_model' not in st.session_state:
46
+ st.session_state.selected_model = "gpt-4"
47
+
48
+ # Define model options for each provider
49
+ OPENAI_MODELS = ["gpt-4", "gpt-4-turbo", "gpt-4-mini", "gpt-3.5-turbo"]
50
+ GROQ_MODELS = ["llama3.3-70b-versatile", "gemma2-9b-it", "llama-3-8b-8192"]
51
+
52
+ # Create temporary directory for file storage
53
+ if 'temp_dir' not in st.session_state:
54
+ st.session_state.temp_dir = tempfile.mkdtemp()
55
+ st.session_state.images_dir = os.path.join(st.session_state.temp_dir, "images/plotly_figures/pickle")
56
+ os.makedirs(st.session_state.images_dir, exist_ok=True)
57
+ print(f"Created temporary directory: {st.session_state.temp_dir}")
58
+ print(f"Created images directory: {st.session_state.images_dir}")
59
+
60
+ # Define the system prompt
61
+ SYSTEM_PROMPT = """## Role
62
+ You are a professional data scientist helping a non-technical user understand, analyze, and visualize their data.
63
+
64
+ ## Capabilities
65
+ 1. **Execute python code** using the `complete_python_task` tool.
66
+
67
+ ## Goals
68
+ 1. Understand the user's objectives clearly.
69
+ 2. Take the user on a data analysis journey, iterating to find the best way to visualize or analyse their data to solve their problems.
70
+ 3. Investigate if the goal is achievable by running Python code via the `python_code` field.
71
+ 4. Gain input from the user at every step to ensure the analysis is on the right track and to understand business nuances.
72
+
73
+ ## Code Guidelines
74
+ - **ALL INPUT DATA IS LOADED ALREADY**, so use the provided variable names to access the data.
75
+ - **VARIABLES PERSIST BETWEEN RUNS**, so reuse previously defined variables if needed.
76
+ - **TO SEE CODE OUTPUT**, use `print()` statements. You won't be able to see outputs of `pd.head()`, `pd.describe()` etc. otherwise.
77
+ - **ONLY USE THE FOLLOWING LIBRARIES**:
78
+ - `pandas`
79
+ - `sklearn` (including all major ML models)
80
+ - `plotly`
81
+ - `numpy`
82
+
83
+ All these libraries are already imported for you.
84
+
85
+ ## Machine Learning Guidelines
86
+ - For regression tasks:
87
+ - Linear Regression: `LinearRegression`
88
+ - Logistic Regression: `LogisticRegression`
89
+ - Ridge Regression: `Ridge`
90
+ - Lasso Regression: `Lasso`
91
+ - Random Forest Regression: `RandomForestRegressor`
92
+
93
+ - For classification tasks:
94
+ - Logistic Regression: `LogisticRegression`
95
+ - Decision Trees: `DecisionTreeClassifier`
96
+ - Random Forests: `RandomForestClassifier`
97
+ - Support Vector Machines: `SVC`
98
+ - K-Nearest Neighbors: `KNeighborsClassifier`
99
+ - Naive Bayes: `GaussianNB`
100
+
101
+ - For clustering:
102
+ - K-Means: `KMeans`
103
+ - DBSCAN: `DBSCAN`
104
+
105
+ - For dimensionality reduction:
106
+ - PCA: `PCA`
107
+
108
+ - Always preprocess data appropriately:
109
+ - Scale numerical features with `StandardScaler` or `MinMaxScaler`
110
+ - Encode categorical variables with `OneHotEncoder` when needed
111
+ - Handle missing values with `SimpleImputer`
112
+
113
+ - Always split data into training and testing sets using `train_test_split`
114
+ - Evaluate models using appropriate metrics:
115
+ - For regression: `mean_squared_error`, `mean_absolute_error`, `r2_score`
116
+ - For classification: `accuracy_score`, `confusion_matrix`, `classification_report`
117
+ - For clustering: `silhouette_score`
118
+
119
+ - Consider using `cross_val_score` for more robust evaluation
120
+ - Visualize ML results with plotly when possible
121
+
122
+ ## Plotting Guidelines
123
+ - Always use the `plotly` library for plotting.
124
+ - Store all plotly figures inside a `plotly_figures` list, they will be saved automatically.
125
+ - Do not try and show the plots inline with `fig.show()`.
126
+ """
127
+
128
+ # Define the State class
129
+ class AgentState(TypedDict):
130
+ messages: Annotated[Sequence[BaseMessage], operator.add]
131
+ input_data: Annotated[List[Dict], operator.add]
132
+ intermediate_outputs: Annotated[List[dict], operator.add]
133
+ current_variables: dict
134
+ output_image_paths: Annotated[List[str], operator.add]
135
+
136
+ # Initialize session state variables
137
+ if 'in_memory_datasets' not in st.session_state:
138
+ st.session_state.in_memory_datasets = {}
139
+
140
+ if 'persistent_vars' not in st.session_state:
141
+ st.session_state.persistent_vars = {}
142
+
143
+ if 'dataset_metadata_list' not in st.session_state:
144
+ st.session_state.dataset_metadata_list = []
145
+
146
+ if 'chat_history' not in st.session_state:
147
+ st.session_state.chat_history = []
148
+
149
+ if 'dashboard_plots' not in st.session_state:
150
+ st.session_state.dashboard_plots = [None, None, None, None]
151
+
152
+ if 'columns' not in st.session_state:
153
+ st.session_state.columns = ["No columns available"]
154
+
155
+ if 'custom_plots_to_save' not in st.session_state:
156
+ st.session_state.custom_plots_to_save = {}
157
+
158
+ # Set up the tools
159
+ repl = PythonREPL()
160
+ plotly_saving_code = """import pickle
161
+
162
+ import uuid
163
+ import os
164
+ for figure in plotly_figures:
165
+ pickle_filename = f"{images_dir}/{uuid.uuid4()}.pickle"
166
+ with open(pickle_filename, 'wb') as f:
167
+ pickle.dump(figure, f)
168
+ """
169
+
170
+ @tool
171
+ def complete_python_task(
172
+ graph_state: Annotated[dict, InjectedState],
173
+ thought: str,
174
+ python_code: str
175
+ ) -> Tuple[str, dict]:
176
+ """Execute Python code for data analysis and visualization."""
177
+
178
+ current_variables = graph_state.get("current_variables", {})
179
+
180
+ # Load datasets from in-memory storage
181
+ for input_dataset in graph_state.get("input_data", []):
182
+ var_name = input_dataset.get("variable_name")
183
+ if var_name and var_name not in current_variables and var_name in st.session_state.in_memory_datasets:
184
+ print(f"Loading {var_name} from in-memory storage")
185
+ current_variables[var_name] = st.session_state.in_memory_datasets[var_name]
186
+ current_image_pickle_files = os.listdir(st.session_state.images_dir)
187
+
188
+ try:
189
+ # Capture stdout
190
+ old_stdout = sys.stdout
191
+ sys.stdout = StringIO()
192
+
193
+ # Execute the code and capture the result
194
+ exec_globals = globals().copy()
195
+ exec_globals.update(st.session_state.persistent_vars)
196
+ exec_globals.update(current_variables)
197
+
198
+ # Add scikit-learn modules to execution environment
199
+ import sklearn
200
+ import numpy as np
201
+
202
+ # Import scikit-learn components
203
+ from sklearn.linear_model import LinearRegression, LogisticRegression, Ridge, Lasso # type: ignore
204
+ from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, GradientBoostingClassifier # type: ignore
205
+ from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
206
+ from sklearn.svm import SVC, SVR
207
+ from sklearn.naive_bayes import GaussianNB
208
+ from sklearn.decomposition import PCA
209
+ from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
210
+ from sklearn.cluster import KMeans, DBSCAN
211
+ from sklearn.preprocessing import StandardScaler, MinMaxScaler, OneHotEncoder
212
+ from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV
213
+ from sklearn.metrics import (
214
+ accuracy_score, confusion_matrix, classification_report,
215
+ mean_squared_error, r2_score, mean_absolute_error, silhouette_score
216
+ )
217
+ from sklearn.pipeline import Pipeline
218
+ from sklearn.impute import SimpleImputer
219
+
220
+ # Update execution globals with all ML components
221
+ exec_globals.update({
222
+ "plotly_figures": [],
223
+ "images_dir": st.session_state.images_dir,
224
+ "np": np,
225
+ # Linear models
226
+ "LinearRegression": LinearRegression,
227
+ "LogisticRegression": LogisticRegression,
228
+ "Ridge": Ridge,
229
+ "Lasso": Lasso,
230
+ # Tree-based models
231
+ "DecisionTreeClassifier": DecisionTreeClassifier,
232
+ "DecisionTreeRegressor": DecisionTreeRegressor,
233
+ "RandomForestClassifier": RandomForestClassifier,
234
+ "RandomForestRegressor": RandomForestRegressor,
235
+ "GradientBoostingClassifier": GradientBoostingClassifier,
236
+ # SVM models
237
+ "SVC": SVC,
238
+ "SVR": SVR,
239
+ # Other models
240
+ "GaussianNB": GaussianNB,
241
+ "PCA": PCA,
242
+ "KNeighborsClassifier": KNeighborsClassifier,
243
+ "KNeighborsRegressor": KNeighborsRegressor,
244
+ "KMeans": KMeans,
245
+ "DBSCAN": DBSCAN,
246
+ # Preprocessing
247
+ "StandardScaler": StandardScaler,
248
+ "MinMaxScaler": MinMaxScaler,
249
+ "OneHotEncoder": OneHotEncoder,
250
+ "SimpleImputer": SimpleImputer,
251
+ # Model selection and evaluation
252
+ "train_test_split": train_test_split,
253
+ "cross_val_score": cross_val_score,
254
+ "GridSearchCV": GridSearchCV,
255
+ "accuracy_score": accuracy_score,
256
+ "confusion_matrix": confusion_matrix,
257
+ "classification_report": classification_report,
258
+ "mean_squared_error": mean_squared_error,
259
+ "r2_score": r2_score,
260
+ "mean_absolute_error": mean_absolute_error,
261
+ "silhouette_score": silhouette_score,
262
+ # Pipeline
263
+ "Pipeline": Pipeline
264
+ })
265
+
266
+ exec(python_code, exec_globals)
267
+
268
+ st.session_state.persistent_vars.update({k: v for k, v in exec_globals.items() if k not in globals()})
269
+
270
+ # Get the captured stdout
271
+ output = sys.stdout.getvalue()
272
+
273
+ # Restore stdout
274
+ sys.stdout = old_stdout
275
+
276
+ updated_state = {
277
+ "intermediate_outputs": [{"thought": thought, "code": python_code, "output": output}],
278
+ "current_variables": st.session_state.persistent_vars
279
+ }
280
+
281
+ if 'plotly_figures' in exec_globals and exec_globals['plotly_figures']:
282
+ exec(plotly_saving_code, exec_globals)
283
+
284
+ # Check if any images were created
285
+ new_image_folder_contents = os.listdir(st.session_state.images_dir)
286
+ new_image_files = [file for file in new_image_folder_contents if file not in current_image_pickle_files]
287
+
288
+ if new_image_files:
289
+ updated_state["output_image_paths"] = new_image_files
290
+ st.session_state.persistent_vars["plotly_figures"] = []
291
+ return output, updated_state
292
+
293
+ except Exception as e:
294
+ sys.stdout = old_stdout # Restore stdout in case of error
295
+ print(f"Error in complete_python_task: {str(e)}")
296
+ return str(e), {"intermediate_outputs": [{"thought": thought, "code": python_code, "output": str(e)}]}
297
+
298
+ # Function to initialize the LLM based on selected provider and model
299
+ def initialize_llm():
300
+ api_key = st.session_state.api_key
301
+ model = st.session_state.selected_model
302
+
303
+ if not api_key:
304
+ return None
305
+
306
+ try:
307
+ if st.session_state.ai_provider == "openai":
308
+ os.environ["OPENAI_API_KEY"] = api_key
309
+ return ChatOpenAI(model=model, temperature=0)
310
+ elif st.session_state.ai_provider == "groq":
311
+ os.environ["GROQ_API_KEY"] = api_key
312
+ # For Groq, set the base URL and use the model
313
+ from langchain_groq import ChatGroq
314
+ return ChatGroq(model=model, temperature=0)
315
+ except Exception as e:
316
+ print(f"Error initializing LLM: {str(e)}")
317
+ return None
318
+
319
+ # Set up the tools
320
+ tools = [complete_python_task]
321
+ tool_executor = ToolExecutor(tools)
322
+
323
+ # Load the prompt template
324
+ chat_template = ChatPromptTemplate.from_messages([
325
+ ("system", SYSTEM_PROMPT),
326
+ ("placeholder", "{messages}"),
327
+ ])
328
+
329
+ def create_data_summary(state: AgentState) -> str:
330
+ summary = ""
331
+ variables = []
332
+
333
+ # Add sample data for each dataset
334
+ for d in state.get("input_data", []):
335
+ var_name = d.get("variable_name")
336
+ if var_name:
337
+
338
+ variables.append(var_name)
339
+ summary += f"\n\nVariable: {var_name}\n"
340
+ summary += f"Description: {d.get('data_description', 'No description')}\n"
341
+
342
+ # Add sample data if available
343
+ if var_name in st.session_state.in_memory_datasets:
344
+ df = st.session_state.in_memory_datasets[var_name]
345
+ summary += "\nSample Data (first 5 rows):\n"
346
+ summary += df.head(5).to_string()
347
+
348
+ if "current_variables" in state:
349
+ remaining_variables = [v for v in state["current_variables"] if v not in variables and not v.startswith("_")]
350
+
351
+ for v in remaining_variables:
352
+
353
+ var_value = state["current_variables"].get(v)
354
+
355
+ if isinstance(var_value, pd.DataFrame):
356
+ summary += f"\n\nVariable: {v} (DataFrame with shape {var_value.shape})"
357
+ else:
358
+ summary += f"\n\nVariable: {v}"
359
+ return summary
360
+
361
+ def route_to_tools(state: AgentState) -> Literal["tools", "__end__"]:
362
+ """Determine if we should route to tools or end the chain"""
363
+ if messages := state.get("messages", []):
364
+ ai_message = messages[-1]
365
+ else:
366
+ raise ValueError(f"No messages found in input state to tool_edge: {state}")
367
+
368
+ if hasattr(ai_message, "tool_calls") and len(ai_message.tool_calls) > 0:
369
+ return "tools"
370
+
371
+ return "__end__"
372
+
373
+ def call_model(state: AgentState):
374
+ """Call the LLM to get a response"""
375
+ current_data_template = """The following data is available:\n{data_summary}"""
376
+ current_data_message = HumanMessage(
377
+ content=current_data_template.format(data_summary=create_data_summary(state))
378
+ )
379
+ messages = [current_data_message] + state["messages"]
380
+
381
+ # Get the initialized LLM
382
+ llm = initialize_llm()
383
+ if llm is None:
384
+ return {"messages": [AIMessage(content="Please configure a valid API key and model in the settings tab.")]}
385
+
386
+ # Create the model with bound tools
387
+ model = llm.bind_tools(tools)
388
+ model = chat_template | model
389
+
390
+ llm_outputs = model.invoke({"messages": messages})
391
+ return {"messages": [llm_outputs], "intermediate_outputs": [current_data_message.content]}
392
+
393
+ def call_tools(state: AgentState):
394
+ """Execute tools called by the LLM"""
395
+ last_message = state["messages"][-1]
396
+ tool_invocations = []
397
+
398
+ if isinstance(last_message, AIMessage) and hasattr(last_message, 'tool_calls'):
399
+ tool_invocations = [
400
+ ToolInvocation(
401
+ tool=tool_call["name"],
402
+ tool_input={**tool_call["args"], "graph_state": state}
403
+ ) for tool_call in last_message.tool_calls
404
+ ]
405
+ responses = tool_executor.batch(tool_invocations, return_exceptions=True)
406
+
407
+ tool_messages = []
408
+ state_updates = {}
409
+
410
+ for tc, response in zip(last_message.tool_calls, responses):
411
+ if isinstance(response, Exception):
412
+ print(f"Exception in tool execution: {str(response)}")
413
+ tool_messages.append(ToolMessage(
414
+ content=f"Error: {str(response)}",
415
+ name=tc["name"],
416
+ tool_call_id=tc["id"]
417
+ ))
418
+ continue
419
+
420
+ message, updates = response
421
+ tool_messages.append(ToolMessage(
422
+ content=str(message),
423
+ name=tc["name"],
424
+ tool_call_id=tc["id"]
425
+ ))
426
+
427
+ # Merge updates instead of overwriting
428
+ for key, value in updates.items():
429
+ if key in state_updates:
430
+ if isinstance(value, list) and isinstance(state_updates[key], list):
431
+ state_updates[key].extend(value)
432
+ elif isinstance(value, dict) and isinstance(state_updates[key], dict):
433
+ state_updates[key].update(value)
434
+ else:
435
+ state_updates[key] = value
436
+ else:
437
+ state_updates[key] = value
438
+
439
+ if 'messages' not in state_updates:
440
+ state_updates["messages"] = []
441
+
442
+ state_updates["messages"] = tool_messages
443
+ return state_updates
444
+
445
+ # Set up the graph
446
+ workflow = StateGraph(AgentState)
447
+ workflow.add_node("agent", call_model)
448
+ workflow.add_node("tools", call_tools)
449
+ workflow.add_conditional_edges(
450
+ "agent",
451
+ route_to_tools,
452
+ {
453
+ "tools": "tools",
454
+ "__end__": END
455
+ }
456
+ )
457
+ workflow.add_edge("tools", "agent")
458
+ workflow.set_entry_point("agent")
459
+
460
+ chain = workflow.compile()
461
+
462
+ def process_file_upload(files):
463
+ """Process uploaded files and return dataframe previews and column names"""
464
+ st.session_state.in_memory_datasets = {} # Clear previous datasets
465
+ st.session_state.dataset_metadata_list = [] # Clear previous metadata
466
+ st.session_state.persistent_vars.clear() # Clear persistent variables for new session
467
+
468
+ if not files:
469
+ return "No files uploaded.", [], ["No columns available"]
470
+
471
+ results = []
472
+ all_columns = [] # Track all columns from all datasets
473
+
474
+ for file in files:
475
+ try:
476
+ # Use file object directly
477
+ if file.name.endswith('.csv'):
478
+ df = pd.read_csv(file)
479
+ elif file.name.endswith(('.xls', '.xlsx')):
480
+ df = pd.read_excel(file)
481
+ else:
482
+ results.append(f"Unsupported file format: {file.name}. Please upload CSV or Excel files.")
483
+ continue
484
+
485
+ var_name = file.name.split('.')[0].replace('-', '_').replace(' ', '_').lower()
486
+ st.session_state.in_memory_datasets[var_name] = df
487
+
488
+ # Collect all columns
489
+ all_columns.extend(df.columns.tolist())
490
+
491
+ # Create dataset metadata
492
+ dataset_metadata = {
493
+ "variable_name": var_name,
494
+ "data_path": "in_memory",
495
+ "data_description": f"Dataset containing {df.shape[0]} rows and {df.shape[1]} columns. Columns: {', '.join(df.columns.tolist())}",
496
+ "original_filename": file.name
497
+ }
498
+
499
+ st.session_state.dataset_metadata_list.append(dataset_metadata)
500
+
501
+ # Return preview of the dataset
502
+ preview = f"### Dataset: {file.name}\nVariable name: `{var_name}`\n\n"
503
+ preview += df.head(10).to_markdown()
504
+ results.append(preview)
505
+ print(f"Successfully processed {file.name}")
506
+
507
+ except Exception as e:
508
+ print(f"Error processing {file.name}: {str(e)}")
509
+ results.append(f"Error processing {file.name}: {str(e)}")
510
+
511
+ # Get unique columns
512
+ unique_columns = []
513
+ seen = set()
514
+
515
+ for col in all_columns:
516
+ if col not in seen:
517
+ seen.add(col)
518
+ unique_columns.append(col)
519
+
520
+ if not unique_columns:
521
+ unique_columns = ["No columns available"]
522
+
523
+ print(f"Found {len(unique_columns)} unique columns across datasets")
524
+ return "\n\n".join(results), st.session_state.dataset_metadata_list, unique_columns
525
+
526
+ def get_columns():
527
+ """Directly gets columns from in-memory datasets"""
528
+ all_columns = []
529
+
530
+ for var_name, df in st.session_state.in_memory_datasets.items():
531
+ if isinstance(df, pd.DataFrame):
532
+ all_columns.extend(df.columns.tolist())
533
+
534
+ # Remove duplicates while preserving order
535
+ unique_columns = []
536
+ seen = set()
537
+
538
+ for col in all_columns:
539
+ if col not in seen:
540
+ seen.add(col)
541
+ unique_columns.append(col)
542
+
543
+ if not unique_columns:
544
+ unique_columns = ["No columns available"]
545
+
546
+ print(f"Populating dropdowns with {len(unique_columns)} columns")
547
+ return unique_columns
548
+
549
+ # === FUNCTIONS ===
550
+ import openai
551
+ import pandas as pd
552
+ import json
553
+ import re
554
+
555
+ def standard_clean(df):
556
+ df.columns = [re.sub(r'\W+', '_', col).strip().lower() for col in df.columns]
557
+ df.drop_duplicates(inplace=True)
558
+ df.dropna(axis=1, how='all', inplace=True)
559
+ df.dropna(axis=0, how='all', inplace=True)
560
+ for col in df.select_dtypes(include='object').columns:
561
+ df[col] = df[col].astype(str).str.strip()
562
+ return df
563
+
564
+ def query_openai(prompt):
565
+ try:
566
+ # Use the configured API key and model from session state
567
+ api_key = st.session_state.api_key
568
+ model = st.session_state.selected_model
569
+
570
+ if st.session_state.ai_provider == "openai":
571
+ client = openai.OpenAI(api_key=api_key)
572
+ response = client.chat.completions.create(
573
+ model=model,
574
+ messages=[{"role": "user", "content": prompt}],
575
+ temperature=0.7
576
+ )
577
+ return response.choices[0].message.content
578
+ elif st.session_state.ai_provider == "groq":
579
+ from groq import Groq
580
+ client = Groq(api_key=api_key)
581
+ response = client.chat.completions.create(
582
+ model=model,
583
+ messages=[{"role": "user", "content": prompt}],
584
+ temperature=0.7
585
+ )
586
+ return response.choices[0].message.content
587
+ except Exception as e:
588
+ print(f"API Error: {e}")
589
+ return "{}"
590
+
591
+ def llm_suggest_cleaning(df):
592
+ sample = df.head(10).to_csv(index=False)
593
+ prompt = f"""
594
+ You are a professional data wrangler. Below is a sample of a messy dataset.
595
+
596
+ Your task is to study the sample and provide a full data cleaning plan. Follow the structure below and provide **only a valid Python dictionary**.
597
+
598
+ ### Return a dictionary with these keys:
599
+
600
+ 1. **rename_columns** – rename unclear or inconsistent column names
601
+ 2. **convert_types** – convert columns to appropriate datatypes: int, float, str, or date
602
+ 3. **fill_missing** – fill missing values using the most suitable strategy for each column
603
+ 4. **value_map** – map inconsistent values (e.g., yes/Yes/Y → Yes)
604
+ 5. **handle_duplicates** – specify either "exact", "fuzzy", or "both"
605
+ 6. **handle_outliers** – suggest per-column strategies for handling outliers
606
+ 7. **text_standardization** – list columns for:
607
+ - "case_normalization"
608
+ - "remove_special_chars"
609
+ - "normalize_whitespace"
610
+
611
+ 8. **report_summary** – provide an example summary in the following structure:
612
+ {{
613
+ "original_records": 4000,
614
+ "cleaned_records": 200,
615
+ "missing_values": {{"total": 0, "fixed": 0}},
616
+ "duplicates": {{"found": 3800, "removed": 3800}},
617
+ "outliers": {{"found": 60, "fixed": 13}},
618
+ "text_standardization": {{
619
+ "fields_modified": 600,
620
+ "case_fixed": 400,
621
+ "special_chars_fixed": 212,
622
+ "whitespace_fixed": 200
623
+ }},
624
+ "time_taken": "0.00 seconds"
625
+ }}
626
+
627
+ ### For missing values, suggest the most suitable method based on column type and content:
628
+
629
+ - "mean", "median" → for numerical columns
630
+ - "mode" → for categorical columns
631
+ - "knn", "interpolate", "forward_fill", "backward_fill" → for time series
632
+ - constant value (e.g., 0, "Unknown") → if appropriate
633
+ - "drop_rows_with_many_missing" → if a row has too many nulls
634
+
635
+ ### For outliers:
636
+ Per-column, suggest one of the following:
637
+ - "remove"
638
+ - "replace_with_mean"
639
+ - "replace_with_median"
640
+ - "cap_iqr"
641
+ - "cap_zscore"
642
+ - "log_transform"
643
+ - "flag_only"
644
+ - "keep"
645
+
646
+ ### For duplicates:
647
+ Choose: "exact", "fuzzy", or "both"
648
+
649
+ ### For text columns:
650
+ Suggest columns for:
651
+ - "case_normalization"
652
+ - "remove_special_chars"
653
+ - "normalize_whitespace"
654
+
655
+ Do NOT drop any columns unless it’s extremely necessary. Ensure your output is a valid Python dictionary only.
656
+
657
+ Sample data:
658
+ {sample}
659
+ """
660
+ raw_response = query_openai(prompt)
661
+ try:
662
+ suggestions = eval(raw_response)
663
+ return suggestions
664
+ except Exception as e:
665
+ print("Could not parse suggestions:", e)
666
+ return {
667
+ "rename_columns": {},
668
+ "convert_types": {},
669
+ "fill_missing": {},
670
+ "value_map": {},
671
+ "handle_duplicates": "exact",
672
+ "handle_outliers": {},
673
+ "text_standardization": {
674
+ "case_normalization": [],
675
+ "remove_special_chars": [],
676
+ "normalize_whitespace": []
677
+ },
678
+ "report_summary": {}
679
+ }
680
+
681
+ def apply_suggestions(df, suggestions):
682
+ df.rename(columns=suggestions.get("rename_columns", {}), inplace=True)
683
+
684
+ for col, dtype in suggestions.get("convert_types", {}).items():
685
+ if col not in df.columns:
686
+ continue
687
+ try:
688
+ if dtype == "int":
689
+ df[col] = pd.to_numeric(df[col], errors='coerce').astype("Int64")
690
+ elif dtype == "float":
691
+ df[col] = pd.to_numeric(df[col], errors='coerce')
692
+ elif dtype == "str":
693
+ df[col] = df[col].astype(str)
694
+ elif dtype == "date":
695
+ df[col] = pd.to_datetime(df[col], errors='coerce')
696
+ except:
697
+ print(f"Failed to convert {col} to {dtype}")
698
+
699
+ for col, method in suggestions.get("fill_missing", {}).items():
700
+ if col not in df.columns:
701
+ continue
702
+ try:
703
+ if method == "mean":
704
+ df[col].fillna(df[col].mean(), inplace=True)
705
+ elif method == "median":
706
+ df[col].fillna(df[col].median(), inplace=True)
707
+ elif method == "mode":
708
+ df[col].fillna(df[col].mode().iloc[0], inplace=True)
709
+ elif isinstance(method, str):
710
+ df[col].fillna(method, inplace=True)
711
+ except:
712
+ print(f"Could not fill missing values for {col}")
713
+
714
+ for col, mapping in suggestions.get("value_map", {}).items():
715
+ if col in df.columns:
716
+ try:
717
+ df[col] = df[col].replace(mapping)
718
+ except:
719
+ print(f"Could not map values in {col}")
720
+
721
+ return df
722
+
723
+ def capture_dashboard_screenshot():
724
+ """Capture the entire dashboard as a single image"""
725
+ try:
726
+ # Create a figure that combines all dashboard plots
727
+ import plotly.graph_objects as go
728
+ from plotly.subplots import make_subplots
729
+
730
+ # Create a 2x2 subplot
731
+ fig = make_subplots(rows=2, cols=2,
732
+ subplot_titles=["Visualization 1", "Visualization 2",
733
+ "Visualization 3", "Visualization 4"])
734
+
735
+ # Add each plot from the dashboard to the combined figure
736
+ for i, plot in enumerate(st.session_state.dashboard_plots):
737
+ if plot is not None:
738
+ row = (i // 2) + 1
739
+ col = (i % 2) + 1
740
+
741
+ # Extract traces from the original figure and add to our subplot
742
+ for trace in plot.data:
743
+ fig.add_trace(trace, row=row, col=col)
744
+
745
+ # Copy layout properties for each subplot
746
+ for axis_type in ['xaxis', 'yaxis']:
747
+ axis_name = f"{axis_type}{i+1 if i > 0 else ''}"
748
+ subplot_name = f"{axis_type}{row}{col}"
749
+
750
+ # Copy axis properties if they exist
751
+ if hasattr(plot.layout, axis_name):
752
+ axis_props = getattr(plot.layout, axis_name)
753
+ fig.update_layout({subplot_name: axis_props})
754
+
755
+ # Update layout for better appearance
756
+ fig.update_layout(
757
+ height=800,
758
+ width=1000,
759
+ title_text="Dashboard Overview",
760
+ showlegend=False,
761
+ )
762
+
763
+ # Save to a temporary file
764
+ dashboard_path = f"{st.session_state.temp_dir}/dashboard_combined.png"
765
+ fig.write_image(dashboard_path, scale=2) # Higher scale for better resolution
766
+ return dashboard_path
767
+
768
+ except Exception as e:
769
+ import traceback
770
+ print(f"Error capturing dashboard: {str(e)}")
771
+ print(traceback.format_exc())
772
+ return None
773
+
774
+ def generate_enhanced_pdf_report():
775
+ """Generate an enhanced PDF report with proper handling of base64 image data"""
776
+ try:
777
+ # Create a buffer for the PDF
778
+ buffer = io.BytesIO()
779
+
780
+ # Create the PDF document
781
+ doc = SimpleDocTemplate(buffer, pagesize=letter,
782
+ leftMargin=36, rightMargin=36,
783
+ topMargin=36, bottomMargin=36)
784
+
785
+ # Create custom styles with better formatting
786
+ styles = getSampleStyleSheet()
787
+
788
+ # Add custom styles with improved formatting
789
+ styles.add(ParagraphStyle(
790
+ name='ReportTitle',
791
+ parent=styles['Heading1'],
792
+ fontSize=24,
793
+ alignment=1, # Centered
794
+ spaceAfter=20,
795
+ textColor='#2C3E50' # Dark blue color
796
+ ))
797
+
798
+ styles.add(ParagraphStyle(
799
+ name='SectionHeader',
800
+ parent=styles['Heading2'],
801
+ fontSize=16,
802
+ spaceBefore=15,
803
+ spaceAfter=10,
804
+ textColor='#2C3E50',
805
+ borderWidth=1,
806
+ borderColor='#95A5A6',
807
+ borderPadding=5,
808
+ borderRadius=5
809
+ ))
810
+
811
+ styles.add(ParagraphStyle(
812
+ name='SubHeader',
813
+ parent=styles['Heading3'],
814
+ fontSize=14,
815
+ spaceBefore=10,
816
+ spaceAfter=8,
817
+ textColor='#34495E',
818
+ fontWeight='bold'
819
+ ))
820
+ styles.add(ParagraphStyle(
821
+ name='UserMessage',
822
+ parent=styles['Normal'],
823
+ fontSize=11,
824
+ leftIndent=10,
825
+ spaceBefore=8,
826
+ spaceAfter=4
827
+ ))
828
+
829
+ styles.add(ParagraphStyle(
830
+ name='AssistantMessage',
831
+ parent=styles['Normal'],
832
+ fontSize=11,
833
+ leftIndent=10,
834
+ spaceBefore=4,
835
+ spaceAfter=12,
836
+ textColor='#2980B9'
837
+ ))
838
+
839
+ styles.add(ParagraphStyle(
840
+ name='Timestamp',
841
+ parent=styles['Italic'],
842
+ fontSize=10,
843
+ textColor='#7F8C8D',
844
+ alignment=2 # Right aligned
845
+ ))
846
+
847
+ # Create the document content
848
+ elements = []
849
+
850
+ # Add title
851
+ elements.append(Paragraph('Data Analysis Report', styles['ReportTitle']))
852
+
853
+ # Add timestamp
854
+ elements.append(Paragraph(f'Generated on: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}',
855
+ styles['Timestamp']))
856
+ elements.append(Spacer(1, 0.5*inch))
857
+
858
+ # Add conversation history with better formatting
859
+ elements.append(Paragraph('Analysis Conversation History', styles['SectionHeader']))
860
+
861
+ if st.session_state.chat_history:
862
+ for i, (user_msg, assistant_msg) in enumerate(st.session_state.chat_history):
863
+ # Format user message with proper styling
864
+ elements.append(Paragraph(f'<b>You:</b>', styles['SubHeader']))
865
+ user_msg_formatted = user_msg.replace('\n', '<br/>')
866
+ elements.append(Paragraph(user_msg_formatted, styles['UserMessage']))
867
+
868
+ # Process assistant message to handle visualization
869
+ # Look for markdown image syntax with base64 data
870
+ base64_pattern = r'!\[Visualization\]\(data:image\/png;base64,([^\)]+)\)'
871
+
872
+ # Check if the message contains visualizations
873
+ if '### Visualizations' in assistant_msg or re.search(base64_pattern, assistant_msg):
874
+ # Split the message at the Visualizations header if it exists
875
+ if '### Visualizations' in assistant_msg:
876
+ parts = assistant_msg.split('### Visualizations', 1)
877
+ text_part = parts[0]
878
+ viz_part = "### Visualizations" + parts[1] if len(parts) > 1 else ""
879
+ else:
880
+ # If no header but still has visualization
881
+ match = re.search(base64_pattern, assistant_msg)
882
+ text_part = assistant_msg[:match.start()]
883
+ viz_part = assistant_msg[match.start():]
884
+
885
+ # Format the text part
886
+ elements.append(Paragraph(f'<b>Assistant:</b>', styles['SubHeader']))
887
+ text_part = text_part.replace('\n', '<br/>')
888
+ elements.append(Paragraph(text_part, styles['AssistantMessage']))
889
+
890
+ # Process visualizations
891
+ matches = re.findall(base64_pattern, viz_part)
892
+ for j, base64_data in enumerate(matches):
893
+ try:
894
+ # Decode the base64 image
895
+ image_data = base64.b64decode(base64_data)
896
+
897
+ # Create a temporary file for the image
898
+ temp_img_path = f"{st.session_state.temp_dir}/chat_viz_{i}_{j}.png"
899
+
900
+ with open(temp_img_path, 'wb') as f:
901
+ f.write(image_data)
902
+
903
+ # Add the image to the PDF
904
+ elements.append(Paragraph(f'<b>Visualization:</b>', styles['SubHeader']))
905
+ elements.append(Spacer(1, 0.1*inch))
906
+ img = Image(temp_img_path, width=6*inch, height=4*inch)
907
+ elements.append(img)
908
+ elements.append(Spacer(1, 0.2*inch))
909
+ except Exception as e:
910
+ print(f"Error processing base64 image: {str(e)}")
911
+ elements.append(Paragraph(f"[Error displaying visualization: {str(e)}]",
912
+ styles['Normal']))
913
+ else:
914
+ # No visualizations, just format the text
915
+ elements.append(Paragraph(f'<b>Assistant:</b>', styles['SubHeader']))
916
+ assistant_msg_formatted = assistant_msg.replace('\n', '<br/>')
917
+ if len(assistant_msg_formatted) > 1500:
918
+ assistant_msg_formatted = assistant_msg_formatted[:1500] + '...'
919
+ elements.append(Paragraph(assistant_msg_formatted, styles['AssistantMessage']))
920
+
921
+ elements.append(Spacer(1, 0.2*inch))
922
+ else:
923
+ elements.append(Paragraph('No conversation history available.', styles['Normal']))
924
+
925
+ # Force a page break before the dashboard
926
+ elements.append(PageBreak())
927
+
928
+ # Add dashboard section header
929
+ elements.append(Paragraph('Dashboard Overview', styles['SectionHeader']))
930
+ elements.append(Spacer(1, 0.2*inch))
931
+
932
+ # Capture the dashboard as a single image
933
+ dashboard_img_path = capture_dashboard_screenshot()
934
+
935
+ if dashboard_img_path:
936
+ # Calculate available width
937
+ available_width = doc.width
938
+
939
+ # Create PIL image to get dimensions
940
+ pil_img = PILImage.open(dashboard_img_path)
941
+ img_width, img_height = pil_img.size
942
+
943
+ # Calculate scaling factor to fit within page width
944
+ scale_factor = available_width / img_width
945
+
946
+ # Calculate new height based on aspect ratio
947
+ new_height = img_height * scale_factor
948
+
949
+ # Add the image with scaled dimensions
950
+ img = Image(dashboard_img_path, width=available_width, height=new_height)
951
+ elements.append(img)
952
+ else:
953
+ # Fallback: Add individual plots if combined dashboard fails
954
+ plot_count = 0
955
+ for i, plot in enumerate(st.session_state.dashboard_plots):
956
+ if plot is not None:
957
+ plot_count += 1
958
+
959
+ # Convert plotly figure to image
960
+ img_bytes = io.BytesIO()
961
+ plot.write_image(img_bytes, format='png', width=500, height=300)
962
+ img_bytes.seek(0)
963
+
964
+ # Create a temporary file for the image
965
+ temp_img_path = f"{st.session_state.temp_dir}/plot_{i}.png"
966
+
967
+ with open(temp_img_path, 'wb') as f:
968
+ f.write(img_bytes.getvalue())
969
+
970
+ # Add to PDF with appropriate caption and formatting
971
+ elements.append(Paragraph(f'Dashboard Visualization {i+1}', styles['SubHeader']))
972
+ elements.append(Spacer(1, 0.1*inch))
973
+
974
+ # Add the image with proper scaling
975
+ img = Image(temp_img_path, width=6.5*inch, height=4*inch)
976
+ elements.append(img)
977
+ elements.append(Spacer(1, 0.3*inch))
978
+
979
+ if plot_count == 0:
980
+ elements.append(Paragraph('No visualizations have been added to the dashboard.',
981
+ styles['Normal']))
982
+
983
+ # Build the PDF with improved formatting
984
+ doc.build(elements)
985
+
986
+ # Get the value of the buffer
987
+ pdf_value = buffer.getvalue()
988
+ buffer.close()
989
+
990
+ return pdf_value
991
+
992
+ except Exception as e:
993
+ import traceback
994
+ print(f"Error generating enhanced PDF report: {str(e)}")
995
+ print(traceback.format_exc())
996
+ return None
997
+
998
+ def chat_with_workflow(message, history, dataset_info):
999
+ """Send user query to the workflow and get response"""
1000
+
1001
+ if not dataset_info:
1002
+ return "Please upload at least one dataset before asking questions."
1003
+
1004
+ # Check if we have a valid API key and model
1005
+ if not st.session_state.api_key:
1006
+ return "Please set up your API key and model in the Settings tab before chatting."
1007
+
1008
+ print(f"Chat with workflow called with {len(dataset_info)} datasets")
1009
+
1010
+ try:
1011
+ # Extract chat history for context (last 3 exchanges)
1012
+ max_history = 3
1013
+ previous_messages = []
1014
+
1015
+ if history:
1016
+ start_idx = max(0, len(history) - max_history)
1017
+ recent_history = history[start_idx:]
1018
+
1019
+ for exchange in recent_history:
1020
+ if exchange[0]: # User message
1021
+ previous_messages.append(HumanMessage(content=exchange[0]))
1022
+ if exchange[1]: # AI response
1023
+ previous_messages.append(AIMessage(content=exchange[1]))
1024
+
1025
+ # Initialize the workflow state
1026
+ state = AgentState(
1027
+ messages=previous_messages + [HumanMessage(content=message)],
1028
+ input_data=dataset_info,
1029
+ intermediate_outputs=[],
1030
+ current_variables=st.session_state.persistent_vars,
1031
+ output_image_paths=[]
1032
+ )
1033
+
1034
+ # Execute the workflow
1035
+ print("Executing workflow...")
1036
+ result = chain.invoke(state)
1037
+ print("Workflow execution completed")
1038
+
1039
+ # Extract messages from the result
1040
+ messages = result["messages"]
1041
+
1042
+ # Format the response - only get the latest response
1043
+ response = ""
1044
+ if messages:
1045
+ latest_message = messages[-1] # Get only the last message
1046
+ if hasattr(latest_message, "content"):
1047
+ content = latest_message.content
1048
+
1049
+ # Clean up the response
1050
+ # Remove any instances where the user's message is repeated
1051
+ if message in content:
1052
+ content = content.split(message)[-1].strip()
1053
+
1054
+ # Remove any chat history markers
1055
+ content_lines = content.split('\n')
1056
+ filtered_lines = [line for line in content_lines
1057
+ if not line.strip().startswith(("You:", "User:", "Human:", "Assistant:"))]
1058
+ content = '\n'.join(filtered_lines)
1059
+
1060
+ response = content.strip() + "\n\n"
1061
+
1062
+ # Handle visualizations
1063
+ if "output_image_paths" in result and result["output_image_paths"]:
1064
+ response += "### Visualizations\n\n"
1065
+ for img_path in result["output_image_paths"]:
1066
+ try:
1067
+ full_path = os.path.join(st.session_state.images_dir, img_path)
1068
+ with open(full_path, 'rb') as f:
1069
+ fig = pickle.load(f)
1070
+
1071
+ # Convert plotly figure to image
1072
+ img_bytes = BytesIO()
1073
+ fig.update_layout(width=800, height=500)
1074
+ pio.write_image(fig, img_bytes, format='png')
1075
+ img_bytes.seek(0)
1076
+
1077
+ # Convert to base64 for markdown image
1078
+ b64_img = base64.b64encode(img_bytes.read()).decode()
1079
+ response += f"![Visualization](data:image/png;base64,{b64_img})\n\n"
1080
+ except Exception as e:
1081
+ response += f"Error loading visualization: {str(e)}\n\n"
1082
+
1083
+ return response
1084
+
1085
+ except Exception as e:
1086
+ import traceback
1087
+ print(f"Error in chat_with_workflow: {str(e)}")
1088
+ print(traceback.format_exc())
1089
+ return f"Error executing workflow: {str(e)}"
1090
+
1091
+ def auto_generate_dashboard(dataset_info):
1092
+ """Generate an automatic dashboard with four plots"""
1093
+
1094
+ if not dataset_info:
1095
+ return "Please upload a dataset first.", [None, None, None, None]
1096
+
1097
+ prompt = """
1098
+ You are a data visualization expert. Given a dataset, identify the top 4 most insightful plots using statistical reasoning or patterns (correlation, distribution, trends).
1099
+
1100
+ Use plotly and store the plots in a list named plotly_figures.
1101
+
1102
+ Include multivariate plots using color/size/facets when helpful.
1103
+ """
1104
+
1105
+ state = AgentState(
1106
+ messages=[HumanMessage(content=prompt)],
1107
+ input_data=dataset_info,
1108
+ intermediate_outputs=[],
1109
+ current_variables=st.session_state.persistent_vars,
1110
+ output_image_paths=[]
1111
+ )
1112
+
1113
+ result = chain.invoke(state)
1114
+ figures = []
1115
+
1116
+ if "output_image_paths" in result:
1117
+ for img_path in result["output_image_paths"][:4]:
1118
+ try:
1119
+ full_path = os.path.join(st.session_state.images_dir, img_path)
1120
+ with open(full_path, 'rb') as f:
1121
+ fig = pickle.load(f)
1122
+ figures.append(fig)
1123
+ except Exception as e:
1124
+ print(f"Error loading figure: {e}")
1125
+
1126
+ while len(figures) < 4:
1127
+ figures.append(None)
1128
+
1129
+ st.session_state.dashboard_plots = figures
1130
+ return "Dashboard generated!", figures
1131
+
1132
+ def generate_custom_plots_with_llm(dataset_info, x_col, y_col, facet_col):
1133
+ """Generate custom plots based on user-selected columns"""
1134
+
1135
+ if not dataset_info or not x_col or not y_col:
1136
+ return [None, None, None]
1137
+
1138
+ prompt = f"""
1139
+ You are a data visualization expert.
1140
+
1141
+ Create 3 insightful visualizations using Plotly based on:
1142
+
1143
+ - X-axis: {x_col}
1144
+ - Y-axis: {y_col}
1145
+ - Facet (optional): {facet_col if facet_col != 'None' else 'None'}
1146
+
1147
+ Try to find interesting relationships, trends, or clusters using appropriate chart types.
1148
+
1149
+ Use `plotly_figures` list and avoid using fig.show().
1150
+ """
1151
+
1152
+ state = AgentState(
1153
+ messages=[HumanMessage(content=prompt)],
1154
+ input_data=dataset_info,
1155
+ intermediate_outputs=[],
1156
+ current_variables=st.session_state.persistent_vars,
1157
+ output_image_paths=[]
1158
+ )
1159
+
1160
+ result = chain.invoke(state)
1161
+ figures = []
1162
+
1163
+ if "output_image_paths" in result:
1164
+ for img_path in result["output_image_paths"][:3]:
1165
+ try:
1166
+ full_path = os.path.join(st.session_state.images_dir, img_path)
1167
+ with open(full_path, 'rb') as f:
1168
+ fig = pickle.load(f)
1169
+ figures.append(fig)
1170
+ except Exception as e:
1171
+ print(f"Error loading figure: {e}")
1172
+
1173
+ while len(figures) < 3:
1174
+ figures.append(None)
1175
+ return figures
1176
+
1177
+ def remove_plot(index):
1178
+ """Remove a plot from the dashboard"""
1179
+ if 0 <= index < len(st.session_state.dashboard_plots):
1180
+ st.session_state.dashboard_plots[index] = None
1181
+
1182
+ def respond(message):
1183
+ """Handle chat message response"""
1184
+ if not st.session_state.dataset_metadata_list:
1185
+ bot_message = "Please upload at least one dataset before asking questions."
1186
+ else:
1187
+ bot_message = chat_with_workflow(message, st.session_state.chat_history, st.session_state.dataset_metadata_list)
1188
+
1189
+ st.session_state.chat_history.append((message, bot_message))
1190
+ st.rerun()
1191
+
1192
+ def save_plot_to_dashboard(plot_index):
1193
+ """Callback for the Add Plot button"""
1194
+ for i in range(len(st.session_state.dashboard_plots)):
1195
+ if st.session_state.dashboard_plots[i] is None:
1196
+ # Found an empty slot
1197
+ st.session_state.dashboard_plots[i] = st.session_state.custom_plots_to_save[plot_index]
1198
+ return
1199
+
1200
+ # New function to check if settings are valid
1201
+ def is_settings_valid():
1202
+ """Check if API key and model are configured"""
1203
+ return st.session_state.api_key != ""
1204
+
1205
+ # Streamlit UI with left panel settings
1206
+ st.set_page_config(page_title="QueryMind 🧠", layout="wide")
1207
+
1208
+ # Create side panel for settings
1209
+ with st.sidebar:
1210
+ st.header("Settings")
1211
+ st.info("⚠️ Configure your API settings before uploading data")
1212
+
1213
+ # AI Provider selection
1214
+ provider = st.radio("Select AI Provider",
1215
+ options=["OpenAI", "Groq"],
1216
+ index=0 if st.session_state.ai_provider == "openai" else 1,
1217
+ horizontal=True)
1218
+
1219
+ # Update session state based on selection
1220
+ st.session_state.ai_provider = provider.lower()
1221
+
1222
+ # API Key input
1223
+ api_key = st.text_input("Enter API Key",
1224
+ value=st.session_state.api_key,
1225
+ type="password",
1226
+ help="Your API key for the selected provider")
1227
+
1228
+ # Display different model options based on provider
1229
+ if st.session_state.ai_provider == "openai":
1230
+ model_options = OPENAI_MODELS
1231
+ model_help = "GPT-4 provides the best results but is slower. GPT-3.5-Turbo is faster but less capable."
1232
+ else: # groq
1233
+ model_options = GROQ_MODELS
1234
+ model_help = "Llama 3.3 70B is most capable. Gemma 2 9B offers good balance. Llama 3 8B is fastest."
1235
+
1236
+ # Model selection
1237
+ selected_model = st.selectbox("Select Model",
1238
+ options=model_options,
1239
+ index=model_options.index(st.session_state.selected_model) if st.session_state.selected_model in model_options else 0,
1240
+ help=model_help)
1241
+
1242
+ # Save button
1243
+ if st.button("Save Settings"):
1244
+ st.session_state.api_key = api_key
1245
+ st.session_state.selected_model = selected_model
1246
+
1247
+ # Test the API key and model
1248
+ try:
1249
+ # Initialize LLM using the provided settings
1250
+ test_llm = initialize_llm()
1251
+ if test_llm:
1252
+ st.success(f"✅ Successfully configured {provider} with model: {selected_model}")
1253
+ else:
1254
+ st.error("Failed to initialize the AI provider. Please check your API key and model selection.")
1255
+ except Exception as e:
1256
+ st.error(f"Error testing settings: {str(e)}")
1257
+
1258
+ # Display current settings
1259
+ st.subheader("Current Settings")
1260
+ settings_info = f"""
1261
+ - **Provider**: {st.session_state.ai_provider.upper()}
1262
+ - **Model**: {st.session_state.selected_model}
1263
+ - **API Key**: {'✅ Set' if st.session_state.api_key else '❌ Not Set'}
1264
+ """
1265
+ st.markdown(settings_info)
1266
+
1267
+ # Provider-specific information
1268
+ if st.session_state.ai_provider == "openai":
1269
+ with st.expander("OpenAI Models Info"):
1270
+ st.info("""
1271
+ - **GPT-4**: Most powerful model, best for complex analysis and detailed explanations
1272
+ - **GPT-4-Turbo**: Faster than GPT-4 with similar capabilities
1273
+ - **GPT-4-Mini**: Economical option with good performance for standard tasks
1274
+ - **GPT-3.5-Turbo**: Fastest option, suitable for basic analysis and visualization
1275
+ """)
1276
+ else:
1277
+ with st.expander("Groq Models Info"):
1278
+ st.info("""
1279
+ - **llama3.3-70b-versatile**: Most powerful model for comprehensive analysis
1280
+ - **gemma2-9b-it**: Good balance of speed and capabilities
1281
+ - **llama-3-8b-8192**: Fastest option for basic analysis tasks
1282
+ """)
1283
+
1284
+ # Integration instructions
1285
+ with st.expander("How to get API Keys"):
1286
+ if st.session_state.ai_provider == "openai":
1287
+ st.markdown("""
1288
+ ### Getting an OpenAI API Key
1289
+
1290
+ 1. Go to [OpenAI's platform](https://platform.openai.com)
1291
+ 2. Sign up or log in to your account
1292
+ 3. Navigate to the API section
1293
+ 4. Create a new API key
1294
+ 5. Copy the key and paste it above
1295
+
1296
+ Note: OpenAI API usage incurs charges based on tokens used.
1297
+ """)
1298
+ else:
1299
+ st.markdown("""
1300
+ ### Getting a Groq API Key
1301
+
1302
+ 1. Go to [Groq's website](https://console.groq.com/keys)
1303
+ 2. Sign up or log in to your account
1304
+ 3. Navigate to API Keys section
1305
+ 4. Create a new API key
1306
+ 5. Copy the key and paste it above
1307
+
1308
+ Note: Check Groq's pricing page for current rates.
1309
+ """)
1310
+
1311
+ # Main content
1312
+ st.title(" QueryMind 🧠 - Data Assistant ")
1313
+ st.markdown("Upload your datasets, ask questions, and generate visualizations to gain insights.")
1314
+
1315
+ # Check if settings are valid before showing tabs
1316
+ if not is_settings_valid():
1317
+ st.warning("⚠️ Please configure your API key in the sidebar settings panel before proceeding.")
1318
+ st.stop()
1319
+
1320
+ # Create main tabs - only if settings are valid
1321
+ tab1, tab2, tab3, tab4, tab5 = st.tabs(["Upload Datasets", "Data Cleaning", "Chat with AI Assistant", "Auto Dashboard Generator", "Generate Report"])
1322
+
1323
+ with tab1:
1324
+ st.header("Upload Datasets")
1325
+ uploaded_files = st.file_uploader("Upload CSV or Excel Files",
1326
+ accept_multiple_files=True,
1327
+ type=['csv', 'xlsx', 'xls'])
1328
+
1329
+ if uploaded_files and st.button("Process Uploaded Files"):
1330
+ with st.spinner("Processing files..."):
1331
+ preview, metadata_list, columns = process_file_upload(uploaded_files)
1332
+ st.session_state.columns = columns
1333
+
1334
+ # Display basic information about processed files
1335
+ st.success(f"✅ Successfully processed {len(uploaded_files)} file(s)")
1336
+
1337
+ # Show detailed preview for each dataset
1338
+ st.subheader("Dataset Previews")
1339
+
1340
+ for dataset_name, df in st.session_state.in_memory_datasets.items():
1341
+ with st.expander(f"Preview: {dataset_name}"):
1342
+ # Display dataset info
1343
+ st.write(f"**Rows:** {df.shape[0]} | **Columns:** {df.shape[1]}")
1344
+
1345
+ # Display column information
1346
+ col_info = pd.DataFrame({
1347
+ 'Column Name': df.columns,
1348
+ 'Data Type': df.dtypes.astype(str),
1349
+ 'Non-Null Count': df.count().values,
1350
+ 'Sample Values': [', '.join(df[col].dropna().astype(str).head(3).tolist()) for col in df.columns]
1351
+ })
1352
+
1353
+ # Show column information in a compact table
1354
+ st.write("**Column Information:**")
1355
+ st.dataframe(col_info, use_container_width=True)
1356
+
1357
+ # Show actual data preview
1358
+ st.write("**Data Preview (First 10 rows):**")
1359
+ st.dataframe(df.head(10), use_container_width=True)
1360
+
1361
+ # Provide hint for the next steps
1362
+ st.info("👆 Click on the dataset names above to see detailed previews. Then proceed to the Data Cleaning tab to clean your data or Chat with AI Assistant to analyze it.")
1363
+
1364
+ with tab2:
1365
+ st.header("Data Cleaning")
1366
+
1367
+ if 'cleaning_done' not in st.session_state:
1368
+ st.session_state.cleaning_done = False
1369
+
1370
+ if 'cleaned_datasets' not in st.session_state:
1371
+ st.session_state.cleaned_datasets = {}
1372
+
1373
+ if 'cleaning_summaries' not in st.session_state:
1374
+ st.session_state.cleaning_summaries = {}
1375
+
1376
+ if st.session_state.get("in_memory_datasets"):
1377
+ if not st.session_state.cleaning_done:
1378
+ if st.button("Run Data Cleaning"):
1379
+ with st.spinner("Running LLM-assisted cleaning..."):
1380
+ for name, df in st.session_state.in_memory_datasets.items():
1381
+ raw_df = df.copy()
1382
+ df_std = standard_clean(raw_df.copy())
1383
+ suggestions = llm_suggest_cleaning(df_std.copy())
1384
+ df_clean = apply_suggestions(df_std.copy(), suggestions)
1385
+ st.session_state.cleaned_datasets[name] = df_clean
1386
+ st.session_state.cleaning_summaries[name] = suggestions
1387
+ st.session_state.cleaning_done = True
1388
+ st.rerun()
1389
+ else:
1390
+ st.info("Click Run Data Cleaning to clean your datasets using the LLM.")
1391
+ else:
1392
+ for name, df_clean in st.session_state.cleaned_datasets.items():
1393
+ raw_df = st.session_state.in_memory_datasets[name]
1394
+
1395
+ st.subheader(f"Dataset: {name}")
1396
+ col1, col2 = st.columns(2)
1397
+
1398
+ with col1:
1399
+ st.markdown("Original Data (First 5 Rows)")
1400
+ st.dataframe(raw_df.head())
1401
+
1402
+ with col2:
1403
+ st.markdown("Cleaned Data (First 5 Rows)")
1404
+ st.dataframe(df_clean.head())
1405
+
1406
+ st.markdown("Summary of Cleaning Actions")
1407
+ suggestions = st.session_state.cleaning_summaries[name]
1408
+ summary_text = ""
1409
+
1410
+ if suggestions:
1411
+ for key, value in suggestions.items():
1412
+ summary_text += f"**{key}**: {json.dumps(value, indent=2)}\n\n"
1413
+ st.markdown(summary_text)
1414
+
1415
+ st.markdown("Refine the Cleaning (Natural Language Instructions)")
1416
+ user_input = st.text_input("Example: Convert 'dob' to datetime and fill missing with '2000-01-01'",
1417
+ key=f"user_input_{name}")
1418
+
1419
+ if f'corrections_{name}' not in st.session_state:
1420
+ st.session_state[f'corrections_{name}'] = []
1421
+ if st.button("Apply Correction", key=f'apply_correction_{name}'):
1422
+ if user_input.strip():
1423
+ correction_prompt = f"""
1424
+ You are a data cleaning expert working with pandas. Here's a summary of the previous cleaning:
1425
+
1426
+ {summary_text}
1427
+
1428
+ The user now asks:
1429
+ \"{user_input.strip()}\"
1430
+
1431
+ Please return Python code that modifies the existing pandas DataFrame `df`. The code must:
1432
+ - Assume `df` is already loaded
1433
+ - Perform the action as described
1434
+ - End with `df` being the modified DataFrame
1435
+ - Please include all the libaries that needed to imported in order to run the code in the same code snippet
1436
+
1437
+ ONLY return the code. Do not include explanations, markdown, or extra text.
1438
+ """
1439
+ correction_code = query_openai(correction_prompt)
1440
+
1441
+ # Clean LLM code block if wrapped in markdown syntax
1442
+ if correction_code.startswith("```"):
1443
+ correction_code = correction_code.strip().strip("`") # remove backticks
1444
+ correction_code = correction_code.replace("python", "", 1).strip()
1445
+
1446
+
1447
+ try:
1448
+ df = st.session_state.cleaned_datasets[name].copy()
1449
+ local_vars = {"df": df}
1450
+
1451
+ # Ensure code is executable and updates `df`
1452
+ exec(correction_code, {}, local_vars)
1453
+ df_updated = local_vars.get("df")
1454
+
1455
+ if df_updated is not None and isinstance(df_updated, pd.DataFrame):
1456
+ st.session_state.cleaned_datasets[name] = df_updated
1457
+ st.session_state[f'corrections_{name}'].append((user_input, correction_code))
1458
+ st.success("Correction applied.")
1459
+ st.rerun()
1460
+ else:
1461
+ st.warning("LLM did not return a valid DataFrame. Here's the response:")
1462
+ st.code(correction_code, language="python")
1463
+
1464
+ except Exception as e:
1465
+ st.error(f"Failed to apply correction: {str(e)}")
1466
+ st.code(correction_code, language="python")
1467
+
1468
+ with tab3:
1469
+ st.header("Chat with AI Assistant")
1470
+
1471
+ st.markdown("""
1472
+ ## Example Questions
1473
+ - "What analysis can you perform on this dataset?"
1474
+ - "Show me basic statistics for all columns"
1475
+ - "Create a correlation heatmap"
1476
+ - "Plot the distribution of a specific column"
1477
+ - "What is the relationship between two columns?"
1478
+ """)
1479
+
1480
+ # Display chat history
1481
+ for exchange in st.session_state.chat_history:
1482
+ with st.chat_message("user"):
1483
+ st.write(exchange[0])
1484
+ with st.chat_message("assistant"):
1485
+ st.write(exchange[1])
1486
+
1487
+ # Chat input
1488
+ if prompt := st.chat_input("Your question"):
1489
+ with st.spinner("Thinking..."):
1490
+ respond(prompt)
1491
+
1492
+ with tab4:
1493
+ st.header("Auto Dashboard Generator")
1494
+
1495
+ # Dashboard controls
1496
+ dashboard_title = st.text_input("Dashboard Title", placeholder="Enter your dashboard title")
1497
+
1498
+ col1, col2 = st.columns(2)
1499
+
1500
+ with col1:
1501
+ if st.button("Generate Suggested Dashboard (Auto)"):
1502
+ with st.spinner("Generating dashboard..."):
1503
+ message, figures = auto_generate_dashboard(st.session_state.dataset_metadata_list)
1504
+ st.success(message)
1505
+
1506
+ with col2:
1507
+ if st.button("Refresh Column Options"):
1508
+ st.session_state.columns = get_columns()
1509
+ st.rerun()
1510
+
1511
+ # Dashboard display
1512
+ st.subheader("Dashboard")
1513
+
1514
+ # Row 1
1515
+ col1, col2 = st.columns(2)
1516
+
1517
+ with col1:
1518
+ if st.session_state.dashboard_plots[0]:
1519
+ st.plotly_chart(st.session_state.dashboard_plots[0], use_container_width=True)
1520
+ if st.button("Remove Plot 1"):
1521
+ remove_plot(0)
1522
+ st.rerun()
1523
+
1524
+ with col2:
1525
+ if st.session_state.dashboard_plots[1]:
1526
+ st.plotly_chart(st.session_state.dashboard_plots[1], use_container_width=True)
1527
+ if st.button("Remove Plot 2"):
1528
+ remove_plot(1)
1529
+ st.rerun()
1530
+
1531
+ # Row 2
1532
+ col3, col4 = st.columns(2)
1533
+
1534
+ with col3:
1535
+ if st.session_state.dashboard_plots[2]:
1536
+ st.plotly_chart(st.session_state.dashboard_plots[2], use_container_width=True)
1537
+ if st.button("Remove Plot 3"):
1538
+ remove_plot(2)
1539
+ st.rerun()
1540
+
1541
+ with col4:
1542
+ if st.session_state.dashboard_plots[3]:
1543
+ st.plotly_chart(st.session_state.dashboard_plots[3], use_container_width=True)
1544
+ if st.button("Remove Plot 4"):
1545
+ remove_plot(3)
1546
+ st.rerun()
1547
+
1548
+ # Custom plot generator
1549
+ st.subheader("Custom Plot Generator")
1550
+
1551
+ # Column selection
1552
+ col1, col2, col3 = st.columns(3)
1553
+
1554
+ with col1:
1555
+ x_axis = st.selectbox("X-axis Column", options=st.session_state.columns)
1556
+
1557
+ with col2:
1558
+ y_axis = st.selectbox("Y-axis Column", options=st.session_state.columns)
1559
+
1560
+ with col3:
1561
+ facet = st.selectbox("Facet (optional)", options=["None"] + st.session_state.columns)
1562
+
1563
+ if st.button("Generate Custom Visualizations"):
1564
+ with st.spinner("Generating custom visualizations..."):
1565
+ custom_plots = generate_custom_plots_with_llm(st.session_state.dataset_metadata_list, x_axis, y_axis, facet)
1566
+ # Store plots in session state
1567
+ for i, plot in enumerate(custom_plots):
1568
+ if plot:
1569
+ st.session_state.custom_plots_to_save[i] = plot
1570
+
1571
+ # Display custom plots with add buttons
1572
+ for i, plot in enumerate(custom_plots):
1573
+ if plot:
1574
+ st.plotly_chart(plot, use_container_width=True)
1575
+ st.button(
1576
+ f"Add Plot {i+1} to Dashboard",
1577
+ key=f"add_plot_{i}",
1578
+ on_click=save_plot_to_dashboard,
1579
+ args=(i,)
1580
+ )
1581
+
1582
+ with tab5:
1583
+ st.header("Generate Analysis Report")
1584
+
1585
+ st.markdown("""
1586
+ Generate a PDF report containing:
1587
+ - Dashboard visualizations
1588
+ - Chat conversation history
1589
+ """)
1590
+
1591
+ report_title = st.text_input("Report Title (Optional)", "Data Analysis Report")
1592
+
1593
+ if st.button("Generate PDF Report"):
1594
+ with st.spinner("Generating report..."):
1595
+ pdf_data = generate_enhanced_pdf_report()
1596
+ if pdf_data:
1597
+ # Create download button for PDF
1598
+ b64_pdf = base64.b64encode(pdf_data).decode('utf-8')
1599
+ # Create download link
1600
+ pdf_download_link = f'<a href="data:application/pdf;base64,{b64_pdf}" download="data_analysis_report.pdf">Download PDF Report</a>'
1601
+ st.markdown("### Your report is ready!")
1602
+ st.markdown(pdf_download_link, unsafe_allow_html=True)
1603
+ # Preview option (simplified)
1604
+ with st.expander("Preview Report"):
1605
+ st.warning("PDF preview is not available in Streamlit, please download the report to view it.")
1606
+ else:
1607
+ st.error("Failed to generate the report. Please try again.")
1608
+
1609
+ # Cleanup on app exit
1610
+ def cleanup():
1611
+ try:
1612
+ shutil.rmtree(st.session_state.temp_dir)
1613
+ print(f"Cleaned up temporary directory: {st.session_state.temp_dir}")
1614
+ except Exception as e:
1615
+ print(f"Error cleaning up: {e}")
1616
+
1617
+ import atexit
1618
+ atexit.register(cleanup)
requirements.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ pandas
3
+ numpy
4
+ plotly
5
+ langchain
6
+ langgraph==0.2.74
7
+ langchain-core
8
+ langchain-groq
9
+ langchain_openai
10
+ langchain-experimental
11
+ reportlab
12
+ Pillow
13
+ scikit-learn
14
+ tabulate
15
+ kaleido