pavanmutha commited on
Commit
01fbc32
Β·
verified Β·
1 Parent(s): ca79b7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +299 -277
app.py CHANGED
@@ -9,15 +9,23 @@ import psutil
9
  import optuna
10
  import ast
11
  import pandas as pd
 
12
  from sklearn.model_selection import train_test_split
13
  from sklearn.ensemble import RandomForestClassifier
14
- from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
 
 
15
  import shap
16
  import lime
17
  import lime.lime_tabular
18
  import matplotlib.pyplot as plt
19
- import numpy as np
20
- from optuna.visualization import plot_optimization_history, plot_param_importances
 
 
 
 
 
21
 
22
  # Authenticate Hugging Face
23
  hf_token = os.getenv("HF_TOKEN")
@@ -26,312 +34,326 @@ login(token=hf_token, add_to_git_credential=True)
26
  # Initialize Model
27
  model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
28
 
29
- def format_analysis_report(raw_output, visuals):
30
- try:
31
- if isinstance(raw_output, dict):
32
- analysis_dict = raw_output
33
- else:
34
- try:
35
- analysis_dict = ast.literal_eval(str(raw_output))
36
- except (SyntaxError, ValueError) as e:
37
- print(f"Error parsing CodeAgent output: {e}")
38
- return str(raw_output), visuals # Return raw output as string
39
-
40
- report = f"""
41
- <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
42
- <h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">πŸ“Š Data Analysis Report</h1>
43
- <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
44
- <h2 style="color: #2B547E;">πŸ” Key Observations</h2>
45
- {format_observations(analysis_dict.get('observations', {}))}
46
- </div>
47
- <div style="margin-top: 30px;">
48
- <h2 style="color: #2B547E;">πŸ’‘ Insights & Visualizations</h2>
49
- {format_insights(analysis_dict.get('insights', {}), visuals)}
50
- </div>
51
- </div>
52
- """
53
- return report, visuals
54
- except Exception as e:
55
- print(f"Error in format_analysis_report: {e}")
56
- return str(raw_output), visuals
57
 
58
- def format_observations(observations):
59
- return '\n'.join([
60
- f"""
61
- <div style="margin: 15px 0; padding: 15px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
62
- <h3 style="margin: 0 0 10px 0; color: #4A708B;">{key.replace('_', ' ').title()}</h3>
63
- <pre style="margin: 0; padding: 10px; background: #f8f9fa; border-radius: 4px;">{value}</pre>
64
- </div>
65
- """ for key, value in observations.items() if 'proportions' in key
66
- ])
67
 
68
- def format_insights(insights, visuals):
69
- return '\n'.join([
70
- f"""
71
- <div style="margin: 20px 0; padding: 20px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
72
- <div style="display: flex; align-items: center; gap: 10px;">
73
- <div style="background: #2B547E; color: white; width: 30px; height: 30px; border-radius: 50%; display: flex; align-items: center; justify-content: center;">{idx+1}</div>
74
- <p style="margin: 0; font-size: 16px;">{insight}</p>
75
- </div>
76
- {f'<img src="/file={visuals[idx]}" style="max-width: 100%; height: auto; margin-top: 10px; border-radius: 6px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">' if idx < len(visuals) else ''}
77
- </div>
78
- """ for idx, (key, insight) in enumerate(insights.items())
79
- ])
80
 
81
- def analyze_data(csv_file, additional_notes=""):
82
- start_time = time.time()
83
- process = psutil.Process(os.getpid())
84
- initial_memory = process.memory_info().rss / 1024 ** 2
85
-
86
- if os.path.exists('./figures'):
87
- shutil.rmtree('./figures')
88
- os.makedirs('./figures', exist_ok=True)
89
-
90
- wandb.login(key=os.environ.get('WANDB_API_KEY'))
91
- run = wandb.init(project="huggingface-data-analysis", config={
92
- "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
93
- "additional_notes": additional_notes,
94
- "source_file": csv_file.name if csv_file else None
95
- })
96
 
97
- agent = CodeAgent(tools=[], model=model, additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn", "sklearn"])
98
- analysis_result = agent.run("""
99
- You are an expert data analyst. Perform comprehensive analysis including:
100
- 1. Basic statistics and data quality checks
101
- 2. 3 insightful analytical questions about relationships in the data
102
- 3. Visualization of key patterns and correlations
103
- 4. Actionable real-world insights derived from findings.
104
- Generate publication-quality visualizations and save to './figures/'.
105
- Return the analysis results as a python dictionary that can be parsed by ast.literal_eval().
106
- The dictionary should have the following structure:
107
- {
108
- 'observations': {
109
- 'observation_1_key': 'observation_1_value',
110
- 'observation_2_key': 'observation_2_value',
111
- ...
112
- },
113
- 'insights': {
114
- 'insight_1_key': 'insight_1_value',
115
- 'insight_2_key': 'insight_2_value',
116
- ...
117
- }
118
- }
119
- """, additional_args={"additional_notes": additional_notes, "source_file": csv_file})
120
 
121
- execution_time = time.time() - start_time
122
- final_memory = process.memory_info().rss / 1024 ** 2
123
- memory_usage = final_memory - initial_memory
124
- wandb.log({"execution_time_sec": execution_time, "memory_usage_mb": memory_usage})
125
 
126
- visuals = [os.path.join('./figures', f) for f in os.listdir('./figures') if f.endswith(('.png', '.jpg', '.jpeg'))]
127
- for viz in visuals:
128
- wandb.log({os.path.basename(viz): wandb.Image(viz)})
 
 
 
129
 
130
- run.finish()
131
- return format_analysis_report(analysis_result, visuals)
132
 
133
- def objective(trial, X_train, y_train, X_test, y_test):
134
- # Enhanced hyperparameter space
135
- n_estimators = trial.suggest_int("n_estimators", 50, 500, step=50)
136
- max_depth = trial.suggest_int("max_depth", 3, 15)
137
- min_samples_split = trial.suggest_int("min_samples_split", 2, 10)
138
- min_samples_leaf = trial.suggest_int("min_samples_leaf", 1, 5)
139
- max_features = trial.suggest_categorical("max_features", ["sqrt", "log2", None])
140
- bootstrap = trial.suggest_categorical("bootstrap", [True, False])
141
- criterion = trial.suggest_categorical("criterion", ["gini", "entropy"])
142
-
143
- model = RandomForestClassifier(
144
- n_estimators=n_estimators,
145
- max_depth=max_depth,
146
- min_samples_split=min_samples_split,
147
- min_samples_leaf=min_samples_leaf,
148
- max_features=max_features,
149
- bootstrap=bootstrap,
150
- criterion=criterion,
151
- random_state=42,
152
- n_jobs=-1
153
- )
154
-
155
- model.fit(X_train, y_train)
156
- predictions = model.predict(X_test)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
- # Track multiple metrics
159
- accuracy = accuracy_score(y_test, predictions)
160
- precision = precision_score(y_test, predictions, average='weighted', zero_division=0)
161
- recall = recall_score(y_test, predictions, average='weighted', zero_division=0)
162
- f1 = f1_score(y_test, predictions, average='weighted', zero_division=0)
163
 
164
- # Log metrics to W&B
165
- wandb.log({
166
- "trial_accuracy": accuracy,
167
- "trial_precision": precision,
168
- "trial_recall": recall,
169
- "trial_f1": f1,
170
- "n_estimators": n_estimators,
171
- "max_depth": max_depth,
172
- "min_samples_split": min_samples_split,
173
- "min_samples_leaf": min_samples_leaf,
174
- "max_features": str(max_features),
175
- "bootstrap": bootstrap,
176
- "criterion": criterion
177
  })
178
 
179
- return accuracy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
- def tune_hyperparameters(csv_file, n_trials: int):
182
- # Initialize W&B run
183
- wandb.login(key=os.environ.get('WANDB_API_KEY'))
184
- run = wandb.init(project="hyperparameter-optimization",
185
- config={"n_trials": n_trials, "model_type": "RandomForest"})
 
 
 
 
 
 
186
 
187
- df = pd.read_csv(csv_file)
188
- y = df.iloc[:, -1]
189
- X = df.iloc[:, :-1]
190
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
191
 
192
- # Create study with enhanced settings
193
- study = optuna.create_study(
194
- direction="maximize",
195
- sampler=optuna.samplers.TPESampler(),
196
- pruner=optuna.pruners.MedianPruner(n_warmup_steps=5)
197
- )
198
-
199
- # Run optimization
200
- study.optimize(lambda trial: objective(trial, X_train, y_train, X_test, y_test),
201
- n_trials=n_trials,
202
- callbacks=[wandb_callback])
203
-
204
- # Get best trial results
205
- best_params = study.best_params
206
- best_value = study.best_value
207
-
208
- # Train final model with best parameters
209
- final_model = RandomForestClassifier(**best_params, random_state=42, n_jobs=-1)
210
- final_model.fit(X_train, y_train)
211
- final_predictions = final_model.predict(X_test)
212
-
213
- # Calculate final metrics
214
- accuracy = accuracy_score(y_test, final_predictions)
215
- precision = precision_score(y_test, final_predictions, average='weighted', zero_division=0)
216
- recall = recall_score(y_test, final_predictions, average='weighted', zero_division=0)
217
- f1 = f1_score(y_test, final_predictions, average='weighted', zero_division=0)
218
-
219
- # Generate optimization visualizations
220
- optimization_history = plot_optimization_history(study)
221
- param_importance = plot_param_importances(study)
222
-
223
- # Save visualizations
224
- os.makedirs('./figures', exist_ok=True)
225
- history_path = "./figures/optimization_history.png"
226
- importance_path = "./figures/param_importance.png"
227
 
228
- optimization_history.figure.savefig(history_path)
229
- param_importance.figure.savefig(importance_path)
230
 
231
- # Generate SHAP and LIME explanations
232
- shap_explainer = shap.TreeExplainer(final_model)
233
- shap_values = shap_explainer.shap_values(X_test)
234
- shap.summary_plot(shap_values, X_test, show=False)
235
- shap_fig_path = "./figures/shap_summary.png"
236
- plt.savefig(shap_fig_path)
237
- plt.clf()
238
 
239
- lime_explainer = lime.lime_tabular.LimeTabularExplainer(
240
- X_train.values,
241
- feature_names=X_train.columns,
242
- class_names=['target'],
243
- mode='classification'
244
- )
245
- lime_explanation = lime_explainer.explain_instance(
246
- X_test.iloc[0].values,
247
- final_model.predict_proba
248
- )
249
- lime_fig = lime_explanation.as_pyplot_figure()
250
- lime_fig_path = "./figures/lime_explanation.png"
251
- lime_fig.savefig(lime_fig_path)
252
- plt.clf()
253
 
254
- # Log everything to W&B
255
- wandb.log({
256
- "best_params": best_params,
257
- "best_accuracy": best_value,
258
- "final_accuracy": accuracy,
259
- "final_precision": precision,
260
- "final_recall": recall,
261
- "final_f1": f1,
262
- "optimization_history": wandb.Image(history_path),
263
- "parameter_importance": wandb.Image(importance_path),
264
- "shap_summary": wandb.Image(shap_fig_path),
265
- "lime_explanation": wandb.Image(lime_fig_path)
266
  })
267
 
268
- # Generate HTML report
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  report = f"""
270
- <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
271
- <h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">🎯 Hyperparameter Optimization Results</h1>
272
 
273
- <div style="margin-top: 20px; background: #f8f9fa; padding: 15px; border-radius: 8px;">
274
- <h2 style="color: #2B547E;">πŸ“ˆ Performance Metrics</h2>
275
- <p><strong>Best Accuracy:</strong> {best_value:.4f}</p>
276
- <p><strong>Final Model Accuracy:</strong> {accuracy:.4f}</p>
277
- <p><strong>Precision:</strong> {precision:.4f}</p>
278
- <p><strong>Recall:</strong> {recall:.4f}</p>
279
- <p><strong>F1 Score:</strong> {f1:.4f}</p>
280
  </div>
281
 
282
- <div style="margin-top: 25px; background: #f8f9fa; padding: 15px; border-radius: 8px;">
283
- <h2 style="color: #2B547E;">βš™οΈ Best Parameters</h2>
284
- <pre style="background: white; padding: 10px; border-radius: 4px;">{best_params}</pre>
285
  </div>
 
 
 
 
 
 
 
 
 
 
 
 
286
 
287
- <div style="margin-top: 25px;">
288
- <h2 style="color: #2B547E;">πŸ“Š Optimization Process</h2>
289
- <img src="/file={history_path}" style="max-width: 100%; border-radius: 6px; margin-bottom: 15px;">
290
- <img src="/file={importance_path}" style="max-width: 100%; border-radius: 6px;">
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  </div>
292
  </div>
293
  """
294
-
295
- # Get visualization paths for gallery
296
- visuals = [
297
- history_path,
298
- importance_path,
299
- shap_fig_path,
300
- lime_fig_path
301
- ]
302
-
303
- run.finish()
304
- return report, visuals
305
-
306
- def wandb_callback(study, trial):
307
- """Callback to log study information to W&B after each trial"""
308
- wandb.log({
309
- "best_accuracy": study.best_value,
310
- "current_trial": trial.number,
311
- "current_accuracy": trial.value
312
- })
313
 
 
314
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
315
- gr.Markdown("## πŸ“Š AI Data Analysis Agent with Enhanced Hyperparameter Optimization")
316
- with gr.Row():
317
- with gr.Column():
318
- file_input = gr.File(label="Upload CSV Dataset", type="filepath")
319
- notes_input = gr.Textbox(label="Dataset Notes (Optional)", lines=3)
320
- analyze_btn = gr.Button("Analyze", variant="primary")
321
- optuna_trials = gr.Number(
322
- label="Number of Hyperparameter Tuning Trials",
323
- value=50,
324
- minimum=10,
325
- maximum=200,
326
- step=5
327
- )
328
- tune_btn = gr.Button("Optimize Hyperparameters", variant="secondary")
329
- with gr.Column():
330
- analysis_output = gr.Markdown("### Analysis results will appear here...")
331
- optuna_output = gr.HTML(label="Hyperparameter Tuning Results")
332
- gallery = gr.Gallery(label="Optimization Visualizations", columns=2)
333
 
334
- analyze_btn.click(fn=analyze_data, inputs=[file_input, notes_input], outputs=[analysis_output, gallery])
335
- tune_btn.click(fn=tune_hyperparameters, inputs=[file_input, optuna_trials], outputs=[optuna_output, gallery])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
336
 
337
- demo.launch(debug=True)
 
9
  import optuna
10
  import ast
11
  import pandas as pd
12
+ import numpy as np
13
  from sklearn.model_selection import train_test_split
14
  from sklearn.ensemble import RandomForestClassifier
15
+ from sklearn.metrics import (accuracy_score, precision_score,
16
+ recall_score, f1_score, classification_report)
17
+ from sklearn.preprocessing import LabelEncoder
18
  import shap
19
  import lime
20
  import lime.lime_tabular
21
  import matplotlib.pyplot as plt
22
+ import seaborn as sns
23
+ from optuna.visualization import (plot_optimization_history,
24
+ plot_param_importances,
25
+ plot_parallel_coordinate)
26
+ from PIL import Image
27
+ import base64
28
+ from io import BytesIO
29
 
30
  # Authenticate Hugging Face
31
  hf_token = os.getenv("HF_TOKEN")
 
34
  # Initialize Model
35
  model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
36
 
37
+ # Initialize W&B
38
+ wandb.login(key=os.environ.get('WANDB_API_KEY'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ def save_figure(fig, filename):
41
+ """Helper function to save matplotlib figures"""
42
+ os.makedirs('./figures', exist_ok=True)
43
+ path = f"./figures/{filename}"
44
+ fig.savefig(path, bbox_inches='tight')
45
+ plt.close(fig)
46
+ return path
 
 
47
 
48
+ def encode_categorical_data(df):
49
+ """Encode categorical columns and return encoded df and encoders"""
50
+ encoders = {}
51
+ for col in df.select_dtypes(include=['object', 'category']).columns:
52
+ le = LabelEncoder()
53
+ df[col] = le.fit_transform(df[col].astype(str))
54
+ encoders[col] = le
55
+ return df, encoders
 
 
 
 
56
 
57
+ def generate_data_insights(df):
58
+ """Generate insights using smolagent"""
59
+ agent = CodeAgent(
60
+ tools=[],
61
+ model=model,
62
+ additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn"]
63
+ )
 
 
 
 
 
 
 
 
64
 
65
+ prompt = """
66
+ Analyze this dataset and provide:
67
+ 1. 5 key statistical insights about the data
68
+ 2. 5 suggested visualizations with explanations
69
+ 3. Data quality assessment
70
+ 4. Recommendations for preprocessing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
+ For each insight:
73
+ - Explain its significance
74
+ - Provide the Python code to verify it
75
+ - Suggest potential actions
76
 
77
+ Return the results as a dictionary with:
78
+ - 'insights': List of 5 key insights
79
+ - 'visualizations': List of 5 visualization descriptions with code
80
+ - 'quality': Data quality assessment
81
+ - 'recommendations': Preprocessing recommendations
82
+ """
83
 
84
+ return agent.run(prompt, additional_args={"df": df})
 
85
 
86
+ def create_visualizations(df, insights):
87
+ """Create visualizations based on insights"""
88
+ visuals = []
89
+ try:
90
+ # Visualization 1: Missing values heatmap
91
+ if df.isnull().any().any():
92
+ fig, ax = plt.subplots(figsize=(10, 6))
93
+ sns.heatmap(df.isnull(), cbar=False, ax=ax)
94
+ plt.title("Missing Values Heatmap")
95
+ visuals.append(save_figure(fig, "missing_values.png"))
96
+
97
+ # Visualization 2: Correlation heatmap
98
+ numeric_cols = df.select_dtypes(include=np.number).columns
99
+ if len(numeric_cols) > 1:
100
+ fig, ax = plt.subplots(figsize=(10, 8))
101
+ sns.heatmap(df[numeric_cols].corr(), annot=True, cmap='coolwarm', ax=ax)
102
+ plt.title("Correlation Heatmap")
103
+ visuals.append(save_figure(fig, "correlation_heatmap.png"))
104
+
105
+ # Visualization 3: Feature distributions
106
+ for col in numeric_cols[:3]: # First 3 numeric columns
107
+ fig, ax = plt.subplots(figsize=(10, 6))
108
+ sns.histplot(df[col], kde=True, ax=ax)
109
+ plt.title(f"Distribution of {col}")
110
+ visuals.append(save_figure(fig, f"distribution_{col}.png"))
111
+
112
+ # Visualization 4: Pairplot (sample if large)
113
+ if len(numeric_cols) > 1:
114
+ fig = sns.pairplot(df[numeric_cols].sample(min(100, len(df))))
115
+ visuals.append(save_figure(fig, "pairplot.png"))
116
+
117
+ # Visualization 5: Categorical counts
118
+ cat_cols = df.select_dtypes(include=['object', 'category']).columns
119
+ for col in cat_cols[:2]: # First 2 categorical columns
120
+ fig, ax = plt.subplots(figsize=(10, 6))
121
+ df[col].value_counts().plot(kind='bar', ax=ax)
122
+ plt.title(f"Count of {col}")
123
+ visuals.append(save_figure(fig, f"count_{col}.png"))
124
+
125
+ except Exception as e:
126
+ print(f"Visualization error: {e}")
127
 
128
+ return visuals
129
+
130
+ def analyze_data(csv_file, additional_notes=""):
131
+ """Main data analysis function"""
132
+ start_time = time.time()
133
 
134
+ # Initialize W&B run
135
+ run = wandb.init(project="data-analysis", config={
136
+ "model": "Mixtral-8x7B",
137
+ "notes": additional_notes,
138
+ "file": csv_file.name if csv_file else None
 
 
 
 
 
 
 
 
139
  })
140
 
141
+ try:
142
+ # Load data
143
+ df = pd.read_csv(csv_file)
144
+
145
+ # Generate insights with smolagent
146
+ insights = generate_data_insights(df)
147
+
148
+ # Create visualizations
149
+ visuals = create_visualizations(df, insights)
150
+
151
+ # Log to W&B
152
+ for viz in visuals:
153
+ wandb.log({"visualizations": wandb.Image(viz)})
154
+
155
+ # Format report
156
+ report = format_analysis_report(insights, visuals)
157
+
158
+ # Track performance
159
+ execution_time = time.time() - start_time
160
+ wandb.log({"execution_time": execution_time})
161
+
162
+ return report, visuals
163
+
164
+ except Exception as e:
165
+ return f"Error: {str(e)}", []
166
+ finally:
167
+ run.finish()
168
 
169
+ def objective(trial, X, y):
170
+ """Optuna objective function for hyperparameter tuning"""
171
+ params = {
172
+ 'n_estimators': trial.suggest_int('n_estimators', 50, 500),
173
+ 'max_depth': trial.suggest_int('max_depth', 3, 15),
174
+ 'min_samples_split': trial.suggest_int('min_samples_split', 2, 10),
175
+ 'min_samples_leaf': trial.suggest_int('min_samples_leaf', 1, 5),
176
+ 'max_features': trial.suggest_categorical('max_features', ['sqrt', 'log2', None]),
177
+ 'bootstrap': trial.suggest_categorical('bootstrap', [True, False]),
178
+ 'criterion': trial.suggest_categorical('criterion', ['gini', 'entropy'])
179
+ }
180
 
 
 
 
181
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
182
 
183
+ model = RandomForestClassifier(**params, random_state=42, n_jobs=-1)
184
+ model.fit(X_train, y_train)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
 
186
+ y_pred = model.predict(X_test)
 
187
 
188
+ # Track multiple metrics
189
+ metrics = {
190
+ 'accuracy': accuracy_score(y_test, y_pred),
191
+ 'precision': precision_score(y_test, y_pred, average='weighted'),
192
+ 'recall': recall_score(y_test, y_pred, average='weighted'),
193
+ 'f1': f1_score(y_test, y_pred, average='weighted')
194
+ }
195
 
196
+ # Log to W&B
197
+ wandb.log({**params, **metrics})
 
 
 
 
 
 
 
 
 
 
 
 
198
 
199
+ return metrics['accuracy']
200
+
201
+ def tune_hyperparameters(csv_file, n_trials=50):
202
+ """Hyperparameter tuning with Optuna and W&B"""
203
+ run = wandb.init(project="hyperparameter-tuning", config={
204
+ "n_trials": n_trials,
205
+ "model": "RandomForest"
 
 
 
 
 
206
  })
207
 
208
+ try:
209
+ # Load and prepare data
210
+ df = pd.read_csv(csv_file)
211
+ df, _ = encode_categorical_data(df)
212
+
213
+ y = df.iloc[:, -1] # Assume last column is target
214
+ X = df.iloc[:, :-1]
215
+
216
+ # Optuna study
217
+ study = optuna.create_study(
218
+ direction='maximize',
219
+ sampler=optuna.samplers.TPESampler(),
220
+ pruner=optuna.pruners.MedianPruner()
221
+ )
222
+
223
+ study.optimize(lambda trial: objective(trial, X, y), n_trials=n_trials)
224
+
225
+ # Generate visualizations
226
+ visuals = []
227
+ fig = plot_optimization_history(study)
228
+ visuals.append(save_figure(fig, "optimization_history.png"))
229
+
230
+ fig = plot_param_importances(study)
231
+ visuals.append(save_figure(fig, "param_importance.png"))
232
+
233
+ fig = plot_parallel_coordinate(study)
234
+ visuals.append(save_figure(fig, "parallel_coordinate.png"))
235
+
236
+ # Train best model
237
+ best_model = RandomForestClassifier(**study.best_params, random_state=42)
238
+ best_model.fit(X, y)
239
+
240
+ # SHAP explainability
241
+ explainer = shap.TreeExplainer(best_model)
242
+ shap_values = explainer.shap_values(X)
243
+
244
+ fig, ax = plt.subplots(figsize=(10, 8))
245
+ shap.summary_plot(shap_values, X, show=False)
246
+ visuals.append(save_figure(fig, "shap_summary.png"))
247
+
248
+ # LIME explainability
249
+ explainer = lime.lime_tabular.LimeTabularExplainer(
250
+ X.values,
251
+ feature_names=X.columns,
252
+ class_names=['class_0', 'class_1'], # Modify as needed
253
+ mode='classification'
254
+ )
255
+ exp = explainer.explain_instance(X.iloc[0].values, best_model.predict_proba)
256
+ fig = exp.as_pyplot_figure()
257
+ visuals.append(save_figure(fig, "lime_explanation.png"))
258
+
259
+ # Format results
260
+ report = format_tuning_results(study, best_model, X, y)
261
+
262
+ return report, visuals
263
+
264
+ except Exception as e:
265
+ return f"Error: {str(e)}", []
266
+ finally:
267
+ run.finish()
268
+
269
+ def format_analysis_report(insights, visuals):
270
+ """Format the analysis report with insights and visuals"""
271
  report = f"""
272
+ <div style="font-family: Arial; max-width: 1000px; margin: 0 auto;">
273
+ <h1 style="color: #2B547E;">πŸ“Š Data Analysis Report</h1>
274
 
275
+ <div style="margin-top: 20px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
276
+ <h2 style="color: #2B547E;">πŸ” Key Insights</h2>
277
+ {format_insights_section(insights.get('insights', []))}
 
 
 
 
278
  </div>
279
 
280
+ <div style="margin-top: 30px;">
281
+ <h2 style="color: #2B547E;">πŸ“ˆ Visualizations</h2>
282
+ {format_visualizations(visuals)}
283
  </div>
284
+ </div>
285
+ """
286
+ return report
287
+
288
+ def format_tuning_results(study, model, X, y):
289
+ """Format hyperparameter tuning results"""
290
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
291
+ y_pred = model.predict(X_test)
292
+
293
+ report = f"""
294
+ <div style="font-family: Arial; max-width: 1000px; margin: 0 auto;">
295
+ <h1 style="color: #2B547E;">βš™οΈ Hyperparameter Tuning Results</h1>
296
 
297
+ <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 20px; margin-top: 20px;">
298
+ <div style="background: #f8f9fa; padding: 20px; border-radius: 8px;">
299
+ <h2 style="color: #2B547E;">πŸ“Š Best Parameters</h2>
300
+ <pre>{study.best_params}</pre>
301
+ </div>
302
+
303
+ <div style="background: #f8f9fa; padding: 20px; border-radius: 8px;">
304
+ <h2 style="color: #2B547E;">πŸ“ˆ Performance Metrics</h2>
305
+ <p>Accuracy: {accuracy_score(y_test, y_pred):.4f}</p>
306
+ <p>Precision: {precision_score(y_test, y_pred, average='weighted'):.4f}</p>
307
+ <p>Recall: {recall_score(y_test, y_pred, average='weighted'):.4f}</p>
308
+ <p>F1 Score: {f1_score(y_test, y_pred, average='weighted'):.4f}</p>
309
+ </div>
310
+ </div>
311
+
312
+ <div style="margin-top: 30px;">
313
+ <h2 style="color: #2B547E;">πŸ” Classification Report</h2>
314
+ <pre>{classification_report(y_test, y_pred)}</pre>
315
  </div>
316
  </div>
317
  """
318
+ return report
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
319
 
320
+ # Create Gradio interface
321
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
322
+ gr.Markdown("# 🧠 Advanced Data Analysis with AI")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
 
324
+ with gr.Tab("Data Analysis"):
325
+ with gr.Row():
326
+ with gr.Column():
327
+ data_file = gr.File(label="Upload CSV", file_types=[".csv"])
328
+ notes = gr.Textbox(label="Analysis Notes (Optional)", lines=3)
329
+ analyze_btn = gr.Button("Analyze Data", variant="primary")
330
+
331
+ with gr.Column():
332
+ analysis_report = gr.HTML(label="Analysis Report")
333
+ viz_gallery = gr.Gallery(label="Visualizations")
334
+
335
+ with gr.Tab("Model Tuning"):
336
+ with gr.Row():
337
+ with gr.Column():
338
+ tune_file = gr.File(label="Upload CSV for Tuning", file_types=[".csv"])
339
+ trials = gr.Slider(10, 200, value=50, label="Number of Trials")
340
+ tune_btn = gr.Button("Tune Hyperparameters", variant="primary")
341
+
342
+ with gr.Column():
343
+ tuning_report = gr.HTML(label="Tuning Results")
344
+ tuning_viz = gr.Gallery(label="Tuning Visualizations")
345
+
346
+ # Event handlers
347
+ analyze_btn.click(
348
+ fn=analyze_data,
349
+ inputs=[data_file, notes],
350
+ outputs=[analysis_report, viz_gallery]
351
+ )
352
+
353
+ tune_btn.click(
354
+ fn=tune_hyperparameters,
355
+ inputs=[tune_file, trials],
356
+ outputs=[tuning_report, tuning_viz]
357
+ )
358
 
359
+ demo.launch()