pavanmutha commited on
Commit
15a30cc
·
verified ·
1 Parent(s): 111530c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -201
app.py CHANGED
@@ -96,23 +96,26 @@ def format_analysis_report(raw_output, visuals, metrics=None, explainability_plo
96
  </div>
97
  """
98
 
99
- hyperparams_items = ''.join([
100
- f"""
101
- <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
102
- <h3 style="margin: 0 0 10px 0; color: #4A708B;">{key.replace('_', ' ').title()}</h3>
103
- <p style="font-size: 18px; margin: 0;">{value}</p>
104
- </div>
105
- """ for key, value in hyperparams.items()
106
- ])
107
-
108
- hyperparams_section = f"""
109
- <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
110
- <h2 style="color: #2B547E;">⚙️ Model Hyperparameters</h2>
111
- <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 15px;">
112
- {hyperparams_items}
113
- </div>
114
- </div>
115
- """
 
 
 
116
 
117
  # Explainability section
118
  explainability_section = ""
@@ -304,187 +307,4 @@ def analyze_data(csv_file, additional_notes="", perform_ml=True):
304
  'max_depth': None,
305
  'min_samples_split': 2,
306
  'min_samples_leaf': 1,
307
- 'max_features': 'sqrt',
308
- 'bootstrap': True
309
- }
310
-
311
- # Log hyperparameters to wandb
312
- wandb.config.update({"model_hyperparameters": hyperparams})
313
-
314
- # Evaluate baseline model
315
- baseline_model = RandomForestClassifier(random_state=42, **hyperparams)
316
- metrics = evaluate_model(X, y, baseline_model)
317
-
318
- # Generate explainability plots
319
- feature_names = processed_df.columns[:-1]
320
- explainability_plots = generate_explainability_plots(X, baseline_model, feature_names)
321
-
322
- wandb.log(metrics)
323
- except Exception as e:
324
- print(f"ML analysis failed: {str(e)}")
325
- wandb.log({"ml_error": str(e)})
326
-
327
- # Run the main analysis
328
- agent = CodeAgent(tools=[], model=model, additional_authorized_imports=["numpy", "pandas", "matplotlib.pyplot", "seaborn"])
329
- analysis_result = agent.run("""
330
- You are an expert data analyst. Perform comprehensive analysis including:
331
- 1. Basic statistics and data quality checks
332
- 2. 3 insightful analytical questions about relationships in the data
333
- 3. Visualization of key patterns and correlations
334
- 4. Actionable real-world insights derived from findings
335
- Generate publication-quality visualizations and save to './figures/'
336
- """, additional_args={"additional_notes": additional_notes, "source_file": csv_file})
337
-
338
- except Exception as e:
339
- analysis_result = f"Analysis failed: {str(e)}"
340
-
341
- execution_time = time.time() - start_time
342
- final_memory = process.memory_info().rss / 1024 ** 2
343
- memory_usage = final_memory - initial_memory
344
- wandb.log({
345
- "execution_time_sec": execution_time,
346
- "memory_usage_mb": memory_usage,
347
- **({"model_metrics": metrics} if metrics else {})
348
- })
349
-
350
- visuals = [os.path.join('./figures', f) for f in os.listdir('./figures') if f.endswith(('.png', '.jpg', '.jpeg'))]
351
- for viz in visuals:
352
- wandb.log({os.path.basename(viz): wandb.Image(viz)})
353
-
354
- run.finish()
355
- return format_analysis_report(analysis_result, visuals, metrics, explainability_plots, hyperparams)
356
-
357
- def objective(trial, csv_path):
358
- try:
359
- # Load and preprocess data
360
- df = pd.read_csv(csv_path)
361
- processed_df = preprocess_data(df)
362
-
363
- if len(processed_df.columns) <= 1:
364
- return 0.0 # No features to work with
365
-
366
- X = processed_df.iloc[:, :-1].values
367
- y = processed_df.iloc[:, -1].values
368
-
369
- # Convert y to numeric if needed
370
- if y.dtype == object:
371
- y = pd.factorize(y)[0]
372
-
373
- # Define hyperparameter space
374
- params = {
375
- 'n_estimators': trial.suggest_int('n_estimators', 50, 500),
376
- 'max_depth': trial.suggest_int('max_depth', 3, 15),
377
- 'min_samples_split': trial.suggest_int('min_samples_split', 2, 10),
378
- 'min_samples_leaf': trial.suggest_int('min_samples_leaf', 1, 5),
379
- 'max_features': trial.suggest_categorical('max_features', ['sqrt', 'log2']),
380
- 'bootstrap': trial.suggest_categorical('bootstrap', [True, False])
381
- }
382
-
383
- # Log hyperparameters to wandb
384
- if wandb.run:
385
- wandb.log({"trial_params": params})
386
-
387
- # Split data
388
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
389
-
390
- # Standardize features
391
- scaler = StandardScaler()
392
- X_train = scaler.fit_transform(X_train)
393
- X_test = scaler.transform(X_test)
394
-
395
- # Create and evaluate model
396
- model = RandomForestClassifier(**params, random_state=42)
397
- model.fit(X_train, y_train)
398
- y_pred = model.predict(X_test)
399
-
400
- # Calculate metrics
401
- f1 = f1_score(y_test, y_pred, average='weighted')
402
- accuracy = accuracy_score(y_test, y_pred)
403
-
404
- # Log metrics to wandb
405
- if wandb.run:
406
- wandb.log({
407
- "trial_f1": f1,
408
- "trial_accuracy": accuracy,
409
- "trial_number": trial.number
410
- })
411
-
412
- return f1
413
-
414
- except Exception as e:
415
- print(f"Trial failed: {str(e)}")
416
- return 0.0
417
-
418
- def tune_hyperparameters(n_trials: int, csv_file):
419
- try:
420
- if not csv_file:
421
- return "Please upload a CSV file first for hyperparameter tuning."
422
-
423
- # Save the uploaded file temporarily for Optuna
424
- temp_path = "temp_optuna_data.csv"
425
- with open(temp_path, "wb") as f:
426
- f.write(csv_file.read())
427
-
428
- # Verify the data can be loaded
429
- df = pd.read_csv(temp_path)
430
- if len(df.columns) <= 1:
431
- os.remove(temp_path)
432
- return "Dataset needs at least one feature and one target column."
433
-
434
- # Initialize wandb run for hyperparameter tuning
435
- wandb.login(key=os.environ.get('WANDB_API_KEY'))
436
- tuning_run = wandb.init(project="huggingface-hyperparameter-tuning", reinit=True)
437
-
438
- # Create study and optimize
439
- study = optuna.create_study(direction="maximize")
440
- study.optimize(lambda trial: objective(trial, temp_path), n_trials=n_trials)
441
-
442
- # Log best parameters and metrics
443
- tuning_run.config.update({"best_hyperparameters": study.best_params})
444
- tuning_run.log({
445
- "best_f1_score": study.best_value,
446
- "best_trial_number": study.best_trial.number
447
- })
448
-
449
- tuning_run.finish()
450
- os.remove(temp_path)
451
-
452
- return f"""
453
- Best Hyperparameters: {study.best_params}
454
- Best F1 Score: {study.best_value:.4f}
455
- """
456
- except Exception as e:
457
- return f"Hyperparameter tuning failed: {str(e)}"
458
-
459
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
460
- gr.Markdown("## 📊 AI Data Analysis Agent with Hyperparameter Optimization")
461
- with gr.Row():
462
- with gr.Column():
463
- file_input = gr.File(label="Upload CSV Dataset", type="filepath")
464
- notes_input = gr.Textbox(label="Dataset Notes (Optional)", lines=3)
465
- perform_ml = gr.Checkbox(label="Perform Machine Learning Analysis", value=True)
466
- analyze_btn = gr.Button("Analyze", variant="primary")
467
- with gr.Accordion("Hyperparameter Tuning", open=False):
468
- optuna_trials = gr.Number(label="Number of Trials", value=10, precision=0)
469
- tune_btn = gr.Button("Optimize Hyperparameters", variant="secondary")
470
- with gr.Column():
471
- analysis_output = gr.HTML("""<div style="font-family: Arial, sans-serif; padding: 20px;">
472
- <h2 style="color: #2B547E;">Analysis results will appear here...</h2>
473
- <p>Upload a CSV file and click "Analyze" to begin.</p>
474
- </div>""")
475
- optuna_output = gr.Textbox(label="Tuning Results", interactive=False)
476
- gallery = gr.Gallery(label="Data Visualizations", columns=2)
477
-
478
- analyze_btn.click(
479
- fn=analyze_data,
480
- inputs=[file_input, notes_input, perform_ml],
481
- outputs=[analysis_output, gallery]
482
- )
483
- tune_btn.click(
484
- fn=tune_hyperparameters,
485
- inputs=[optuna_trials, file_input],
486
- outputs=[optuna_output]
487
- )
488
-
489
- if __name__ == "__main__":
490
- demo.launch(debug=True)
 
96
  </div>
97
  """
98
 
99
+ # Hyperparameters section
100
+ hyperparams_section = ""
101
+ if hyperparams:
102
+ hyperparams_items = ''.join([
103
+ f"""
104
+ <div style="background: white; padding: 15px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.05);">
105
+ <h3 style="margin: 0 0 10px 0; color: #4A708B;">{key.replace('_', ' ').title()}</h3>
106
+ <p style="font-size: 18px; margin: 0;">{value}</p>
107
+ </div>
108
+ """ for key, value in hyperparams.items()
109
+ ])
110
+
111
+ hyperparams_section = f"""
112
+ <div style="margin-top: 25px; background: #f8f9fa; padding: 20px; border-radius: 8px;">
113
+ <h2 style="color: #2B547E;">⚙️ Model Hyperparameters</h2>
114
+ <div style="display: grid; grid-template-columns: repeat(2, 1fr); gap: 15px;">
115
+ {hyperparams_items}
116
+ </div>
117
+ </div>
118
+ """
119
 
120
  # Explainability section
121
  explainability_section = ""
 
307
  'max_depth': None,
308
  'min_samples_split': 2,
309
  'min_samples_leaf': 1,
310
+ 'max