damndeepesh commited on
Commit
be587b5
·
verified ·
1 Parent(s): 1b86695

updated files

Browse files
Files changed (3) hide show
  1. README.md +0 -10
  2. app.py +497 -169
  3. requirements.txt +5 -1
README.md CHANGED
@@ -1,13 +1,3 @@
1
- ---
2
- license: mit
3
- title: AutoML
4
- sdk: streamlit
5
- emoji: 🏆
6
- colorFrom: green
7
- colorTo: yellow
8
- pinned: true
9
- sdk_version: 1.45.1
10
- ---
11
  # AutoML & Explainability Web Application
12
 
13
  This Streamlit web application empowers users to perform end-to-end machine learning tasks with ease. Upload your data, automatically train and compare various models, understand their predictions through SHAP explainability, and export the best model for your needs.
 
 
 
 
 
 
 
 
 
 
 
1
  # AutoML & Explainability Web Application
2
 
3
  This Streamlit web application empowers users to perform end-to-end machine learning tasks with ease. Upload your data, automatically train and compare various models, understand their predictions through SHAP explainability, and export the best model for your needs.
app.py CHANGED
@@ -1,16 +1,20 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
- from sklearn.model_selection import train_test_split, cross_val_score
5
  from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, RandomForestRegressor, GradientBoostingRegressor
6
  from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
7
  from sklearn.svm import SVC, SVR
8
  from sklearn.linear_model import LogisticRegression, LinearRegression, Ridge, ElasticNet
9
  from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
10
  from sklearn.naive_bayes import GaussianNB
11
- from sklearn.preprocessing import StandardScaler, LabelEncoder
12
  from sklearn.impute import SimpleImputer
13
  from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_auc_score, f1_score
 
 
 
 
14
  import shap
15
  import matplotlib.pyplot as plt
16
  import seaborn as sns
@@ -31,36 +35,15 @@ st.set_page_config(
31
  )
32
 
33
  # Custom CSS for better styling
34
- st.markdown("""
35
- <style>
36
- .main-header {
37
- font-size: 2.5rem;
38
- color: #1f77b4;
39
- text-align: center;
40
- margin-bottom: 2rem;
41
- }
42
- .metric-card {
43
- background-color: #f0f2f6;
44
- padding: 1rem;
45
- border-radius: 0.5rem;
46
- margin: 0.5rem 0;
47
- box-shadow: 0 2px 4px rgba(0,0,0,0.1);
48
- }
49
- .success-message {
50
- background-color: #d4edda;
51
- color: #155724;
52
- padding: 1rem;
53
- border-radius: 0.5rem;
54
- border: 1px solid #c3e6cb;
55
- }
56
- .stButton>button {
57
- width: 100%;
58
- border-radius: 0.5rem;
59
- }
60
- </style>
61
- """, unsafe_allow_html=True)
62
-
63
  # --- Helper Functions ---
 
 
 
 
 
 
 
 
64
  def get_model_metrics(y_true, y_pred, y_proba=None, problem_type='Classification'):
65
  metrics = {}
66
  if problem_type == "Classification":
@@ -141,7 +124,7 @@ def data_upload_page():
141
  st.session_state.problem_type = None
142
  st.session_state.source_data_type = 'single'
143
  except Exception as e:
144
- st.error(f"Error reading single file: {e}")
145
  return
146
  elif uploaded_train_file:
147
  try:
@@ -158,7 +141,7 @@ def data_upload_page():
158
  else:
159
  st.session_state.test_data = None # Explicitly set to None
160
  except Exception as e:
161
- st.error(f"Error reading train/test files: {e}")
162
  return
163
 
164
  if df is not None:
@@ -226,7 +209,7 @@ def data_upload_page():
226
  st.session_state.problem_type = "Regression"
227
  else:
228
  st.session_state.problem_type = "Unsupported Target Type"
229
- st.error("Target column type is not suitable for classification or regression.")
230
  return
231
 
232
  st.success(f"Target column '{target_column}' selected. Problem Type: {st.session_state.problem_type}")
@@ -239,7 +222,7 @@ def data_upload_page():
239
  col3_test.metric("Test Missing Values", st.session_state.test_data.isnull().sum().sum())
240
  st.dataframe(st.session_state.test_data.head(5), use_container_width=True)
241
  if target_column not in st.session_state.test_data.columns:
242
- st.error(f"Target column '{target_column}' not found in the uploaded test data. Please ensure column names match.")
243
  return # Stop further processing if target is missing in test data
244
 
245
  st.subheader(f"Target Column Distribution (in {'Training Data' if st.session_state.get('source_data_type') == 'separate' else 'Uploaded Data'}): {target_column}")
@@ -253,19 +236,17 @@ def data_upload_page():
253
  st.pyplot(fig)
254
 
255
  except Exception as e:
256
- st.error(f"Error reading or processing file: {e}")
257
  if auto_run_training and st.session_state.target_column:
258
  st.session_state.auto_run_triggered = True
259
  st.experimental_rerun() # Rerun to switch page or trigger training
260
 
261
  except Exception as e:
262
- st.error(f"Error processing data: {e}")
263
- import traceback
264
- st.error(traceback.format_exc())
265
  else:
266
  st.info("👆 Please upload a CSV or Excel file (or separate train/test files) to get started.")
267
 
268
- def preprocess_data(df, target_column):
269
  X = df.drop(columns=[target_column])
270
  y = df[target_column].copy() # Use .copy() to avoid SettingWithCopyWarning
271
 
@@ -299,6 +280,9 @@ def preprocess_data(df, target_column):
299
  if len(cat_cols) > 0:
300
  X[cat_cols] = cat_imputer.fit_transform(X[cat_cols])
301
 
 
 
 
302
  # Encode categorical features
303
  le_dict_features = {}
304
  for col in cat_cols:
@@ -328,10 +312,10 @@ def model_training_page():
328
  data_available = (st.session_state.data is not None) or \
329
  (st.session_state.train_data is not None)
330
  if not data_available or st.session_state.target_column is None:
331
- st.warning("⚠️ Please upload data (single or train/test) and select a target column first.")
332
  return
333
  if st.session_state.problem_type == "Unsupported Target Type":
334
- st.error("Cannot train models with the current target column type.")
335
  return
336
 
337
  target = st.session_state.target_column
@@ -343,7 +327,27 @@ def model_training_page():
343
  test_size = col1.slider("Test Size (if splitting single file)", 0.1, 0.5, 0.2, 0.05, disabled=disable_test_size)
344
  random_state = col1.number_input("Random State", value=42, min_value=0)
345
  cv_folds = col2.slider("Cross-Validation Folds", 3, 10, 5)
346
- scale_features = col2.checkbox("Scale Numeric Features", value=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
347
 
348
  # Auto-start training if triggered
349
  start_button_pressed = st.button("🎯 Start Training", type="primary", key='manual_start_train_button')
@@ -364,7 +368,7 @@ def model_training_page():
364
  if st.session_state.test_data is not None:
365
  df_test_processed = st.session_state.test_data.copy()
366
  if target not in df_test_processed.columns:
367
- st.error(f"Target column '{target}' not found in test data during preprocessing. Aborting.")
368
  return
369
  X_test, y_test = preprocess_data(df_test_processed, target) # Preprocess test data separately
370
  # Ensure X_test has same columns as X_train after preprocessing (esp. after one-hot encoding if added later)
@@ -384,54 +388,136 @@ def model_training_page():
384
  )
385
 
386
  if X_train is None or y_train is None:
387
- st.error("Training data (X_train, y_train) could not be prepared. Please check your data and selections.")
388
  return
389
 
390
  # Scaling should be fit on X_train and transformed on X_test
391
- if scale_features:
 
392
  num_cols_train = X_train.select_dtypes(include=np.number).columns
393
  if len(num_cols_train) > 0:
394
- scaler = StandardScaler()
395
- X_train[num_cols_train] = scaler.fit_transform(X_train[num_cols_train])
396
- st.session_state.scaler = scaler # Save the fitted scaler
397
- if X_test is not None:
398
- num_cols_test = X_test.select_dtypes(include=np.number).columns
399
- # Ensure test set uses the same numeric columns in the same order as train set for scaling
400
- cols_to_scale_in_test = [col for col in num_cols_train if col in X_test.columns]
401
- if len(cols_to_scale_in_test) > 0:
402
- # Create a DataFrame with columns in the order of num_cols_train
403
- X_test_subset_for_scaling = X_test[cols_to_scale_in_test]
404
- X_test_scaled_values = scaler.transform(X_test_subset_for_scaling)
405
- X_test[cols_to_scale_in_test] = X_test_scaled_values
406
- # Handle missing/extra columns if necessary, for now assume they match or subset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
 
408
  st.session_state.update({'X_train': X_train, 'X_test': X_test, 'y_train': y_train, 'y_test': y_test})
409
 
410
  # Define models based on problem type
 
411
  if st.session_state.problem_type == "Classification":
412
- models_to_train = {
413
- "Logistic Regression": LogisticRegression(random_state=random_state, max_iter=1000),
414
- "Decision Tree": DecisionTreeClassifier(random_state=random_state),
415
- "Random Forest": RandomForestClassifier(random_state=random_state),
416
- "Gradient Boosting": GradientBoostingClassifier(random_state=random_state),
417
- "Support Vector Machine": SVC(random_state=random_state, probability=True),
418
- "K-Nearest Neighbors": KNeighborsClassifier(),
419
- "Gaussian Naive Bayes": GaussianNB()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  }
421
  scoring = 'accuracy'
422
  else: # Regression
423
- # Local imports for LinearRegression, Ridge, RandomForestRegressor, etc.
424
- # are removed as these models are now imported globally by the first search/replace block.
425
- # ElasticNet is also imported globally.
426
- models_to_train = {
427
- "Linear Regression": LinearRegression(),
428
- "Ridge Regression": Ridge(random_state=random_state),
429
- "ElasticNet Regression": ElasticNet(random_state=random_state),
430
- "Random Forest Regressor": RandomForestRegressor(random_state=random_state),
431
- "Gradient Boosting Regressor": GradientBoostingRegressor(random_state=random_state),
432
- "Decision Tree Regressor": DecisionTreeRegressor(random_state=random_state),
433
- "Support Vector Regressor": SVR(),
434
- "K-Nearest Neighbors Regressor": KNeighborsRegressor()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
435
  }
436
  scoring = 'r2'
437
 
@@ -440,22 +526,86 @@ def model_training_page():
440
  progress_bar = st.progress(0)
441
  status_text = st.empty()
442
 
443
- for i, (name, model) in enumerate(models_to_train.items()):
444
- status_text.text(f"Training {name}...")
445
- model.fit(X_train, y_train)
446
- trained_models[name] = model
447
-
448
- y_pred_test = model.predict(X_test)
449
- y_proba_test = model.predict_proba(X_test) if hasattr(model, 'predict_proba') and st.session_state.problem_type == "Classification" else None
450
-
451
- metrics = get_model_metrics(y_test, y_pred_test, y_proba_test, problem_type=st.session_state.problem_type)
452
- cv_score = cross_val_score(model, X_train, y_train, cv=cv_folds, scoring=scoring).mean()
453
-
454
- current_model_scores = {'CV Mean Score': cv_score}
455
- current_model_scores.update(metrics) # Add all relevant metrics
456
- model_scores_dict[name] = current_model_scores
457
-
458
- progress_bar.progress((i + 1) / len(models_to_train))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
459
 
460
  st.session_state.models = trained_models
461
  st.session_state.model_scores = model_scores_dict
@@ -476,9 +626,7 @@ def model_training_page():
476
  st.success(f"✅ Training completed! Best model: {best_model_name}")
477
 
478
  except Exception as e:
479
- st.error(f"Error during training: {e}")
480
- import traceback
481
- st.error(traceback.format_exc())
482
 
483
  def model_comparison_page():
484
  st.header("📊 Model Comparison")
@@ -519,20 +667,130 @@ def model_comparison_page():
519
  st.subheader(f"📋 Detailed Metrics for Best Model: {best_model_name}")
520
  best_model = st.session_state.best_model_info['model']
521
  y_pred = best_model.predict(st.session_state.X_test)
522
-
523
- col1, col2 = st.columns(2)
524
- with col1:
525
- st.text("Classification Report:")
526
- report_df = pd.DataFrame(classification_report(st.session_state.y_test, y_pred, output_dict=True)).transpose()
527
- st.dataframe(report_df.round(3), use_container_width=True)
528
- with col2:
529
- st.text("Confusion Matrix:")
530
- cm = confusion_matrix(st.session_state.y_test, y_pred)
531
- fig_cm, ax_cm = plt.subplots()
532
- sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax_cm)
533
- ax_cm.set_xlabel('Predicted')
534
- ax_cm.set_ylabel('Actual')
535
- st.pyplot(fig_cm)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
536
 
537
  def explainability_page():
538
  st.header("🔍 Model Explainability (SHAP)")
@@ -548,20 +806,36 @@ def explainability_page():
548
  with st.spinner("Generating SHAP explanations..."):
549
  try:
550
  # SHAP Explainer
551
- if isinstance(best_model, (RandomForestClassifier, GradientBoostingClassifier, DecisionTreeClassifier,
552
- RandomForestRegressor, GradientBoostingRegressor, DecisionTreeRegressor)):
553
- explainer = shap.TreeExplainer(best_model)
554
- elif isinstance(best_model, (LogisticRegression, LinearRegression, Ridge, ElasticNet)):
555
- explainer = shap.LinearExplainer(best_model, X_test_df) # Pass data for LinearExplainer
556
- elif isinstance(best_model, (SVC, SVR, KNeighborsClassifier, KNeighborsRegressor, GaussianNB)):
557
- # KernelExplainer can be slow or not directly applicable for some, use a subset of X_train for background data
558
- # For KNN and Naive Bayes, KernelExplainer is a common choice for SHAP if TreeExplainer/LinearExplainer aren't suitable.
559
- background_data = shap.sample(st.session_state.X_train, min(100, len(st.session_state.X_train)))
560
- if isinstance(background_data, np.ndarray):
561
- background_data = pd.DataFrame(background_data, columns=X_test_df.columns)
562
- explainer = shap.KernelExplainer(best_model.predict_proba if hasattr(best_model, 'predict_proba') else best_model.predict, background_data)
563
- else:
564
- st.error(f"SHAP explanations not supported for {best_model_name} with current setup.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
565
  return
566
 
567
  shap_values = explainer.shap_values(X_test_df)
@@ -614,9 +888,7 @@ def explainability_page():
614
  st.metric("Predicted Value", f"{predicted:.2f}")
615
 
616
  except Exception as e:
617
- st.error(f"Error generating SHAP explanations: {e}")
618
- import traceback
619
- st.error(traceback.format_exc())
620
 
621
  def model_export_page():
622
  st.header("💾 Model Export")
@@ -637,9 +909,31 @@ def model_export_page():
637
  steps = []
638
  if st.session_state.scaler:
639
  steps.append(('scaler', st.session_state.scaler))
640
- steps.append(('model', best_model))
641
- pipeline_to_export = Pipeline(steps)
642
- st.session_state.trained_pipeline = pipeline_to_export
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
643
 
644
  export_format = st.selectbox("Choose export format:", ["Joblib (.joblib)", "Pickle (.pkl)"])
645
  file_name_suggestion = f"{best_model_name.lower().replace(' ', '_')}_pipeline"
@@ -664,12 +958,28 @@ def model_export_page():
664
  )
665
  st.success("Model pipeline ready for download!")
666
  except Exception as e:
667
- st.error(f"Error exporting model: {e}")
668
 
669
  st.subheader("📖 How to use the exported pipeline:")
670
- st.code(f"""
671
- import joblib # or import pickle
 
 
 
 
 
672
  import pandas as pd
 
 
 
 
 
 
 
 
 
 
 
673
 
674
  # Load the pipeline
675
  pipeline = joblib.load('{file_name}{'.joblib' if 'Joblib' in export_format else '.pkl'}')
@@ -683,40 +993,58 @@ pipeline = joblib.load('{file_name}{'.joblib' if 'Joblib' in export_format else
683
  # Make predictions
684
  # predictions = pipeline.predict(new_data)
685
  # print(predictions)
686
- """, language='python')
687
 
688
- # --- Main Application ---
689
- def main():
690
- init_session_state()
691
- st.markdown('<h1 class="main-header">🤖 AutoML & Explainability Platform</h1>', unsafe_allow_html=True)
692
-
693
- st.sidebar.title("⚙️ Workflow")
694
- page_options = ["Data Upload & Preview", "Model Training", "Model Comparison", "Explainability", "Model Export"]
695
 
696
- # Handle auto-run navigation
697
- if st.session_state.get('auto_run_triggered') and st.session_state.target_column:
698
- st.session_state.auto_run_triggered = False # Reset trigger
699
- st.session_state.current_page = "Model Training"
700
- st.session_state.auto_run_triggered_for_training = True # Signal model_training_page to auto-start
701
 
702
- if 'current_page' not in st.session_state:
703
- st.session_state.current_page = "Data Upload & Preview"
704
-
705
- page = st.sidebar.radio("Navigate", page_options, key='navigation_radio', index=page_options.index(st.session_state.current_page))
706
- st.session_state.current_page = page # Update current page based on user selection
707
-
708
- if page == "Data Upload & Preview":
709
- data_upload_page()
710
- elif page == "Model Training":
711
- model_training_page()
712
- elif page == "Model Comparison":
713
- model_comparison_page()
714
- elif page == "Explainability":
715
- explainability_page()
716
- elif page == "Model Export":
717
- model_export_page()
718
-
719
- st.sidebar.markdown("---_Developed with Trae AI_---")
720
-
721
- if __name__ == "__main__":
722
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
+ from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV, RandomizedSearchCV
5
  from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier, RandomForestRegressor, GradientBoostingRegressor
6
  from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
7
  from sklearn.svm import SVC, SVR
8
  from sklearn.linear_model import LogisticRegression, LinearRegression, Ridge, ElasticNet
9
  from sklearn.neighbors import KNeighborsClassifier, KNeighborsRegressor
10
  from sklearn.naive_bayes import GaussianNB
11
+ from sklearn.preprocessing import StandardScaler, MinMaxScaler, LabelEncoder
12
  from sklearn.impute import SimpleImputer
13
  from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, roc_auc_score, f1_score
14
+ # Import advanced models
15
+ import xgboost as xgb
16
+ import lightgbm as lgb
17
+ import catboost as cb
18
  import shap
19
  import matplotlib.pyplot as plt
20
  import seaborn as sns
 
35
  )
36
 
37
  # Custom CSS for better styling
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  # --- Helper Functions ---
39
+ def display_error(e, context="An unexpected error occurred"):
40
+ """Displays a user-friendly error message."""
41
+ st.error(f"😕 Oops! Something went wrong. {context}. Please check your inputs or the data format.")
42
+ st.error(f"Details: {str(e)}")
43
+ # Optionally, log the full traceback for debugging, but don't show it to the user by default
44
+ # import traceback
45
+ # st.expander("See Full Error Traceback").error(traceback.format_exc())
46
+
47
  def get_model_metrics(y_true, y_pred, y_proba=None, problem_type='Classification'):
48
  metrics = {}
49
  if problem_type == "Classification":
 
124
  st.session_state.problem_type = None
125
  st.session_state.source_data_type = 'single'
126
  except Exception as e:
127
+ display_error(e, "Failed to read the uploaded single file")
128
  return
129
  elif uploaded_train_file:
130
  try:
 
141
  else:
142
  st.session_state.test_data = None # Explicitly set to None
143
  except Exception as e:
144
+ display_error(e, "Failed to read the uploaded train/test files")
145
  return
146
 
147
  if df is not None:
 
209
  st.session_state.problem_type = "Regression"
210
  else:
211
  st.session_state.problem_type = "Unsupported Target Type"
212
+ st.error("The selected target column has an unsupported data type. Please choose a numeric column for regression or a categorical/binary column for classification.")
213
  return
214
 
215
  st.success(f"Target column '{target_column}' selected. Problem Type: {st.session_state.problem_type}")
 
222
  col3_test.metric("Test Missing Values", st.session_state.test_data.isnull().sum().sum())
223
  st.dataframe(st.session_state.test_data.head(5), use_container_width=True)
224
  if target_column not in st.session_state.test_data.columns:
225
+ st.error(f"The target column '{target_column}' was not found in your uploaded test data. Please ensure the column names match exactly between your training and testing datasets.")
226
  return # Stop further processing if target is missing in test data
227
 
228
  st.subheader(f"Target Column Distribution (in {'Training Data' if st.session_state.get('source_data_type') == 'separate' else 'Uploaded Data'}): {target_column}")
 
236
  st.pyplot(fig)
237
 
238
  except Exception as e:
239
+ display_error(e, "An error occurred while reading or performing initial processing on the file")
240
  if auto_run_training and st.session_state.target_column:
241
  st.session_state.auto_run_triggered = True
242
  st.experimental_rerun() # Rerun to switch page or trigger training
243
 
244
  except Exception as e:
245
+ display_error(e, "An error occurred during data processing and analysis")
 
 
246
  else:
247
  st.info("👆 Please upload a CSV or Excel file (or separate train/test files) to get started.")
248
 
249
+ def preprocess_data(df, target_column, scaling_method="None"):
250
  X = df.drop(columns=[target_column])
251
  y = df[target_column].copy() # Use .copy() to avoid SettingWithCopyWarning
252
 
 
280
  if len(cat_cols) > 0:
281
  X[cat_cols] = cat_imputer.fit_transform(X[cat_cols])
282
 
283
+ # Scaling is handled in the model_training_page after splitting, so not here.
284
+ # This function will just do imputation and encoding.
285
+
286
  # Encode categorical features
287
  le_dict_features = {}
288
  for col in cat_cols:
 
312
  data_available = (st.session_state.data is not None) or \
313
  (st.session_state.train_data is not None)
314
  if not data_available or st.session_state.target_column is None:
315
+ st.warning("⚠️ Please upload your data and select a target column on the 'Data Upload & Preview' page before proceeding to model training.")
316
  return
317
  if st.session_state.problem_type == "Unsupported Target Type":
318
+ st.error("Cannot train models because the selected target column has an unsupported data type. Please go back and select a suitable target column.")
319
  return
320
 
321
  target = st.session_state.target_column
 
327
  test_size = col1.slider("Test Size (if splitting single file)", 0.1, 0.5, 0.2, 0.05, disabled=disable_test_size)
328
  random_state = col1.number_input("Random State", value=42, min_value=0)
329
  cv_folds = col2.slider("Cross-Validation Folds", 3, 10, 5)
330
+ # scale_features checkbox is replaced by a selectbox for scaling_method
331
+ scaling_method_options = ["None", "StandardScaler", "MinMaxScaler"]
332
+ scaling_method = col2.selectbox("Numeric Feature Scaling", options=scaling_method_options, index=1, key='scaling_method_selector') # Default to StandardScaler
333
+ st.session_state.scaling_method = scaling_method # Store for use during preprocessing
334
+
335
+ # Initialize session state variables if they don't exist
336
+ if 'tuning_method' not in st.session_state:
337
+ st.session_state.tuning_method = None
338
+ if 'n_iter' not in st.session_state:
339
+ st.session_state.n_iter = 50 # Default value
340
+
341
+ st.subheader("Hyperparameter Tuning")
342
+ enable_tuning = st.checkbox("Enable Hyperparameter Tuning", value=False)
343
+ if enable_tuning:
344
+ # The selectbox will automatically update st.session_state.tuning_method
345
+ tuning_method_selected = st.selectbox("Select Tuning Method", ["Grid Search", "Randomized Search"], key='tuning_method')
346
+ if tuning_method_selected == "Randomized Search":
347
+ st.session_state.n_iter = st.number_input("Number of Iterations (for Randomized Search)", min_value=10, value=50, step=10, key='n_iter_randomized_search')
348
+ else:
349
+ # When tuning is disabled, explicitly set tuning_method to None
350
+ st.session_state.tuning_method = None
351
 
352
  # Auto-start training if triggered
353
  start_button_pressed = st.button("🎯 Start Training", type="primary", key='manual_start_train_button')
 
368
  if st.session_state.test_data is not None:
369
  df_test_processed = st.session_state.test_data.copy()
370
  if target not in df_test_processed.columns:
371
+ st.error(f"The target column '{target}' is missing from your test dataset. Please ensure both train and test datasets have the target column with the same name. Aborting training.")
372
  return
373
  X_test, y_test = preprocess_data(df_test_processed, target) # Preprocess test data separately
374
  # Ensure X_test has same columns as X_train after preprocessing (esp. after one-hot encoding if added later)
 
388
  )
389
 
390
  if X_train is None or y_train is None:
391
+ st.error("The training data (features X_train, target y_train) could not be prepared. This might be due to issues in the uploaded data or preprocessing steps. Please review your data and selections.")
392
  return
393
 
394
  # Scaling should be fit on X_train and transformed on X_test
395
+ current_scaling_method = st.session_state.get('scaling_method', 'StandardScaler') # Get from session state
396
+ if current_scaling_method != "None":
397
  num_cols_train = X_train.select_dtypes(include=np.number).columns
398
  if len(num_cols_train) > 0:
399
+ if current_scaling_method == "StandardScaler":
400
+ scaler = StandardScaler()
401
+ elif current_scaling_method == "MinMaxScaler":
402
+ scaler = MinMaxScaler()
403
+ else:
404
+ scaler = None # Should not happen
405
+
406
+ if scaler:
407
+ X_train[num_cols_train] = scaler.fit_transform(X_train[num_cols_train])
408
+ st.session_state.scaler = scaler # Save the fitted scaler
409
+ st.info(f"Numeric features in training data scaled using {current_scaling_method}.")
410
+ if X_test is not None:
411
+ num_cols_test = X_test.select_dtypes(include=np.number).columns
412
+ # Ensure test set uses the same numeric columns in the same order as train set for scaling
413
+ cols_to_scale_in_test = [col for col in num_cols_train if col in X_test.columns]
414
+ if len(cols_to_scale_in_test) > 0:
415
+ # Create a DataFrame with columns in the order of num_cols_train
416
+ X_test_subset_for_scaling = X_test[cols_to_scale_in_test]
417
+ X_test_scaled_values = scaler.transform(X_test_subset_for_scaling)
418
+ X_test[cols_to_scale_in_test] = X_test_scaled_values
419
+ st.info(f"Numeric features in test data scaled using {current_scaling_method}.")
420
+ else:
421
+ st.session_state.scaler = None # Ensure it's None if no scaling applied
422
+ else:
423
+ st.session_state.scaler = None # Ensure it's None if no numeric columns
424
+ else:
425
+ st.session_state.scaler = None # Ensure it's None if scaling_method is "None"
426
 
427
  st.session_state.update({'X_train': X_train, 'X_test': X_test, 'y_train': y_train, 'y_test': y_test})
428
 
429
  # Define models based on problem type
430
+ # Define models and their parameter grids for tuning
431
  if st.session_state.problem_type == "Classification":
432
+ models_and_params = {
433
+ "Logistic Regression": {
434
+ 'model': LogisticRegression(random_state=random_state, max_iter=1000),
435
+ 'params': {'C': [0.1, 1.0, 10.0], 'solver': ['liblinear', 'lbfgs']}
436
+ },
437
+ "Decision Tree": {
438
+ 'model': DecisionTreeClassifier(random_state=random_state),
439
+ 'params': {'max_depth': [None, 10, 20, 30], 'min_samples_leaf': [1, 5, 10]}
440
+ },
441
+ "Random Forest": {
442
+ 'model': RandomForestClassifier(random_state=random_state),
443
+ 'params': {'n_estimators': [100, 200], 'max_depth': [10, 20]}
444
+ },
445
+ "Gradient Boosting": {
446
+ 'model': GradientBoostingClassifier(random_state=random_state),
447
+ 'params': {'n_estimators': [100, 200], 'learning_rate': [0.01, 0.1]}
448
+ },
449
+ "XGBoost": {
450
+ 'model': xgb.XGBClassifier(random_state=random_state, use_label_encoder=False, eval_metric='logloss'),
451
+ 'params': {'n_estimators': [100, 200], 'learning_rate': [0.01, 0.1], 'max_depth': [3, 6]}
452
+ },
453
+ "LightGBM": {
454
+ 'model': lgb.LGBMClassifier(random_state=random_state),
455
+ 'params': {'n_estimators': [100, 200], 'learning_rate': [0.01, 0.1], 'num_leaves': [31, 50]}
456
+ },
457
+ "CatBoost": {
458
+ 'model': cb.CatBoostClassifier(random_state=random_state, verbose=0),
459
+ 'params': {'iterations': [100, 200], 'learning_rate': [0.01, 0.1], 'depth': [4, 6]}
460
+ },
461
+ "Support Vector Machine": {
462
+ 'model': SVC(random_state=random_state, probability=True),
463
+ 'params': {'C': [0.1, 1.0, 10.0], 'kernel': ['linear', 'rbf']}
464
+ },
465
+ "K-Nearest Neighbors": {
466
+ 'model': KNeighborsClassifier(),
467
+ 'params': {'n_neighbors': [3, 5, 7], 'weights': ['uniform', 'distance']}
468
+ },
469
+ "Gaussian Naive Bayes": {
470
+ 'model': GaussianNB(),
471
+ 'params': {}
472
+ }
473
  }
474
  scoring = 'accuracy'
475
  else: # Regression
476
+ models_and_params = {
477
+ "Linear Regression": {
478
+ 'model': LinearRegression(),
479
+ 'params': {}
480
+ },
481
+ "Ridge Regression": {
482
+ 'model': Ridge(random_state=random_state),
483
+ 'params': {'alpha': [0.1, 1.0, 10.0]}
484
+ },
485
+ "ElasticNet Regression": {
486
+ 'model': ElasticNet(random_state=random_state),
487
+ 'params': {'alpha': [0.1, 1.0, 10.0], 'l1_ratio': [0.1, 0.5, 0.9]}
488
+ },
489
+ "Random Forest Regressor": {
490
+ 'model': RandomForestRegressor(random_state=random_state),
491
+ 'params': {'n_estimators': [100, 200], 'max_depth': [10, 20]}
492
+ },
493
+ "Gradient Boosting Regressor": {
494
+ 'model': GradientBoostingRegressor(random_state=random_state),
495
+ 'params': {'n_estimators': [100, 200], 'learning_rate': [0.01, 0.1]}
496
+ },
497
+ "XGBoost Regressor": {
498
+ 'model': xgb.XGBRegressor(random_state=random_state),
499
+ 'params': {'n_estimators': [100, 200], 'learning_rate': [0.01, 0.1], 'max_depth': [3, 6]}
500
+ },
501
+ "LightGBM Regressor": {
502
+ 'model': lgb.LGBMRegressor(random_state=random_state),
503
+ 'params': {'n_estimators': [100, 200], 'learning_rate': [0.01, 0.1], 'num_leaves': [31, 50]}
504
+ },
505
+ "CatBoost Regressor": {
506
+ 'model': cb.CatBoostRegressor(random_state=random_state, verbose=0),
507
+ 'params': {'iterations': [100, 200], 'learning_rate': [0.01, 0.1], 'depth': [4, 6]}
508
+ },
509
+ "Decision Tree Regressor": {
510
+ 'model': DecisionTreeRegressor(random_state=random_state),
511
+ 'params': {'max_depth': [None, 10, 20, 30], 'min_samples_leaf': [1, 5, 10]}
512
+ },
513
+ "Support Vector Regressor": {
514
+ 'model': SVR(),
515
+ 'params': {'C': [0.1, 1.0, 10.0], 'kernel': ['linear', 'rbf']}
516
+ },
517
+ "K-Nearest Neighbors Regressor": {
518
+ 'model': KNeighborsRegressor(),
519
+ 'params': {'n_neighbors': [3, 5, 7], 'weights': ['uniform', 'distance']}
520
+ }
521
  }
522
  scoring = 'r2'
523
 
 
526
  progress_bar = st.progress(0)
527
  status_text = st.empty()
528
 
529
+ tuning_enabled = st.session_state.get('tuning_method') is not None
530
+ n_iter = st.session_state.get('n_iter', 50) # Default for Randomized Search
531
+
532
+ for i, (name, model_info) in enumerate(models_and_params.items()):
533
+ try:
534
+ model = model_info['model']
535
+ params = model_info['params']
536
+
537
+ # Check if this is one of the newly added models
538
+ is_new_model = name in ["XGBoost", "LightGBM", "CatBoost"] or name in ["XGBoost Regressor", "LightGBM Regressor", "CatBoost Regressor"]
539
+
540
+ if is_new_model:
541
+ status_text.text(f"Initializing {name}...")
542
+
543
+ if tuning_enabled and params:
544
+ status_text.text(f"Tuning {name}...")
545
+ try:
546
+ if st.session_state.tuning_method == "Grid Search":
547
+ tuner = GridSearchCV(model, params, cv=cv_folds, scoring=scoring, n_jobs=-1)
548
+ else: # Randomized Search
549
+ tuner = RandomizedSearchCV(model, params, n_iter=n_iter, cv=cv_folds, scoring=scoring, random_state=random_state, n_jobs=-1)
550
+
551
+ tuner.fit(X_train, y_train)
552
+ best_model = tuner.best_estimator_
553
+ st.write(f"Best parameters for {name}: {tuner.best_params_}")
554
+ except Exception as e:
555
+ display_error(e, f"Error during hyperparameter tuning for {name}")
556
+ # Skip this model and continue with the next one
557
+ continue
558
+ else:
559
+ status_text.text(f"Training {name}...")
560
+ try:
561
+ best_model = model
562
+ best_model.fit(X_train, y_train)
563
+ except Exception as e:
564
+ display_error(e, f"Error during training for {name}")
565
+ # Skip this model and continue with the next one
566
+ continue
567
+
568
+ trained_models[name] = best_model
569
+
570
+ try:
571
+ y_pred_test = best_model.predict(X_test)
572
+
573
+ # Handle predict_proba for classification models
574
+ if st.session_state.problem_type == "Classification" and hasattr(best_model, 'predict_proba'):
575
+ try:
576
+ y_proba_test = best_model.predict_proba(X_test)
577
+ except Exception as e:
578
+ st.warning(f"Could not compute prediction probabilities for {name}: {str(e)}")
579
+ y_proba_test = None
580
+ else:
581
+ y_proba_test = None
582
+
583
+ metrics = get_model_metrics(y_test, y_pred_test, y_proba_test, problem_type=st.session_state.problem_type)
584
+
585
+ # For tuned models, cross_val_score on the best_estimator_ might be redundant if tuner already did CV
586
+ # But for consistency, we can still calculate it or use tuner.best_score_
587
+ try:
588
+ cv_score = cross_val_score(best_model, X_train, y_train, cv=cv_folds, scoring=scoring).mean()
589
+ except Exception as e:
590
+ st.warning(f"Could not compute cross-validation score for {name}: {str(e)}")
591
+ cv_score = float('nan') # Use NaN to indicate missing value
592
+
593
+ current_model_scores = {'CV Mean Score': cv_score}
594
+ current_model_scores.update(metrics) # Add all relevant metrics
595
+ model_scores_dict[name] = current_model_scores
596
+
597
+ if is_new_model:
598
+ st.success(f"{name} trained successfully!")
599
+ except Exception as e:
600
+ display_error(e, f"Error during prediction or evaluation for {name}")
601
+ # Skip adding this model to the scores dictionary
602
+ continue
603
+ except Exception as e:
604
+ display_error(e, f"Unexpected error with {name}")
605
+ # Skip this model entirely and continue with the next one
606
+ continue
607
+
608
+ progress_bar.progress((i + 1) / len(models_and_params))
609
 
610
  st.session_state.models = trained_models
611
  st.session_state.model_scores = model_scores_dict
 
626
  st.success(f"✅ Training completed! Best model: {best_model_name}")
627
 
628
  except Exception as e:
629
+ display_error(e, "An error occurred during the model training process")
 
 
630
 
631
  def model_comparison_page():
632
  st.header("📊 Model Comparison")
 
667
  st.subheader(f"📋 Detailed Metrics for Best Model: {best_model_name}")
668
  best_model = st.session_state.best_model_info['model']
669
  y_pred = best_model.predict(st.session_state.X_test)
670
+ y_test = st.session_state.y_test
671
+
672
+ # Confusion Matrix
673
+ st.write("#### Confusion Matrix")
674
+ cm = confusion_matrix(y_test, y_pred)
675
+ fig_cm, ax_cm = plt.subplots()
676
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=ax_cm)
677
+ ax_cm.set_xlabel('Predicted')
678
+ ax_cm.set_ylabel('Actual')
679
+ ax_cm.set_title('Confusion Matrix')
680
+ st.pyplot(fig_cm)
681
+
682
+ # Classification Report
683
+ st.write("#### Classification Report")
684
+ report = classification_report(y_test, y_pred, output_dict=True)
685
+ report_df = pd.DataFrame(report).transpose()
686
+ st.dataframe(report_df.round(4))
687
+
688
+ # ROC Curve and AUC
689
+ if hasattr(best_model, 'predict_proba'):
690
+ st.write("#### ROC Curve")
691
+ try:
692
+ y_proba = best_model.predict_proba(st.session_state.X_test)
693
+ if y_proba.shape[1] > 2: # Multi-class classification
694
+ # For multi-class, plot one-vs-rest ROC curves
695
+ from sklearn.preprocessing import LabelBinarizer
696
+ lb = LabelBinarizer()
697
+ y_test_binarized = lb.fit_transform(y_test)
698
+
699
+ fig_roc, ax_roc = plt.subplots()
700
+ for i in range(y_proba.shape[1]):
701
+ fpr, tpr, _ = roc_curve(y_test_binarized[:, i], y_proba[:, i])
702
+ roc_auc = auc(fpr, tpr)
703
+ ax_roc.plot(fpr, tpr, label=f'Class {lb.classes_[i]} (AUC = {roc_auc:.2f})')
704
+ ax_roc.plot([0, 1], [0, 1], 'k--', label='Random Classifier')
705
+ ax_roc.set_xlabel('False Positive Rate')
706
+ ax_roc.set_ylabel('True Positive Rate')
707
+ ax_roc.set_title('ROC Curve (One-vs-Rest)')
708
+ ax_roc.legend(loc='lower right')
709
+ st.pyplot(fig_roc)
710
+ else: # Binary classification
711
+ fpr, tpr, _ = roc_curve(y_test, y_proba[:, 1])
712
+ roc_auc = auc(fpr, tpr)
713
+ fig_roc, ax_roc = plt.subplots()
714
+ ax_roc.plot(fpr, tpr, label=f'ROC curve (area = {roc_auc:.2f})')
715
+ ax_roc.plot([0, 1], [0, 1], 'k--', label='Random Classifier')
716
+ ax_roc.set_xlabel('False Positive Rate')
717
+ ax_roc.set_ylabel('True Positive Rate')
718
+ ax_roc.set_title('Receiver Operating Characteristic (ROC) Curve')
719
+ ax_roc.legend(loc='lower right')
720
+ st.pyplot(fig_roc)
721
+ except Exception as e:
722
+ st.warning(f"Could not plot ROC curve: {e}")
723
+
724
+ # Precision-Recall Curve
725
+ st.write("#### Precision-Recall Curve")
726
+ try:
727
+ if y_proba.shape[1] > 2: # Multi-class classification
728
+ # For multi-class, plot one-vs-rest Precision-Recall curves
729
+ from sklearn.preprocessing import LabelBinarizer
730
+ lb = LabelBinarizer()
731
+ y_test_binarized = lb.fit_transform(y_test)
732
+
733
+ fig_pr, ax_pr = plt.subplots()
734
+ for i in range(y_proba.shape[1]):
735
+ precision, recall, _ = precision_recall_curve(y_test_binarized[:, i], y_proba[:, i])
736
+ pr_auc = auc(recall, precision)
737
+ ax_pr.plot(recall, precision, label=f'Class {lb.classes_[i]} (AUC = {pr_auc:.2f})')
738
+ ax_pr.set_xlabel('Recall')
739
+ ax_pr.set_ylabel('Precision')
740
+ ax_pr.set_title('Precision-Recall Curve (One-vs-Rest)')
741
+ ax_pr.legend(loc='lower left')
742
+ st.pyplot(fig_pr)
743
+ else: # Binary classification
744
+ precision, recall, _ = precision_recall_curve(y_test, y_proba[:, 1])
745
+ pr_auc = auc(recall, precision)
746
+ fig_pr, ax_pr = plt.subplots()
747
+ ax_pr.plot(recall, precision, label=f'Precision-Recall curve (area = {pr_auc:.2f})')
748
+ ax_pr.set_xlabel('Recall')
749
+ ax_pr.set_ylabel('Precision')
750
+ ax_pr.set_title('Precision-Recall Curve')
751
+ ax_pr.legend(loc='lower left')
752
+ st.pyplot(fig_pr)
753
+ except Exception as e:
754
+ st.warning(f"Could not plot Precision-Recall curve: {e}")
755
+ else:
756
+ st.info("Model does not support `predict_proba` for ROC/PR curves.")
757
+
758
+ elif st.session_state.problem_type == "Regression" and st.session_state.X_test is not None:
759
+ st.subheader(f"📋 Detailed Metrics for Best Model: {best_model_name}")
760
+ best_model = st.session_state.best_model_info['model']
761
+ y_pred = best_model.predict(st.session_state.X_test)
762
+ y_test = st.session_state.y_test
763
+
764
+ # Residual Plot
765
+ st.write("#### Residual Plot")
766
+ residuals = y_test - y_pred
767
+ fig_res, ax_res = plt.subplots()
768
+ ax_res.scatter(y_pred, residuals)
769
+ ax_res.axhline(y=0, color='r', linestyle='--')
770
+ ax_res.set_xlabel('Predicted Values')
771
+ ax_res.set_ylabel('Residuals')
772
+ ax_res.set_title('Residual Plot')
773
+ st.pyplot(fig_res)
774
+
775
+ # Actual vs. Predicted Plot
776
+ st.write("#### Actual vs. Predicted Plot")
777
+ fig_ap, ax_ap = plt.subplots()
778
+ ax_ap.scatter(y_test, y_pred)
779
+ ax_ap.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'k--', lw=2) # Diagonal line
780
+ ax_ap.set_xlabel('Actual Values')
781
+ ax_ap.set_ylabel('Predicted Values')
782
+ ax_ap.set_title('Actual vs. Predicted Plot')
783
+ st.pyplot(fig_ap)
784
+
785
+ # st.subheader("Cross-Validation Score Details")
786
+ # if st.session_state.model_scores:
787
+ # cv_scores_df = pd.DataFrame({
788
+ # 'Model': list(st.session_state.model_scores.keys()),
789
+ # 'CV Mean Score': [v.get('CV Mean Score', 'N/A') for v in st.session_state.model_scores.values()]
790
+ # })
791
+ # st.dataframe(cv_scores_df.round(4), use_container_width=True)
792
+ # else:
793
+ # st.info("No cross-validation scores available.")
794
 
795
  def explainability_page():
796
  st.header("🔍 Model Explainability (SHAP)")
 
806
  with st.spinner("Generating SHAP explanations..."):
807
  try:
808
  # SHAP Explainer
809
+ try:
810
+ # Check for the newly added models first
811
+ if isinstance(best_model, (xgb.XGBClassifier, xgb.XGBRegressor,
812
+ lgb.LGBMClassifier, lgb.LGBMRegressor,
813
+ cb.CatBoostClassifier, cb.CatBoostRegressor)):
814
+ st.info(f"Using TreeExplainer for {best_model_name}")
815
+ explainer = shap.TreeExplainer(best_model)
816
+ elif isinstance(best_model, (RandomForestClassifier, GradientBoostingClassifier, DecisionTreeClassifier,
817
+ RandomForestRegressor, GradientBoostingRegressor, DecisionTreeRegressor)):
818
+ explainer = shap.TreeExplainer(best_model)
819
+ elif isinstance(best_model, (LogisticRegression, LinearRegression, Ridge, ElasticNet)):
820
+ explainer = shap.LinearExplainer(best_model, X_test_df) # Pass data for LinearExplainer
821
+ elif isinstance(best_model, (SVC, SVR, KNeighborsClassifier, KNeighborsRegressor, GaussianNB)):
822
+ # KernelExplainer can be slow or not directly applicable for some, use a subset of X_train for background data
823
+ # For KNN and Naive Bayes, KernelExplainer is a common choice for SHAP if TreeExplainer/LinearExplainer aren't suitable.
824
+ background_data = shap.sample(st.session_state.X_train, min(100, len(st.session_state.X_train)))
825
+ if isinstance(background_data, np.ndarray):
826
+ background_data = pd.DataFrame(background_data, columns=X_test_df.columns)
827
+ explainer = shap.KernelExplainer(best_model.predict_proba if hasattr(best_model, 'predict_proba') else best_model.predict, background_data)
828
+ else:
829
+ st.warning(f"SHAP explanations might not be optimized for the model type '{best_model_name}'. Using KernelExplainer as fallback.")
830
+ # Fallback to KernelExplainer for unknown model types
831
+ background_data = shap.sample(st.session_state.X_train, min(100, len(st.session_state.X_train)))
832
+ if isinstance(background_data, np.ndarray):
833
+ background_data = pd.DataFrame(background_data, columns=X_test_df.columns)
834
+ predict_fn = best_model.predict_proba if hasattr(best_model, 'predict_proba') and st.session_state.problem_type == "Classification" else best_model.predict
835
+ explainer = shap.KernelExplainer(predict_fn, background_data)
836
+ except Exception as e:
837
+ display_error(e, f"Error creating SHAP explainer for {best_model_name}")
838
+ st.error(f"SHAP explanations are currently not supported for the model type '{best_model_name}'. We are working on expanding compatibility.")
839
  return
840
 
841
  shap_values = explainer.shap_values(X_test_df)
 
888
  st.metric("Predicted Value", f"{predicted:.2f}")
889
 
890
  except Exception as e:
891
+ display_error(e, "An error occurred while generating SHAP explanations")
 
 
892
 
893
  def model_export_page():
894
  st.header("💾 Model Export")
 
909
  steps = []
910
  if st.session_state.scaler:
911
  steps.append(('scaler', st.session_state.scaler))
912
+
913
+ # Check if the model is one of the newly added models
914
+ is_new_model = isinstance(best_model, (xgb.XGBClassifier, xgb.XGBRegressor,
915
+ lgb.LGBMClassifier, lgb.LGBMRegressor,
916
+ cb.CatBoostClassifier, cb.CatBoostRegressor))
917
+
918
+ if is_new_model:
919
+ st.info(f"Preparing {best_model_name} for export. These advanced models may require additional libraries when loading.")
920
+
921
+ # Add model-specific export notes
922
+ if isinstance(best_model, (xgb.XGBClassifier, xgb.XGBRegressor)):
923
+ st.info("Note: To load this XGBoost model, ensure 'xgboost' is installed in your environment.")
924
+ elif isinstance(best_model, (lgb.LGBMClassifier, lgb.LGBMRegressor)):
925
+ st.info("Note: To load this LightGBM model, ensure 'lightgbm' is installed in your environment.")
926
+ elif isinstance(best_model, (cb.CatBoostClassifier, cb.CatBoostRegressor)):
927
+ st.info("Note: To load this CatBoost model, ensure 'catboost' is installed in your environment.")
928
+
929
+ try:
930
+ steps.append(('model', best_model))
931
+ pipeline_to_export = Pipeline(steps)
932
+ st.session_state.trained_pipeline = pipeline_to_export
933
+ except Exception as e:
934
+ display_error(e, f"Error creating pipeline for {best_model_name}")
935
+ st.warning("Falling back to exporting model without pipeline wrapper. Some preprocessing steps may need to be applied manually.")
936
+ st.session_state.trained_pipeline = best_model
937
 
938
  export_format = st.selectbox("Choose export format:", ["Joblib (.joblib)", "Pickle (.pkl)"])
939
  file_name_suggestion = f"{best_model_name.lower().replace(' ', '_')}_pipeline"
 
958
  )
959
  st.success("Model pipeline ready for download!")
960
  except Exception as e:
961
+ display_error(e, "An error occurred while exporting the model pipeline")
962
 
963
  st.subheader("📖 How to use the exported pipeline:")
964
+ # Determine if the best model is one of the newly added models
965
+ is_xgboost = isinstance(best_model, (xgb.XGBClassifier, xgb.XGBRegressor))
966
+ is_lightgbm = isinstance(best_model, (lgb.LGBMClassifier, lgb.LGBMRegressor))
967
+ is_catboost = isinstance(best_model, (cb.CatBoostClassifier, cb.CatBoostRegressor))
968
+
969
+ # Create code example with appropriate imports based on the model type
970
+ code_example = f"""import joblib # or import pickle
971
  import pandas as pd
972
+ """
973
+
974
+ # Add model-specific imports if needed
975
+ if is_xgboost:
976
+ code_example += "import xgboost as xgb # Required for XGBoost models\n"
977
+ if is_lightgbm:
978
+ code_example += "import lightgbm as lgb # Required for LightGBM models\n"
979
+ if is_catboost:
980
+ code_example += "import catboost as cb # Required for CatBoost models\n"
981
+
982
+ code_example += f"""
983
 
984
  # Load the pipeline
985
  pipeline = joblib.load('{file_name}{'.joblib' if 'Joblib' in export_format else '.pkl'}')
 
993
  # Make predictions
994
  # predictions = pipeline.predict(new_data)
995
  # print(predictions)
 
996
 
997
+ # For classification models with probability output
998
+ # if hasattr(pipeline, 'predict_proba'):
999
+ # probabilities = pipeline.predict_proba(new_data)
1000
+ # print(probabilities)
1001
+ """
 
 
1002
 
1003
+ st.code(code_example, language='python')
 
 
 
 
1004
 
1005
+ # Add additional notes for advanced models
1006
+ if is_xgboost or is_lightgbm or is_catboost:
1007
+ st.info("⚠️ Note: When deploying this model in production, ensure all required libraries are installed in your deployment environment.")
1008
+ st.info("💡 Tip: Consider using Docker to create a consistent environment for model deployment.")
1009
+
1010
+ st.subheader("🚀 Generate Flask API Endpoint")
1011
+ if st.button("Generate Flask API Code", key='generate_flask_api_button'):
1012
+ if st.session_state.trained_pipeline and st.session_state.X_train is not None:
1013
+ # Ensure file_name and ext are defined in this scope, might need to get them from session_state or re-evaluate
1014
+ # For simplicity, let's assume they are available or we use a default/placeholder
1015
+ # This part might need adjustment based on how file_name and ext are handled in the download section
1016
+ current_export_format = st.session_state.get('current_export_format', "Joblib (.joblib)") # Assuming this is stored or re-queried
1017
+ current_file_name = st.session_state.get('current_file_name', f"{st.session_state.best_model_info['name'].lower().replace(' ', '_')}_pipeline")
1018
+
1019
+ ext_model = ".joblib" if "Joblib" in current_export_format else ".pkl"
1020
+ model_pipeline_name = f"{current_file_name}{ext_model}"
1021
+
1022
+ flask_app_code = generate_flask_app_code(model_pipeline_name, list(st.session_state.X_train.columns), st.session_state.problem_type, is_xgboost, is_lightgbm, is_catboost)
1023
+
1024
+ st.code(flask_app_code, language='python')
1025
+
1026
+ b64_flask_app = base64.b64encode(flask_app_code.encode()).decode()
1027
+ href_flask_app = f'<a href="data:file/text;base64,{b64_flask_app}" download="flask_api_app.py">Download flask_api_app.py</a>'
1028
+ st.markdown(href_flask_app, unsafe_allow_html=True)
1029
+ st.success("Flask API code generated and ready for download!")
1030
+ st.info("Remember to install Flask (`pip install Flask`) and other necessary libraries (e.g., pandas, scikit-learn, joblib, and model-specific libraries) in the environment where you run this Flask app.")
1031
+ else:
1032
+ st.warning("Please ensure a model pipeline is trained and available, and training data (X_train) context exists.")
1033
+
1034
+
1035
+ # --- Helper function to generate Flask app code ---
1036
+ def generate_flask_app_code(model_path, feature_columns, problem_type, is_xgboost, is_lightgbm, is_catboost):
1037
+ imports = [
1038
+ "from flask import Flask, request, jsonify",
1039
+ "import joblib",
1040
+ "import pandas as pd",
1041
+ "import numpy as np"
1042
+ ]
1043
+ if is_xgboost:
1044
+ imports.append("import xgboost as xgb")
1045
+ if is_lightgbm:
1046
+ imports.append("import lightgbm as lgb")
1047
+ if is_catboost:
1048
+ imports.append("import catboost as cb")
1049
+
1050
+ import_str = "\n".join(imports)
requirements.txt CHANGED
@@ -6,4 +6,8 @@ shap
6
  matplotlib
7
  seaborn
8
  joblib
9
- openpyxl # For .xlsx file support
 
 
 
 
 
6
  matplotlib
7
  seaborn
8
  joblib
9
+ openpyxl # For .xlsx file support
10
+ xgboost>=1.7.0
11
+ lightgbm>=4.0.0
12
+ catboost>=1.2.0
13
+ Flask