damndeepesh commited on
Commit
3641375
·
verified ·
1 Parent(s): 04606f5

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -18
app.py CHANGED
@@ -730,11 +730,24 @@ def model_training_page():
730
 
731
  # Determine best model
732
  if st.session_state.problem_type == "Classification":
733
- best_model_name = max(model_scores_dict, key=lambda k: (model_scores_dict[k]['Test Accuracy'] or 0, model_scores_dict[k]['Test AUC'] or 0))
 
 
 
 
 
 
 
734
  else: # Regression
735
- # Ensure 'R2' exists and provide a default if not (e.g., for models where R2 might not be applicable or calculable)
736
- best_model_name = max(model_scores_dict, key=lambda k: model_scores_dict[k].get('R2', -float('inf')))
 
 
737
 
 
 
 
 
738
  st.session_state.best_model_info = {
739
  'name': best_model_name,
740
  'model': trained_models[best_model_name],
@@ -752,32 +765,40 @@ def model_comparison_page():
752
  st.warning("⚠️ Please train models first.")
753
  return
754
 
755
- scores_df = pd.DataFrame(st.session_state.model_scores).T.fillna(0) # Fill NaN with 0 for display
756
- scores_df = scores_df.round(4)
757
-
758
- st.subheader("🏆 Model Leaderboard")
759
  if st.session_state.problem_type == "Classification":
760
  sort_by = 'Test Accuracy'
761
  display_cols = ['CV Mean Score', 'Test Accuracy', 'Test F1-score', 'Test AUC']
762
  else: # Regression
763
  sort_by = 'R2'
764
- display_cols = ['CV Mean Score', 'R2', 'MSE'] # Add other relevant regression metrics if needed
765
- # Ensure MSE is present, if not, it will be filled with 0 by .fillna(0) earlier or handle missing more gracefully if needed
 
 
 
 
766
 
 
 
 
767
  leaderboard = scores_df[display_cols].sort_values(by=sort_by, ascending=False)
768
  leaderboard['Rank'] = range(1, len(leaderboard) + 1)
769
  leaderboard = leaderboard[['Rank'] + display_cols]
770
  st.dataframe(leaderboard.style.background_gradient(subset=[sort_by], cmap='RdYlGn'), use_container_width=True)
771
 
772
- best_model_name = st.session_state.best_model_info['name']
773
- best_metric_val = st.session_state.best_model_info['metrics'].get(sort_by, 'N/A')
774
- st.markdown(f"<div class='success-message'><h4>🥇 Best Model: {best_model_name} ({sort_by}: {best_metric_val:.4f})</h4></div>", unsafe_allow_html=True)
775
-
776
- st.subheader("📈 Performance Visualization")
777
- fig, ax = plt.subplots(figsize=(10, 6))
778
- plot_data = scores_df[sort_by].sort_values(ascending=True)
779
- bars = ax.barh(plot_data.index, plot_data.values, color=['#ff6b6b' if idx == best_model_name else '#4ecdc4' for idx in plot_data.index])
780
- ax.set_xlabel(sort_by)
 
 
781
  ax.set_title('Model Performance Comparison')
782
  st.pyplot(fig)
783
 
 
730
 
731
  # Determine best model
732
  if st.session_state.problem_type == "Classification":
733
+ # Safely get metrics with default values if missing
734
+ best_model_name = max(
735
+ model_scores_dict,
736
+ key=lambda k: (
737
+ model_scores_dict[k].get('Test Accuracy', 0) or 0,
738
+ model_scores_dict[k].get('Test AUC', 0) or 0
739
+ )
740
+ )
741
  else: # Regression
742
+ best_model_name = max(
743
+ model_scores_dict,
744
+ key=lambda k: model_scores_dict[k].get('R2', -float('inf'))
745
+ )
746
 
747
+ if not model_scores_dict:
748
+ st.error("No models were successfully trained. Please check your data and try again.")
749
+ return
750
+
751
  st.session_state.best_model_info = {
752
  'name': best_model_name,
753
  'model': trained_models[best_model_name],
 
765
  st.warning("⚠️ Please train models first.")
766
  return
767
 
768
+ # Fill NaN with 0 for display and ensure all required columns exist
769
+ scores_df = pd.DataFrame(st.session_state.model_scores).T
770
+
 
771
  if st.session_state.problem_type == "Classification":
772
  sort_by = 'Test Accuracy'
773
  display_cols = ['CV Mean Score', 'Test Accuracy', 'Test F1-score', 'Test AUC']
774
  else: # Regression
775
  sort_by = 'R2'
776
+ display_cols = ['CV Mean Score', 'R2', 'MSE']
777
+
778
+ # Ensure all display columns exist, add them with NaN if missing
779
+ for col in display_cols:
780
+ if col not in scores_df.columns:
781
+ scores_df[col] = np.nan
782
 
783
+ scores_df = scores_df.fillna(0).round(4)
784
+
785
+ st.subheader("🏆 Model Leaderboard")
786
  leaderboard = scores_df[display_cols].sort_values(by=sort_by, ascending=False)
787
  leaderboard['Rank'] = range(1, len(leaderboard) + 1)
788
  leaderboard = leaderboard[['Rank'] + display_cols]
789
  st.dataframe(leaderboard.style.background_gradient(subset=[sort_by], cmap='RdYlGn'), use_container_width=True)
790
 
791
+ if st.session_state.best_model_info:
792
+ best_model_name = st.session_state.best_model_info['name']
793
+ best_metric_val = st.session_state.best_model_info['metrics'].get(sort_by, 0)
794
+ st.markdown(f"<div class='success-message'><h4>🥇 Best Model: {best_model_name} ({sort_by}: {best_metric_val:.4f})</h4></div>", unsafe_allow_html=True)
795
+
796
+ st.subheader("📈 Performance Visualization")
797
+ fig, ax = plt.subplots(figsize=(10, 6))
798
+ plot_data = scores_df[sort_by].sort_values(ascending=True)
799
+ bars = ax.barh(plot_data.index, plot_data.values,
800
+ color=['#ff6b6b' if idx == best_model_name else '#4ecdc4' for idx in plot_data.index])
801
+ ax.set_xlabel(sort_by)
802
  ax.set_title('Model Performance Comparison')
803
  st.pyplot(fig)
804