singhn9 commited on
Commit
9a0d8df
·
verified ·
1 Parent(s): 5a05838

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +100 -193
src/streamlit_app.py CHANGED
@@ -360,208 +360,115 @@ with tabs[3]:
360
  st.subheader("Summary statistics (numeric features)")
361
  st.dataframe(df.describe().T.style.format("{:.3f}"), height=500)
362
 
 
363
  # ----- Ensemble + SHAP tab
364
  with tabs[4]:
365
- st.subheader("Ensemble modeling sandbox (fast) + SHAP explainability")
366
- # Feature & target selector
367
- target = st.selectbox("Target variable", numeric_cols, index=numeric_cols.index("furnace_temp") if "furnace_temp" in numeric_cols else 0)
368
- default_features = [c for c in numeric_cols if c != target][:50] # preselect up to 50 features default
369
- features = st.multiselect("Model input features (select many; start with defaults)", numeric_cols, default=default_features)
370
- sample_size = st.slider("Sample rows to use for training (speed vs fidelity)", min_value=200, max_value=min(4000, df.shape[0]), value=1000, step=100)
371
- train_button = st.button("Train ensemble & compute SHAP (recommended sample only)")
372
- # Model Remediation & Tuning Options
373
- st.markdown("### Model Remediation & Tuning Options")
374
- st.info("Use these to improve flat or low-variance predictions without editing code.")
375
-
376
- colA, colB, colC = st.columns(3)
377
- with colA:
378
- apply_scaling = st.checkbox("Apply StandardScaler()", value=False)
379
- feature_filter = st.checkbox("Use key furnace-relevant features", value=True)
380
- with colB:
381
- random_seed = st.number_input("Random Seed", min_value=0, max_value=9999, value=42)
382
- n_estimators = st.slider("n_estimators (trees)", 50, 600, 150, step=25)
383
- with colC:
384
- furnace_temp_sd = st.slider("Synthetic Furnace Temp σ (spread)", 20, 500, 50, step=10)
385
- arc_power_sd = st.slider("Synthetic Arc Power σ (spread)", 50, 700, 120, step=10)
386
- st.markdown("---")
387
 
388
- # --- Variance Controls UI ---
389
- st.markdown("#### Variance controls (global & per-feature)")
390
- global_var_mult = st.slider(
391
- "Global variance multiplier", 0.1, 5.0, 1.0, step=0.1,
392
- help="Multiply base standard deviations by this factor for all features."
393
- )
394
-
395
- # Optional: choose features to override
396
- feat_for_override = st.multiselect(
397
- "Select features to override variance (optional)", numeric_cols, max_selections=8
 
 
 
 
 
 
 
 
 
 
 
398
  )
399
- variance_overrides = {}
400
- if feat_for_override:
401
- st.markdown("Set multipliers for selected features")
402
- for f in feat_for_override:
403
- mult = st.number_input(
404
- f"Variance multiplier for {f}", min_value=0.1, max_value=10.0,
405
- value=1.0, step=0.1, key=f"mult_{f}"
406
- )
407
- variance_overrides[f] = float(mult)
408
-
409
- st.markdown("---")
410
-
411
- # --- Regeneration button ---
412
- if st.button("Regenerate Synthetic Dataset with Updated Variance"):
413
- with st.spinner("Regenerating synthetic data..."):
414
- variance_overrides.update({
415
- "furnace_temp": furnace_temp_sd / 50,
416
- "arc_power": arc_power_sd / 120
417
- })
418
- CSV_PATH, META_PATH, PDF_PATH = generate_advanced_flatfile(
419
- n_rows=3000,
420
- random_seed=int(random_seed),
421
- max_polynomial_new=60,
422
- global_variance_multiplier=float(global_var_mult),
423
- variance_overrides=variance_overrides,
424
- )
425
-
426
- # Clear cache and reload fresh
427
- st.cache_data.clear()
428
- df, meta_df = load_data(csv_path=CSV_PATH, meta_path=META_PATH)
429
-
430
- numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
431
-
432
- st.success(
433
- f"Synthetic dataset regenerated — {len(df)} rows × {len(df.columns)} features "
434
- f"(Global×{global_var_mult:.2f}; Overrides={len(variance_overrides)})"
435
- )
436
- st.caption(
437
- f"Mean furnace_temp: {df['furnace_temp'].mean():.2f}, "
438
- f"Std furnace_temp: {df['furnace_temp'].std():.2f}"
439
- )
440
-
441
-
442
- if train_button:
443
- with st.spinner("Preparing data and training ensemble..."):
444
- sub_df = df[features + [target]].sample(n=sample_size, random_state=42)
445
- X = sub_df[features].fillna(0)
446
- y = sub_df[target].fillna(0)
447
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
448
- # models
449
- models = {
450
- "Linear": LinearRegression(),
451
- "RandomForest": RandomForestRegressor(n_estimators=150, random_state=42, n_jobs=-1),
452
- "GradientBoosting": GradientBoostingRegressor(n_estimators=150, random_state=42),
453
- "ExtraTrees": ExtraTreesRegressor(n_estimators=150, random_state=42, n_jobs=-1)
454
- }
455
- preds = {}
456
- results = []
457
- for name, m in models.items():
458
- m.fit(X_train, y_train)
459
- p = m.predict(X_test)
460
- preds[name] = p
461
- results.append({"Model": name, "R2": r2_score(y_test, p), "RMSE": float(np.sqrt(mean_squared_error(y_test, p)))})
462
- # ensemble average
463
- ensemble_pred = np.column_stack(list(preds.values())).mean(axis=1)
464
- results.append({"Model": "EnsembleAvg", "R2": r2_score(y_test, ensemble_pred), "RMSE": float(np.sqrt(mean_squared_error(y_test, ensemble_pred)))})
465
- st.dataframe(pd.DataFrame(results).set_index("Model").round(4))
466
-
467
- # scatter
468
- fig, ax = plt.subplots(figsize=(8,4))
469
- ax.scatter(y_test, ensemble_pred, alpha=0.5)
470
- ax.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], "r--")
471
- ax.set_xlabel("Actual"); ax.set_ylabel("Predicted (Ensemble)")
472
- st.pyplot(fig)
473
 
474
- # save the models (lightweight)
475
- joblib.dump(models, ENSEMBLE_ARTIFACT)
476
- st.success(f"Saved ensemble models to {ENSEMBLE_ARTIFACT}")
 
 
477
 
478
- # ---------- SHAP explainability ----------
479
- st.markdown("### SHAP Explainability — pick a model to explain (Tree models recommended)")
480
- explain_model_name = st.selectbox("Model to explain", list(models.keys()), index= list(models.keys()).index("RandomForest") if "RandomForest" in models else 0)
481
- explainer_sample = st.slider("Number of rows to use for SHAP explanation (memory heavy)", 50, min(1500, sample_size), value=300, step=50)
 
482
 
483
- # Use a Tree explainer if possible; otherwise KernelExplainer (slow)
484
- model_to_explain = models[explain_model_name]
485
- X_shap = X_test.copy()
486
- if explainer_sample < X_shap.shape[0]:
487
- X_shap_for = X_shap.sample(n=explainer_sample, random_state=42)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
  else:
489
- X_shap_for = X_shap
490
-
491
- with st.spinner("Computing SHAP values (this may take a while for large SHAP sample)..."):
492
- try:
493
- if hasattr(model_to_explain, "predict") and (explain_model_name in ["RandomForest","ExtraTrees","GradientBoosting"]):
494
- explainer = shap.TreeExplainer(model_to_explain)
495
- shap_values = explainer.shap_values(X_shap_for)
496
- # summary plot
497
- import warnings
498
- warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
499
- fig_shap = plt.figure(figsize=(8,6))
500
- shap.summary_plot(shap_values, X_shap_for, show=False)
501
- st.pyplot(fig_shap)
502
- else:
503
- # fallback: use KernelExplainer on small sample (very slow)
504
- explainer = shap.KernelExplainer(model_to_explain.predict, shap.sample(X_train, 100))
505
- shap_values = explainer.shap_values(X_shap_for, nsamples=100)
506
- fig_shap = plt.figure(figsize=(8,6))
507
- shap.summary_plot(shap_values, X_shap_for, show=False)
508
- st.pyplot(fig_shap)
509
- st.success("SHAP summary plotted.")
510
- except Exception as e:
511
- st.error(f"SHAP failed: {e}")
512
- # per-instance explanation waterfall
513
- st.markdown("#### Explain a single prediction (waterfall):")
514
- idx_choice = st.number_input("Row index (0..n_test-1)", min_value=0, max_value=X_shap.shape[0]-1, value=0)
515
- try:
516
- row = X_shap_for.iloc[[idx_choice]]
517
- if explain_model_name in ["RandomForest","ExtraTrees","GradientBoosting"]:
518
- expl = shap.TreeExplainer(model_to_explain)
519
- shap_vals_row = expl.shap_values(row)
520
- exp_val = expl.expected_value
521
- shap_vals = shap_vals_row
522
-
523
- # Handle tree models returning arrays for single target
524
- if isinstance(exp_val, (list, np.ndarray)) and not np.isscalar(exp_val):
525
- exp_val = exp_val[0]
526
- if isinstance(shap_vals, list):
527
- shap_vals = shap_vals[0]
528
-
529
- exp_val = expl.expected_value
530
- shap_vals = shap_vals_row
531
-
532
- # Handle multi-output case
533
- if isinstance(exp_val, (list, np.ndarray)) and not np.isscalar(exp_val):
534
- exp_val = exp_val[0]
535
- if isinstance(shap_vals, list):
536
- shap_vals = shap_vals[0]
537
-
538
- # Plot safely across SHAP versions
539
- try:
540
- explanation = shap.Explanation(
541
- values=shap_vals[0],
542
- base_values=exp_val,
543
- data=row.iloc[0],
544
- feature_names=row.columns.tolist()
545
- )
546
- plot_obj = shap.plots.waterfall(explanation, show=False)
547
-
548
- # If SHAP returns Axes instead of Figure, wrap it
549
- import matplotlib.pyplot as plt
550
- if hasattr(plot_obj, "figure"):
551
- fig2 = plot_obj.figure
552
- else:
553
- fig2 = plt.gcf()
554
-
555
- st.pyplot(fig2)
556
- except Exception as e:
557
- st.warning(f"Waterfall plotting failed gracefully: {e}")
558
-
559
-
560
- else:
561
- st.info("Per-instance waterfall not available for this model type in fallback.")
562
- except Exception as e:
563
- st.warning(f"Could not plot waterfall: {e}")
564
 
 
565
 
566
 
567
  # ----- Target & Business Impact tab
 
360
  st.subheader("Summary statistics (numeric features)")
361
  st.dataframe(df.describe().T.style.format("{:.3f}"), height=500)
362
 
363
+
364
  # ----- Ensemble + SHAP tab
365
  with tabs[4]:
366
+ st.subheader("Autonomous Ensemble Modeling + SHAP Explainability")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
367
 
368
+ # --- Step 1: Basic UI selections ---
369
+ target = st.selectbox("Target variable", numeric_cols, index=numeric_cols.index("furnace_temp") if "furnace_temp" in numeric_cols else 0)
370
+ default_features = [c for c in numeric_cols if c != target][:60]
371
+ features = st.multiselect("Model input features", numeric_cols, default=default_features)
372
+ sample_size = st.slider("Sample rows for training", 500, min(4000, df.shape[0]), 1000, step=100)
373
+ sub_df = df[features + [target]].sample(n=sample_size, random_state=42)
374
+ X = sub_df[features].fillna(0)
375
+ y = sub_df[target].fillna(0)
376
+
377
+ # --- Step 2: Business / Process Objective selection ---
378
+ st.markdown("### 🎯 Select Operational Objective")
379
+ objective = st.selectbox(
380
+ "Optimization Objective",
381
+ [
382
+ "Maximize Accuracy (R²)",
383
+ "Minimize RMSE (Stable Control)",
384
+ "Maximize Yield Ratio (EAF/Inventory)",
385
+ "Minimize Energy Consumption (Efficiency)",
386
+ "Balanced (Accuracy + Efficiency)"
387
+ ],
388
+ index=0
389
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
391
+ # --- Step 3: Auto-tuning with Optuna ---
392
+ import optuna
393
+ from sklearn.model_selection import cross_val_score
394
+
395
+ st.markdown("### ⚙️ Auto Tuning in Progress")
396
 
397
+ def objective_fn(trial):
398
+ model_name = trial.suggest_categorical("model", ["RandomForest", "GradientBoosting", "ExtraTrees"])
399
+ n_estimators = trial.suggest_int("n_estimators", 100, 600)
400
+ max_depth = trial.suggest_int("max_depth", 3, 20)
401
+ learning_rate = trial.suggest_float("learning_rate", 0.01, 0.3, log=True)
402
 
403
+ if model_name == "RandomForest":
404
+ model = RandomForestRegressor(n_estimators=n_estimators, max_depth=max_depth, n_jobs=-1)
405
+ elif model_name == "GradientBoosting":
406
+ model = GradientBoostingRegressor(n_estimators=n_estimators, learning_rate=learning_rate, max_depth=max_depth)
407
+ else:
408
+ model = ExtraTreesRegressor(n_estimators=n_estimators, max_depth=max_depth, n_jobs=-1)
409
+
410
+ # Metric selection
411
+ scoring_metric = "r2"
412
+ if "RMSE" in objective:
413
+ scoring_metric = "neg_root_mean_squared_error"
414
+
415
+ score = cross_val_score(model, X, y, cv=3, scoring=scoring_metric).mean()
416
+ return score
417
+
418
+ if st.button("Run Auto Ensemble Optimization"):
419
+ with st.spinner("Optimizing models... please wait (~20–60s)"):
420
+ study = optuna.create_study(direction="maximize")
421
+ study.optimize(objective_fn, n_trials=20)
422
+
423
+ best_params = study.best_params
424
+ st.success("✅ Best Auto-Tuned Model Found")
425
+ st.json(best_params)
426
+
427
+ # Build best model
428
+ model_name = best_params.pop("model")
429
+ if model_name == "RandomForest":
430
+ model = RandomForestRegressor(**best_params)
431
+ elif model_name == "GradientBoosting":
432
+ model = GradientBoostingRegressor(**best_params)
433
  else:
434
+ model = ExtraTreesRegressor(**best_params)
435
+ model.fit(X, y)
436
+
437
+ # Save model
438
+ joblib.dump(model, ENSEMBLE_ARTIFACT)
439
+ st.caption(f"Model saved: {ENSEMBLE_ARTIFACT}")
440
+
441
+ # --- Auto Visualizations ---
442
+ st.markdown("### 📈 Optimization History")
443
+ fig_hist = optuna.visualization.matplotlib.plot_optimization_history(study)
444
+ st.pyplot(fig_hist)
445
+
446
+ # Predictions
447
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
448
+ y_pred = model.predict(X_test)
449
+
450
+ r2 = r2_score(y_test, y_pred)
451
+ rmse = mean_squared_error(y_test, y_pred, squared=False)
452
+
453
+ st.metric("R² Score", f"{r2:.3f}")
454
+ st.metric("RMSE", f"{rmse:.3f}")
455
+
456
+ # Scatter plot
457
+ fig, ax = plt.subplots(figsize=(7,4))
458
+ ax.scatter(y_test, y_pred, alpha=0.6)
459
+ ax.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], "r--")
460
+ ax.set_xlabel("Actual"); ax.set_ylabel("Predicted")
461
+ st.pyplot(fig)
462
+
463
+ # --- SHAP Explainability for Best Model ---
464
+ st.markdown("### 🔍 SHAP Explainability (Auto Model)")
465
+ explainer = shap.TreeExplainer(model)
466
+ shap_values = explainer.shap_values(X_test.sample(300))
467
+ fig_shap = plt.figure(figsize=(8,6))
468
+ shap.summary_plot(shap_values, X_test.sample(300), show=False)
469
+ st.pyplot(fig_shap)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
 
471
+ st.info("Auto tuning complete. Model performance and SHAP summary shown above.")
472
 
473
 
474
  # ----- Target & Business Impact tab