pavanmutha commited on
Commit
f78e140
·
verified ·
1 Parent(s): 7b22a65

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -6
app.py CHANGED
@@ -56,7 +56,7 @@ def format_insights(insights, visuals):
56
  """ for idx, (key, insight) in enumerate(insights.items())
57
  ])
58
 
59
- def format_analysis_report(raw_output, visuals, metrics=None, explainability_plots=None):
60
  try:
61
  # Ensure we have a dictionary to work with
62
  if isinstance(raw_output, str):
@@ -96,6 +96,25 @@ def format_analysis_report(raw_output, visuals, metrics=None, explainability_plo
96
  </div>
97
  """
98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  # Explainability section
100
  explainability_section = ""
101
  if explainability_plots:
@@ -132,6 +151,7 @@ def format_analysis_report(raw_output, visuals, metrics=None, explainability_plo
132
  report = f"""
133
  <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
134
  <h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">📊 Data Analysis Report</h1>
 
135
  {metrics_section}
136
  {explainability_section}
137
  {observations_section}
@@ -260,6 +280,7 @@ def analyze_data(csv_file, additional_notes="", perform_ml=True):
260
 
261
  metrics = None
262
  explainability_plots = None
 
263
 
264
  try:
265
  # Load and preprocess data
@@ -278,8 +299,21 @@ def analyze_data(csv_file, additional_notes="", perform_ml=True):
278
  if y.dtype == object:
279
  y = pd.factorize(y)[0]
280
 
 
 
 
 
 
 
 
 
 
 
 
 
 
281
  # Evaluate baseline model
282
- baseline_model = RandomForestClassifier(random_state=42, n_estimators=100)
283
  metrics = evaluate_model(X, y, baseline_model)
284
 
285
  # Generate explainability plots
@@ -308,14 +342,18 @@ def analyze_data(csv_file, additional_notes="", perform_ml=True):
308
  execution_time = time.time() - start_time
309
  final_memory = process.memory_info().rss / 1024 ** 2
310
  memory_usage = final_memory - initial_memory
311
- wandb.log({"execution_time_sec": execution_time, "memory_usage_mb": memory_usage})
 
 
 
 
312
 
313
  visuals = [os.path.join('./figures', f) for f in os.listdir('./figures') if f.endswith(('.png', '.jpg', '.jpeg'))]
314
  for viz in visuals:
315
  wandb.log({os.path.basename(viz): wandb.Image(viz)})
316
 
317
  run.finish()
318
- return format_analysis_report(analysis_result, visuals, metrics, explainability_plots)
319
 
320
  def objective(trial, csv_path):
321
  try:
@@ -343,6 +381,10 @@ def objective(trial, csv_path):
343
  'bootstrap': trial.suggest_categorical('bootstrap', [True, False])
344
  }
345
 
 
 
 
 
346
  # Split data
347
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
348
 
@@ -356,8 +398,19 @@ def objective(trial, csv_path):
356
  model.fit(X_train, y_train)
357
  y_pred = model.predict(X_test)
358
 
359
- # Return metric to optimize (F1 score in this case)
360
- return f1_score(y_test, y_pred, average='weighted')
 
 
 
 
 
 
 
 
 
 
 
361
 
362
  except Exception as e:
363
  print(f"Trial failed: {str(e)}")
@@ -379,11 +432,24 @@ def tune_hyperparameters(n_trials: int, csv_file):
379
  os.remove(temp_path)
380
  return "Dataset needs at least one feature and one target column."
381
 
 
 
 
 
382
  # Create study and optimize
383
  study = optuna.create_study(direction="maximize")
384
  study.optimize(lambda trial: objective(trial, temp_path), n_trials=n_trials)
385
 
 
 
 
 
 
 
 
 
386
  os.remove(temp_path)
 
387
  return f"""
388
  Best Hyperparameters: {study.best_params}
389
  Best F1 Score: {study.best_value:.4f}
 
56
  """ for idx, (key, insight) in enumerate(insights.items())
57
  ])
58
 
59
+ def format_analysis_report(raw_output, visuals, metrics=None, explainability_plots=None, hyperparams=None):
60
  try:
61
  # Ensure we have a dictionary to work with
62
  if isinstance(raw_output, str):
 
96
  </div>
97
  """
98
 
99
+ # Hyperparameters section
100
+ hyperparams_section = ""
101
+ if hyperparams:
102
+ hyperparams_section = f"""
103
+ <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
104
+ <h2 style="color: #2B547E;">⚙️ Model Hyperparameters</h2>
105
+ <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 15px;">
106
+ {''.join([
107
+ f"""
108
+ <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
109
+ <h3 style="margin: 0 0 10px 0; color: #4A708B;">{key.replace('_', ' ').title()}</h3>
110
+ <p style="font-size: 18px; margin: 0;">{value}</p>
111
+ </div>
112
+ """ for key, value in hyperparams.items()
113
+ ])}
114
+ </div>
115
+ </div>
116
+ """
117
+
118
  # Explainability section
119
  explainability_section = ""
120
  if explainability_plots:
 
151
  report = f"""
152
  <div style="font-family: Arial, sans-serif; padding: 20px; color: #333;">
153
  <h1 style="color: #2B547E; border-bottom: 2px solid #2B547E; padding-bottom: 10px;">📊 Data Analysis Report</h1>
154
+ {hyperparams_section}
155
  {metrics_section}
156
  {explainability_section}
157
  {observations_section}
 
280
 
281
  metrics = None
282
  explainability_plots = None
283
+ hyperparams = None
284
 
285
  try:
286
  # Load and preprocess data
 
299
  if y.dtype == object:
300
  y = pd.factorize(y)[0]
301
 
302
+ # Define model hyperparameters
303
+ hyperparams = {
304
+ 'n_estimators': 100,
305
+ 'max_depth': None,
306
+ 'min_samples_split': 2,
307
+ 'min_samples_leaf': 1,
308
+ 'max_features': 'sqrt',
309
+ 'bootstrap': True
310
+ }
311
+
312
+ # Log hyperparameters to wandb
313
+ wandb.config.update({"model_hyperparameters": hyperparams})
314
+
315
  # Evaluate baseline model
316
+ baseline_model = RandomForestClassifier(random_state=42, **hyperparams)
317
  metrics = evaluate_model(X, y, baseline_model)
318
 
319
  # Generate explainability plots
 
342
  execution_time = time.time() - start_time
343
  final_memory = process.memory_info().rss / 1024 ** 2
344
  memory_usage = final_memory - initial_memory
345
+ wandb.log({
346
+ "execution_time_sec": execution_time,
347
+ "memory_usage_mb": memory_usage,
348
+ **({"model_metrics": metrics} if metrics else {})
349
+ })
350
 
351
  visuals = [os.path.join('./figures', f) for f in os.listdir('./figures') if f.endswith(('.png', '.jpg', '.jpeg'))]
352
  for viz in visuals:
353
  wandb.log({os.path.basename(viz): wandb.Image(viz)})
354
 
355
  run.finish()
356
+ return format_analysis_report(analysis_result, visuals, metrics, explainability_plots, hyperparams)
357
 
358
  def objective(trial, csv_path):
359
  try:
 
381
  'bootstrap': trial.suggest_categorical('bootstrap', [True, False])
382
  }
383
 
384
+ # Log hyperparameters to wandb
385
+ if wandb.run:
386
+ wandb.log({"trial_params": params})
387
+
388
  # Split data
389
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
390
 
 
398
  model.fit(X_train, y_train)
399
  y_pred = model.predict(X_test)
400
 
401
+ # Calculate metrics
402
+ f1 = f1_score(y_test, y_pred, average='weighted')
403
+ accuracy = accuracy_score(y_test, y_pred)
404
+
405
+ # Log metrics to wandb
406
+ if wandb.run:
407
+ wandb.log({
408
+ "trial_f1": f1,
409
+ "trial_accuracy": accuracy,
410
+ "trial_number": trial.number
411
+ })
412
+
413
+ return f1
414
 
415
  except Exception as e:
416
  print(f"Trial failed: {str(e)}")
 
432
  os.remove(temp_path)
433
  return "Dataset needs at least one feature and one target column."
434
 
435
+ # Initialize wandb run for hyperparameter tuning
436
+ wandb.login(key=os.environ.get('WANDB_API_KEY'))
437
+ tuning_run = wandb.init(project="huggingface-hyperparameter-tuning", reinit=True)
438
+
439
  # Create study and optimize
440
  study = optuna.create_study(direction="maximize")
441
  study.optimize(lambda trial: objective(trial, temp_path), n_trials=n_trials)
442
 
443
+ # Log best parameters and metrics
444
+ tuning_run.config.update({"best_hyperparameters": study.best_params})
445
+ tuning_run.log({
446
+ "best_f1_score": study.best_value,
447
+ "best_trial_number": study.best_trial.number
448
+ })
449
+
450
+ tuning_run.finish()
451
  os.remove(temp_path)
452
+
453
  return f"""
454
  Best Hyperparameters: {study.best_params}
455
  Best F1 Score: {study.best_value:.4f}