pavanmutha commited on
Commit
15bc3c6
·
verified ·
1 Parent(s): 9ff587a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +211 -460
app.py CHANGED
@@ -8,19 +8,15 @@ import time
8
  import psutil
9
  import optuna
10
  import ast
 
 
 
11
  import pandas as pd
12
  import numpy as np
 
13
  from sklearn.model_selection import train_test_split
14
- from sklearn.preprocessing import StandardScaler, OneHotEncoder
15
- from sklearn.compose import ColumnTransformer
16
- from sklearn.pipeline import Pipeline
17
- from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
18
- from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
19
- import matplotlib.pyplot as plt
20
- import seaborn as sns
21
- import shap
22
- import lime
23
- from lime import lime_tabular
24
 
25
  # Authenticate Hugging Face
26
  hf_token = os.getenv("HF_TOKEN")
@@ -29,13 +25,51 @@ login(token=hf_token, add_to_git_credential=True)
29
  # Initialize Model
30
  model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
31
 
32
- def format_analysis_report(raw_output, visuals):
33
  try:
34
  analysis_dict = raw_output if isinstance(raw_output, dict) else ast.literal_eval(str(raw_output))
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  report = f"""
37
  <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
38
  <h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">📊 Data Analysis Report</h1>
 
39
  <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
40
  <h2 style="color: #2B547E;">🔍 Key Observations</h2>
41
  {format_observations(analysis_dict.get('observations', {}))}
@@ -44,77 +78,90 @@ def format_analysis_report(raw_output, visuals):
44
  <h2 style="color: #2B547E;">💡 Insights & Visualizations</h2>
45
  {format_insights(analysis_dict.get('insights', {}), visuals)}
46
  </div>
 
47
  </div>
48
  """
49
  return report, visuals
50
  except:
51
  return raw_output, visuals
52
 
53
- def format_observations(observations):
54
- return '\n'.join([
55
- f"""
56
- <div style="margin: 15px 0; padding: 15px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
57
- <h3 style="margin: 0 0 10px 0; color: #4A708B;">{key.replace('_', ' ').title()}</h3>
58
- <pre style="margin: 0; padding: 10px; background: #f8f9fa; border-radius: 4px;">{value}</pre>
59
- </div>
60
- """ for key, value in observations.items() if 'proportions' in key
61
- ])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- def format_insights(insights, visuals):
64
- return '\n'.join([
65
- f"""
66
- <div style="margin: 20px 0; padding: 20px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
67
- <div style="display: flex; align-items: center; gap: 10px;">
68
- <div style="background: #2B547E; color: white; width: 30px; height: 30px; border-radius: 50%; display: flex; align-items: center; justify-content: center;">{idx+1}</div>
69
- <p style="margin: 0; font-size: 16px;">{insight}</p>
70
- </div>
71
- {f'<img src="/file={visuals[idx]}" style="max-width: 100%; height: auto; margin-top: 10px; border-radius: 6px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">' if idx < len(visuals) else ''}
72
- </div>
73
- """ for idx, (key, insight) in enumerate(insights.items())
74
- ])
 
 
 
 
 
 
75
 
76
- def format_model_evaluation(metrics_dict, feature_importance_path=None, explainability_path=None):
77
- report = f"""
78
- <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
79
- <h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">🧠 Model Evaluation Report</h1>
80
-
81
- <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
82
- <h2 style="color: #2B547E;">📈 Performance Metrics</h2>
83
- <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 15px;">
84
- <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
85
- <h3 style="margin: 0 0 10px 0; color: #4A708B;">Accuracy</h3>
86
- <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics_dict.get('accuracy', 'N/A'):.4f}</p>
87
- </div>
88
- <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
89
- <h3 style="margin: 0 0 10px 0; color: #4A708B;">Precision</h3>
90
- <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics_dict.get('precision', 'N/A'):.4f}</p>
91
- </div>
92
- <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
93
- <h3 style="margin: 0 0 10px 0; color: #4A708B;">Recall</h3>
94
- <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics_dict.get('recall', 'N/A'):.4f}</p>
95
- </div>
96
- <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
97
- <h3 style="margin: 0 0 10px 0; color: #4A708B;">F1 Score</h3>
98
- <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics_dict.get('f1', 'N/A'):.4f}</p>
99
- </div>
100
- </div>
101
- </div>
102
-
103
- <div style="margin-top: 30px;">
104
- <h2 style="color: #2B547E;">📊 Feature Importance & Explainability</h2>
105
- {f'<img src="/file={feature_importance_path}" style="max-width: 100%; height: auto; margin-top: 10px; border-radius: 6px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">' if feature_importance_path else ''}
106
- {f'<img src="/file={explainability_path}" style="max-width: 100%; height: auto; margin-top: 10px; border-radius: 6px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">' if explainability_path else ''}
107
- </div>
108
-
109
- <div style="margin-top: 30px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
110
- <h2 style="color: #2B547E;">🔄 Hyperparameters</h2>
111
- <pre style="margin: 0; padding: 15px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">{metrics_dict.get('best_params', 'N/A')}</pre>
112
- </div>
113
- </div>
114
- """
115
- return report
116
 
117
- def analyze_data(csv_file, additional_notes=""):
118
  start_time = time.time()
119
  process = psutil.Process(os.getpid())
120
  initial_memory = process.memory_info().rss / 1024 ** 2
@@ -127,9 +174,35 @@ def analyze_data(csv_file, additional_notes=""):
127
  run = wandb.init(project="huggingface-data-analysis", config={
128
  "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
129
  "additional_notes": additional_notes,
130
- "source_file": csv_file.name if csv_file else None
 
131
  })
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  agent = CodeAgent(tools=[], model=model, additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn"])
134
  analysis_result = agent.run("""
135
  You are an expert data analyst. Perform comprehensive analysis including:
@@ -150,403 +223,81 @@ def analyze_data(csv_file, additional_notes=""):
150
  wandb.log({os.path.basename(viz): wandb.Image(viz)})
151
 
152
  run.finish()
153
- return format_analysis_report(analysis_result, visuals)
154
 
155
- def preprocess_features(data, target_column, feature_engineering=True):
156
- """
157
- Preprocess features with optional feature engineering
158
- """
159
- # Check if data is loaded
160
- if data is None or not isinstance(data, pd.DataFrame):
161
- return None, None, None, None, None
162
-
163
- # Separate features and target
164
- if target_column not in data.columns:
165
- # Try to infer target column if it's not specified
166
- for col in ['target', 'label', 'class', 'outcome', 'y']:
167
- if col in data.columns:
168
- target_column = col
169
- break
170
- else:
171
- return None, None, None, None, None
172
-
173
- X = data.drop(target_column, axis=1)
174
- y = data[target_column]
175
 
176
- # Split data
 
 
 
 
177
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
178
 
179
- # Identify numerical and categorical columns
180
- numerical_cols = X.select_dtypes(include=['int64', 'float64']).columns.tolist()
181
- categorical_cols = X.select_dtypes(include=['object', 'category']).columns.tolist()
182
-
183
- # Basic preprocessing
184
- numeric_transformer = Pipeline(steps=[
185
- ('scaler', StandardScaler())
186
- ])
187
-
188
- categorical_transformer = Pipeline(steps=[
189
- ('onehot', OneHotEncoder(handle_unknown='ignore'))
190
- ])
191
 
192
- preprocessor = ColumnTransformer(
193
- transformers=[
194
- ('num', numeric_transformer, numerical_cols),
195
- ('cat', categorical_transformer, categorical_cols)
196
- ])
197
 
198
- # Feature engineering when enabled
199
- if feature_engineering:
200
- # Create interaction terms between numerical features
201
- for i, col1 in enumerate(numerical_cols):
202
- for col2 in numerical_cols[i+1:]:
203
- if len(numerical_cols) > 1:
204
- X_train[f'{col1}_{col2}_interaction'] = X_train[col1] * X_train[col2]
205
- X_test[f'{col1}_{col2}_interaction'] = X_test[col1] * X_test[col2]
206
-
207
- # Create polynomial features for numerical columns (quadratic terms)
208
- for col in numerical_cols:
209
- X_train[f'{col}_squared'] = X_train[col] ** 2
210
- X_test[f'{col}_squared'] = X_test[col] ** 2
211
-
212
- # Create aggregate features for categorical columns
213
- for col in categorical_cols:
214
- # For each categorical column, calculate mean of numerical columns grouped by categories
215
- for num_col in numerical_cols:
216
- if num_col in X_train.columns:
217
- agg_map = X_train.groupby(col)[num_col].mean().to_dict()
218
- X_train[f'{col}_{num_col}_agg'] = X_train[col].map(agg_map)
219
- X_test[f'{col}_{num_col}_agg'] = X_test[col].map(agg_map)
220
-
221
- return X_train, X_test, y_train, y_test, preprocessor
222
 
223
- def create_shap_plot(model, X_test, feature_names):
224
- """Create SHAP summary plot for model explainability"""
225
- plt.figure(figsize=(12, 8))
226
-
227
- # For tree-based models
228
- if hasattr(model, 'feature_importances_'):
229
- explainer = shap.TreeExplainer(model)
230
- shap_values = explainer.shap_values(X_test)
231
-
232
- # Handle multi-class case
233
- if isinstance(shap_values, list):
234
- shap_values = shap_values[1] # Use the positive class
235
 
236
- shap.summary_plot(shap_values, X_test, feature_names=feature_names, show=False)
237
- else:
238
- # Fallback for non-tree models
239
- explainer = shap.KernelExplainer(model.predict_proba, shap.sample(X_test, 50))
240
- shap_values = explainer.shap_values(X_test[:50])
241
-
242
- # Handle multi-class case
243
- if isinstance(shap_values, list):
244
- shap_values = shap_values[1] # Use the positive class
245
 
246
- shap.summary_plot(shap_values, X_test[:50], feature_names=feature_names, show=False)
247
-
248
- plt.tight_layout()
249
- file_path = './figures/shap_summary.png'
250
- plt.savefig(file_path)
251
- plt.close()
252
- return file_path
253
-
254
- def create_lime_explanation(model, X_train, X_test, feature_names):
255
- """Create LIME explanation for a sample instance"""
256
- # Create LIME explainer
257
- explainer = lime_tabular.LimeTabularExplainer(
258
- X_train,
259
- feature_names=feature_names,
260
- class_names=["Negative", "Positive"],
261
- mode="classification"
262
- )
263
-
264
- # Explain a sample instance
265
- instance_idx = 0
266
- exp = explainer.explain_instance(
267
- X_test[instance_idx],
268
- model.predict_proba,
269
- num_features=10
270
- )
271
-
272
- # Plot explanation
273
- plt.figure(figsize=(10, 6))
274
- exp.as_pyplot_figure()
275
- plt.tight_layout()
276
- file_path = './figures/lime_explanation.png'
277
- plt.savefig(file_path)
278
- plt.close()
279
- return file_path
280
-
281
- def create_feature_importance_plot(model, feature_names):
282
- """Create feature importance plot if model supports it"""
283
- if hasattr(model, 'feature_importances_'):
284
- importances = model.feature_importances_
285
- indices = np.argsort(importances)[-20:] # Top 20 features
286
-
287
- plt.figure(figsize=(12, 8))
288
- plt.title('Feature Importances')
289
- plt.barh(range(len(indices)), importances[indices], align='center')
290
- plt.yticks(range(len(indices)), [feature_names[i] for i in indices])
291
- plt.xlabel('Relative Importance')
292
- plt.tight_layout()
293
- file_path = './figures/feature_importance.png'
294
- plt.savefig(file_path)
295
- plt.close()
296
- return file_path
297
- return None
298
-
299
- def train_and_evaluate_model(csv_file, target_column, model_type, feature_eng_enabled=True, explainer_type="shap"):
300
- """Train, evaluate model with metrics and explainability"""
301
- if not csv_file:
302
- return "Please upload a CSV file", None, []
303
-
304
- # Load data
305
- try:
306
- data = pd.read_csv(csv_file)
307
- except Exception as e:
308
- return f"Error loading data: {str(e)}", None, []
309
-
310
- # Preprocess data
311
- X_train, X_test, y_train, y_test, preprocessor = preprocess_features(
312
- data, target_column, feature_engineering=feature_eng_enabled
313
- )
314
-
315
- if X_train is None:
316
- return f"Error: Could not identify target column '{target_column}'", None, []
317
-
318
- # Apply preprocessing
319
- X_train_processed = X_train
320
- X_test_processed = X_test
321
-
322
- # Select model
323
- if model_type == "random_forest":
324
- model = RandomForestClassifier(random_state=42)
325
- else: # Default to gradient boosting
326
- model = GradientBoostingClassifier(random_state=42)
327
-
328
- # Train model
329
- model.fit(X_train_processed, y_train)
330
-
331
- # Make predictions
332
- y_pred = model.predict(X_test_processed)
333
-
334
- # Calculate metrics
335
- metrics = {
336
- 'accuracy': accuracy_score(y_test, y_pred),
337
- 'precision': precision_score(y_test, y_pred, average='weighted'),
338
- 'recall': recall_score(y_test, y_pred, average='weighted'),
339
- 'f1': f1_score(y_test, y_pred, average='weighted'),
340
- }
341
-
342
- # Generate feature names
343
- feature_names = X_train_processed.columns.tolist()
344
-
345
- # Create feature importance plot
346
- feature_importance_path = create_feature_importance_plot(model, feature_names)
347
-
348
- # Create explainability visualization
349
- explainability_path = None
350
- if explainer_type == "shap":
351
- explainability_path = create_shap_plot(model, X_test_processed, feature_names)
352
- else: # LIME
353
- explainability_path = create_lime_explanation(model, X_train_processed.values,
354
- X_test_processed.values, feature_names)
355
-
356
- # Log to wandb
357
- wandb.login(key=os.environ.get('WANDB_API_KEY'))
358
- run = wandb.init(project="huggingface-model-evaluation", config={
359
- "model_type": model_type,
360
- "feature_engineering": feature_eng_enabled,
361
- "explainer": explainer_type,
362
- "metrics": metrics
363
- })
364
-
365
- wandb.log(metrics)
366
-
367
- if feature_importance_path:
368
- wandb.log({"feature_importance": wandb.Image(feature_importance_path)})
369
-
370
- if explainability_path:
371
- wandb.log({"explainability": wandb.Image(explainability_path)})
372
-
373
- run.finish()
374
-
375
- # Return results
376
- results = [feature_importance_path, explainability_path] if feature_importance_path and explainability_path else []
377
- return format_model_evaluation(metrics, feature_importance_path, explainability_path), None, results
378
-
379
- def objective(trial, csv_file, target_column, model_type, feature_eng_enabled=True):
380
- """Objective function for Optuna hyperparameter optimization"""
381
- try:
382
- # Load data
383
- data = pd.read_csv(csv_file)
384
-
385
- # Preprocess data
386
- X_train, X_test, y_train, y_test, preprocessor = preprocess_features(
387
- data, target_column, feature_engineering=feature_eng_enabled
388
- )
389
-
390
- if X_train is None:
391
- return 0.0
392
-
393
- # Apply preprocessing
394
- X_train_processed = X_train
395
- X_test_processed = X_test
396
-
397
- # Hyperparameters based on model type
398
- if model_type == "random_forest":
399
- model = RandomForestClassifier(
400
- n_estimators=trial.suggest_int("n_estimators", 50, 500),
401
- max_depth=trial.suggest_int("max_depth", 3, 20),
402
- min_samples_split=trial.suggest_int("min_samples_split", 2, 10),
403
- min_samples_leaf=trial.suggest_int("min_samples_leaf", 1, 4),
404
- bootstrap=trial.suggest_categorical("bootstrap", [True, False]),
405
- random_state=42
406
- )
407
- else: # Gradient Boosting
408
- model = GradientBoostingClassifier(
409
- learning_rate=trial.suggest_float("learning_rate", 0.01, 0.3),
410
- n_estimators=trial.suggest_int("n_estimators", 50, 500),
411
- max_depth=trial.suggest_int("max_depth", 3, 10),
412
- min_samples_split=trial.suggest_int("min_samples_split", 2, 10),
413
- min_samples_leaf=trial.suggest_int("min_samples_leaf", 1, 4),
414
- subsample=trial.suggest_float("subsample", 0.6, 1.0),
415
- random_state=42
416
- )
417
-
418
- # Train model
419
- model.fit(X_train_processed, y_train)
420
-
421
- # Evaluate model
422
- y_pred = model.predict(X_test_processed)
423
- f1 = f1_score(y_test, y_pred, average='weighted')
424
-
425
- return f1
426
-
427
  except Exception as e:
428
- print(f"Error in objective function: {str(e)}")
429
- return 0.0
430
-
431
- def tune_hyperparameters(csv_file, target_column, model_type, n_trials=10, feature_eng_enabled=True):
432
- """Run hyperparameter tuning with Optuna"""
433
- if not csv_file:
434
- return "Please upload a CSV file first"
435
-
436
- wandb.login(key=os.environ.get('WANDB_API_KEY'))
437
- run = wandb.init(project="huggingface-hyperparameter-tuning", config={
438
- "model_type": model_type,
439
- "feature_engineering": feature_eng_enabled,
440
- "n_trials": n_trials
441
- })
442
-
443
- study = optuna.create_study(direction="maximize")
444
- study.optimize(
445
- lambda trial: objective(trial, csv_file, target_column, model_type, feature_eng_enabled),
446
- n_trials=n_trials
447
- )
448
-
449
- # Log best parameters to wandb
450
- wandb.log({"best_params": study.best_params, "best_value": study.best_value})
451
-
452
- # Visualization of optimization history
453
- plt.figure(figsize=(10, 6))
454
- optuna.visualization.matplotlib.plot_optimization_history(study)
455
- plt.tight_layout()
456
- history_path = './figures/optuna_history.png'
457
- plt.savefig(history_path)
458
- plt.close()
459
-
460
- # Visualization of parameter importances
461
- plt.figure(figsize=(10, 6))
462
- optuna.visualization.matplotlib.plot_param_importances(study)
463
- plt.tight_layout()
464
- importance_path = './figures/optuna_importance.png'
465
- plt.savefig(importance_path)
466
- plt.close()
467
-
468
- # Log visualizations
469
- wandb.log({"optimization_history": wandb.Image(history_path)})
470
- wandb.log({"parameter_importance": wandb.Image(importance_path)})
471
-
472
- run.finish()
473
-
474
- # Return a formatted result
475
- result = f"""
476
- <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
477
- <h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">⚙️ Hyperparameter Optimization Results</h1>
478
-
479
- <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
480
- <h2 style="color: #2B547E;">🏆 Best Parameters</h2>
481
- <pre style="margin: 10px 0; padding: 15px; background: white; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">{study.best_params}</pre>
482
-
483
- <h3 style="color: #4A708B; margin-top: 20px;">Best F1 Score</h3>
484
- <p style="font-size: 20px; font-weight: bold;">{study.best_value:.4f}</p>
485
- </div>
486
-
487
- <div style="margin-top: 30px;">
488
- <h2 style="color: #2B547E;">📈 Optimization Results</h2>
489
- <img src="/file={history_path}" style="max-width: 100%; height: auto; margin-top: 10px; border-radius: 6px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
490
- <img src="/file={importance_path}" style="max-width: 100%; height: auto; margin-top: 10px; border-radius: 6px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">
491
- </div>
492
- </div>
493
- """
494
-
495
- # Return results and visualization paths for gallery
496
- return result, [history_path, importance_path]
497
 
498
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
499
- gr.Markdown("## 📊 AI Data Analysis & ML Experimentation Platform")
500
-
501
- with gr.Tab("Data Analysis"):
502
- with gr.Row():
503
- with gr.Column():
504
- file_input_analysis = gr.File(label="Upload CSV Dataset", type="filepath")
505
- notes_input = gr.Textbox(label="Dataset Notes (Optional)", lines=3)
506
- analyze_btn = gr.Button("Analyze", variant="primary")
507
- with gr.Column():
508
- analysis_output = gr.HTML("### Analysis results will appear here...")
509
- gallery = gr.Gallery(label="Data Visualizations", columns=2)
510
-
511
- analyze_btn.click(fn=analyze_data, inputs=[file_input_analysis, notes_input], outputs=[analysis_output, gallery])
512
-
513
- with gr.Tab("ML Model Experimentation"):
514
- with gr.Row():
515
- with gr.Column():
516
- file_input_model = gr.File(label="Upload CSV Dataset", type="filepath")
517
- target_column = gr.Textbox(label="Target Column Name", placeholder="e.g., target, class, outcome")
518
- model_type = gr.Radio(["random_forest", "gradient_boosting"], label="Model Type", value="random_forest")
519
- feature_eng = gr.Checkbox(label="Enable Feature Engineering", value=True)
520
- explainer_type = gr.Radio(["shap", "lime"], label="Explainability Tool", value="shap")
521
- train_btn = gr.Button("Train & Evaluate Model", variant="primary")
522
- with gr.Column():
523
- model_output = gr.HTML("### Model evaluation results will appear here...")
524
- model_metrics = gr.Textbox(label="Raw Metrics", visible=False)
525
- model_gallery = gr.Gallery(label="Model Visualizations", columns=2)
526
-
527
- train_btn.click(
528
- fn=train_and_evaluate_model,
529
- inputs=[file_input_model, target_column, model_type, feature_eng, explainer_type],
530
- outputs=[model_output, model_metrics, model_gallery]
531
- )
532
-
533
- with gr.Tab("Hyperparameter Tuning"):
534
- with gr.Row():
535
- with gr.Column():
536
- file_input_hp = gr.File(label="Upload CSV Dataset", type="filepath")
537
- target_column_hp = gr.Textbox(label="Target Column Name", placeholder="e.g., target, class, outcome")
538
- model_type_hp = gr.Radio(["random_forest", "gradient_boosting"], label="Model Type", value="random_forest")
539
- feature_eng_hp = gr.Checkbox(label="Enable Feature Engineering", value=True)
540
- n_trials = gr.Slider(minimum=5, maximum=50, value=10, step=5, label="Number of Optimization Trials")
541
- tune_btn = gr.Button("Run Hyperparameter Optimization", variant="primary")
542
- with gr.Column():
543
- hp_output = gr.HTML("### Hyperparameter tuning results will appear here...")
544
- hp_gallery = gr.Gallery(label="Optimization Visualizations", columns=2)
545
-
546
- tune_btn.click(
547
- fn=tune_hyperparameters,
548
- inputs=[file_input_hp, target_column_hp, model_type_hp, n_trials, feature_eng_hp],
549
- outputs=[hp_output, hp_gallery]
550
- )
551
 
552
  demo.launch(debug=True)
 
8
  import psutil
9
  import optuna
10
  import ast
11
+ import shap
12
+ import lime
13
+ import lime.lime_tabular
14
  import pandas as pd
15
  import numpy as np
16
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
17
  from sklearn.model_selection import train_test_split
18
+ from sklearn.ensemble import RandomForestClassifier
19
+ from sklearn.preprocessing import StandardScaler, PolynomialFeatures
 
 
 
 
 
 
 
 
20
 
21
  # Authenticate Hugging Face
22
  hf_token = os.getenv("HF_TOKEN")
 
25
  # Initialize Model
26
  model = HfApiModel("mistralai/Mixtral-8x7B-Instruct-v0.1", token=hf_token)
27
 
28
+ def format_analysis_report(raw_output, visuals, metrics=None, explainability_plots=None):
29
  try:
30
  analysis_dict = raw_output if isinstance(raw_output, dict) else ast.literal_eval(str(raw_output))
31
 
32
+ metrics_section = ""
33
+ if metrics:
34
+ metrics_section = f"""
35
+ <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
36
+ <h2 style="color: #2B547E;">📈 Model Performance Metrics</h2>
37
+ <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 15px;">
38
+ <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
39
+ <h3 style="margin: 0 0 10px 0; color: #4A708B;">Accuracy</h3>
40
+ <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics['accuracy']:.2f}</p>
41
+ </div>
42
+ <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
43
+ <h3 style="margin: 0 0 10px 0; color: #4A708B;">Precision</h3>
44
+ <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics['precision']:.2f}</p>
45
+ </div>
46
+ <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
47
+ <h3 style="margin: 0 0 10px 0; color: #4A708B;">Recall</h3>
48
+ <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics['recall']:.2f}</p>
49
+ </div>
50
+ <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
51
+ <h3 style="margin: 0 0 10px 0; color: #4A708B;">F1 Score</h3>
52
+ <p style="font-size: 24px; font-weight: bold; margin: 0;">{metrics['f1']:.2f}</p>
53
+ </div>
54
+ </div>
55
+ </div>
56
+ """
57
+
58
+ explainability_section = ""
59
+ if explainability_plots:
60
+ explainability_section = f"""
61
+ <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
62
+ <h2 style="color: #2B547E;">🔍 Model Explainability</h2>
63
+ <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 15px;">
64
+ {''.join([f'<img src="/file={plot}" style="max-width: 100%; height: auto; border-radius: 6px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);">' for plot in explainability_plots])}
65
+ </div>
66
+ </div>
67
+ """
68
+
69
  report = f"""
70
  <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
71
  <h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">📊 Data Analysis Report</h1>
72
+ {metrics_section}
73
  <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
74
  <h2 style="color: #2B547E;">🔍 Key Observations</h2>
75
  {format_observations(analysis_dict.get('observations', {}))}
 
78
  <h2 style="color: #2B547E;">💡 Insights & Visualizations</h2>
79
  {format_insights(analysis_dict.get('insights', {}), visuals)}
80
  </div>
81
+ {explainability_section}
82
  </div>
83
  """
84
  return report, visuals
85
  except:
86
  return raw_output, visuals
87
 
88
+ def preprocess_data(df, feature_engineering=True):
89
+ """Handle missing values, categorical encoding, and feature engineering"""
90
+ # Basic preprocessing
91
+ df = df.dropna()
92
+
93
+ # Convert categorical variables if any
94
+ categorical_cols = df.select_dtypes(include=['object']).columns
95
+ for col in categorical_cols:
96
+ if len(df[col].unique()) <= 10: # One-hot encode if few categories
97
+ df = pd.concat([df, pd.get_dummies(df[col], prefix=col)], axis=1)
98
+ df = df.drop(col, axis=1)
99
+
100
+ # Feature engineering
101
+ if feature_engineering:
102
+ # Create polynomial features for numerical columns
103
+ num_cols = df.select_dtypes(include=['int64', 'float64']).columns
104
+ if len(num_cols) > 0:
105
+ poly = PolynomialFeatures(degree=2, interaction_only=True, include_bias=False)
106
+ poly_features = poly.fit_transform(df[num_cols])
107
+ poly_cols = [f"poly_{i}" for i in range(poly_features.shape[1])]
108
+ poly_df = pd.DataFrame(poly_features, columns=poly_cols)
109
+ df = pd.concat([df, poly_df], axis=1)
110
+
111
+ return df
112
 
113
+ def evaluate_model(X, y, model, test_size=0.2):
114
+ """Evaluate model performance with various metrics"""
115
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=42)
116
+
117
+ # Standardize features
118
+ scaler = StandardScaler()
119
+ X_train = scaler.fit_transform(X_train)
120
+ X_test = scaler.transform(X_test)
121
+
122
+ model.fit(X_train, y_train)
123
+ y_pred = model.predict(X_test)
124
+
125
+ return {
126
+ 'accuracy': accuracy_score(y_test, y_pred),
127
+ 'precision': precision_score(y_test, y_pred, average='weighted'),
128
+ 'recall': recall_score(y_test, y_pred, average='weighted'),
129
+ 'f1': f1_score(y_test, y_pred, average='weighted')
130
+ }
131
 
132
+ def generate_explainability_plots(X, model, feature_names, output_dir='./figures'):
133
+ """Generate SHAP and LIME explainability plots"""
134
+ os.makedirs(output_dir, exist_ok=True)
135
+ plot_paths = []
136
+
137
+ # SHAP Analysis
138
+ explainer = shap.Explainer(model)
139
+ shap_values = explainer(X)
140
+
141
+ plt = shap.summary_plot(shap_values, X, feature_names=feature_names, show=False)
142
+ shap_path = os.path.join(output_dir, 'shap_summary.png')
143
+ plt.savefig(shap_path, bbox_inches='tight')
144
+ plt.close()
145
+ plot_paths.append(shap_path)
146
+
147
+ # LIME Analysis
148
+ explainer = lime.lime_tabular.LimeTabularExplainer(
149
+ X,
150
+ feature_names=feature_names,
151
+ class_names=['class_0', 'class_1'], # Update based on your classes
152
+ verbose=True,
153
+ mode='classification'
154
+ )
155
+
156
+ # Explain a random instance
157
+ exp = explainer.explain_instance(X[0], model.predict_proba, num_features=5)
158
+ lime_path = os.path.join(output_dir, 'lime_explanation.png')
159
+ exp.as_pyplot_figure().savefig(lime_path, bbox_inches='tight')
160
+ plot_paths.append(lime_path)
161
+
162
+ return plot_paths
 
 
 
 
 
 
 
 
 
163
 
164
+ def analyze_data(csv_file, additional_notes="", perform_ml=True):
165
  start_time = time.time()
166
  process = psutil.Process(os.getpid())
167
  initial_memory = process.memory_info().rss / 1024 ** 2
 
174
  run = wandb.init(project="huggingface-data-analysis", config={
175
  "model": "mistralai/Mixtral-8x7B-Instruct-v0.1",
176
  "additional_notes": additional_notes,
177
+ "source_file": csv_file.name if csv_file else None,
178
+ "perform_ml": perform_ml
179
  })
180
 
181
+ # Load and preprocess data
182
+ df = pd.read_csv(csv_file)
183
+ processed_df = preprocess_data(df)
184
+
185
+ metrics = None
186
+ explainability_plots = None
187
+
188
+ if perform_ml and len(processed_df.columns) > 1:
189
+ try:
190
+ # Assume last column is target for demonstration
191
+ X = processed_df.iloc[:, :-1].values
192
+ y = processed_df.iloc[:, -1].values
193
+
194
+ # Evaluate baseline model
195
+ baseline_model = RandomForestClassifier(random_state=42)
196
+ metrics = evaluate_model(X, y, baseline_model)
197
+
198
+ # Generate explainability plots
199
+ feature_names = processed_df.columns[:-1]
200
+ explainability_plots = generate_explainability_plots(X[:100], baseline_model, feature_names)
201
+
202
+ wandb.log(metrics)
203
+ except Exception as e:
204
+ print(f"ML analysis failed: {str(e)}")
205
+
206
  agent = CodeAgent(tools=[], model=model, additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn"])
207
  analysis_result = agent.run("""
208
  You are an expert data analyst. Perform comprehensive analysis including:
 
223
  wandb.log({os.path.basename(viz): wandb.Image(viz)})
224
 
225
  run.finish()
226
+ return format_analysis_report(analysis_result, visuals, metrics, explainability_plots)
227
 
228
+ def objective(trial):
229
+ # Define hyperparameter space
230
+ params = {
231
+ 'n_estimators': trial.suggest_int('n_estimators', 50, 500),
232
+ 'max_depth': trial.suggest_int('max_depth', 3, 15),
233
+ 'min_samples_split': trial.suggest_int('min_samples_split', 2, 10),
234
+ 'min_samples_leaf': trial.suggest_int('min_samples_leaf', 1, 5),
235
+ 'max_features': trial.suggest_categorical('max_features', ['sqrt', 'log2', None]),
236
+ 'bootstrap': trial.suggest_categorical('bootstrap', [True, False])
237
+ }
 
 
 
 
 
 
 
 
 
 
238
 
239
+ # Load data (you would need to pass this or make it available)
240
+ df = pd.read_csv("temp_data.csv") # You'll need to handle this properly
241
+ processed_df = preprocess_data(df)
242
+ X = processed_df.iloc[:, :-1].values
243
+ y = processed_df.iloc[:, -1].values
244
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
245
 
246
+ # Standardize features
247
+ scaler = StandardScaler()
248
+ X_train = scaler.fit_transform(X_train)
249
+ X_test = scaler.transform(X_test)
 
 
 
 
 
 
 
 
250
 
251
+ # Create and evaluate model
252
+ model = RandomForestClassifier(**params, random_state=42)
253
+ model.fit(X_train, y_train)
254
+ y_pred = model.predict(X_test)
 
255
 
256
+ # Return metric to optimize (F1 score in this case)
257
+ return f1_score(y_test, y_pred, average='weighted')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
+ def tune_hyperparameters(n_trials: int, csv_file):
260
+ try:
261
+ # Save the uploaded file temporarily for Optuna
262
+ if csv_file:
263
+ temp_path = "temp_data.csv"
264
+ with open(temp_path, "wb") as f:
265
+ f.write(csv_file.read())
 
 
 
 
 
266
 
267
+ study = optuna.create_study(direction="maximize")
268
+ study.optimize(objective, n_trials=n_trials)
 
 
 
 
 
 
 
269
 
270
+ os.remove(temp_path)
271
+ return f"Best Hyperparameters: {study.best_params}\nBest F1 Score: {study.best_value:.4f}"
272
+ else:
273
+ return "Please upload a CSV file first for hyperparameter tuning."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
  except Exception as e:
275
+ return f"Hyperparameter tuning failed: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
278
+ gr.Markdown("## 📊 AI Data Analysis Agent with Hyperparameter Optimization")
279
+ with gr.Row():
280
+ with gr.Column():
281
+ file_input = gr.File(label="Upload CSV Dataset", type="filepath")
282
+ notes_input = gr.Textbox(label="Dataset Notes (Optional)", lines=3)
283
+ perform_ml = gr.Checkbox(label="Perform Machine Learning Analysis", value=True)
284
+ analyze_btn = gr.Button("Analyze", variant="primary")
285
+ optuna_trials = gr.Number(label="Number of Hyperparameter Tuning Trials", value=10)
286
+ tune_btn = gr.Button("Optimize Hyperparameters", variant="secondary")
287
+ with gr.Column():
288
+ analysis_output = gr.Markdown("### Analysis results will appear here...")
289
+ optuna_output = gr.Textbox(label="Best Hyperparameters")
290
+ gallery = gr.Gallery(label="Data Visualizations", columns=2)
291
+
292
+ analyze_btn.click(
293
+ fn=analyze_data,
294
+ inputs=[file_input, notes_input, perform_ml],
295
+ outputs=[analysis_output, gallery]
296
+ )
297
+ tune_btn.click(
298
+ fn=tune_hyperparameters,
299
+ inputs=[optuna_trials, file_input],
300
+ outputs=[optuna_output]
301
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
302
 
303
  demo.launch(debug=True)