Spaces:
Sleeping
Sleeping
Upload app.py
Browse files
app.py
CHANGED
|
@@ -730,11 +730,24 @@ def model_training_page():
|
|
| 730 |
|
| 731 |
# Determine best model
|
| 732 |
if st.session_state.problem_type == "Classification":
|
| 733 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 734 |
else: # Regression
|
| 735 |
-
|
| 736 |
-
|
|
|
|
|
|
|
| 737 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 738 |
st.session_state.best_model_info = {
|
| 739 |
'name': best_model_name,
|
| 740 |
'model': trained_models[best_model_name],
|
|
@@ -752,32 +765,40 @@ def model_comparison_page():
|
|
| 752 |
st.warning("⚠️ Please train models first.")
|
| 753 |
return
|
| 754 |
|
| 755 |
-
|
| 756 |
-
scores_df =
|
| 757 |
-
|
| 758 |
-
st.subheader("🏆 Model Leaderboard")
|
| 759 |
if st.session_state.problem_type == "Classification":
|
| 760 |
sort_by = 'Test Accuracy'
|
| 761 |
display_cols = ['CV Mean Score', 'Test Accuracy', 'Test F1-score', 'Test AUC']
|
| 762 |
else: # Regression
|
| 763 |
sort_by = 'R2'
|
| 764 |
-
display_cols = ['CV Mean Score', 'R2', 'MSE']
|
| 765 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 766 |
|
|
|
|
|
|
|
|
|
|
| 767 |
leaderboard = scores_df[display_cols].sort_values(by=sort_by, ascending=False)
|
| 768 |
leaderboard['Rank'] = range(1, len(leaderboard) + 1)
|
| 769 |
leaderboard = leaderboard[['Rank'] + display_cols]
|
| 770 |
st.dataframe(leaderboard.style.background_gradient(subset=[sort_by], cmap='RdYlGn'), use_container_width=True)
|
| 771 |
|
| 772 |
-
|
| 773 |
-
|
| 774 |
-
|
| 775 |
-
|
| 776 |
-
|
| 777 |
-
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
|
|
|
|
|
|
| 781 |
ax.set_title('Model Performance Comparison')
|
| 782 |
st.pyplot(fig)
|
| 783 |
|
|
|
|
| 730 |
|
| 731 |
# Determine best model
|
| 732 |
if st.session_state.problem_type == "Classification":
|
| 733 |
+
# Safely get metrics with default values if missing
|
| 734 |
+
best_model_name = max(
|
| 735 |
+
model_scores_dict,
|
| 736 |
+
key=lambda k: (
|
| 737 |
+
model_scores_dict[k].get('Test Accuracy', 0) or 0,
|
| 738 |
+
model_scores_dict[k].get('Test AUC', 0) or 0
|
| 739 |
+
)
|
| 740 |
+
)
|
| 741 |
else: # Regression
|
| 742 |
+
best_model_name = max(
|
| 743 |
+
model_scores_dict,
|
| 744 |
+
key=lambda k: model_scores_dict[k].get('R2', -float('inf'))
|
| 745 |
+
)
|
| 746 |
|
| 747 |
+
if not model_scores_dict:
|
| 748 |
+
st.error("No models were successfully trained. Please check your data and try again.")
|
| 749 |
+
return
|
| 750 |
+
|
| 751 |
st.session_state.best_model_info = {
|
| 752 |
'name': best_model_name,
|
| 753 |
'model': trained_models[best_model_name],
|
|
|
|
| 765 |
st.warning("⚠️ Please train models first.")
|
| 766 |
return
|
| 767 |
|
| 768 |
+
# Fill NaN with 0 for display and ensure all required columns exist
|
| 769 |
+
scores_df = pd.DataFrame(st.session_state.model_scores).T
|
| 770 |
+
|
|
|
|
| 771 |
if st.session_state.problem_type == "Classification":
|
| 772 |
sort_by = 'Test Accuracy'
|
| 773 |
display_cols = ['CV Mean Score', 'Test Accuracy', 'Test F1-score', 'Test AUC']
|
| 774 |
else: # Regression
|
| 775 |
sort_by = 'R2'
|
| 776 |
+
display_cols = ['CV Mean Score', 'R2', 'MSE']
|
| 777 |
+
|
| 778 |
+
# Ensure all display columns exist, add them with NaN if missing
|
| 779 |
+
for col in display_cols:
|
| 780 |
+
if col not in scores_df.columns:
|
| 781 |
+
scores_df[col] = np.nan
|
| 782 |
|
| 783 |
+
scores_df = scores_df.fillna(0).round(4)
|
| 784 |
+
|
| 785 |
+
st.subheader("🏆 Model Leaderboard")
|
| 786 |
leaderboard = scores_df[display_cols].sort_values(by=sort_by, ascending=False)
|
| 787 |
leaderboard['Rank'] = range(1, len(leaderboard) + 1)
|
| 788 |
leaderboard = leaderboard[['Rank'] + display_cols]
|
| 789 |
st.dataframe(leaderboard.style.background_gradient(subset=[sort_by], cmap='RdYlGn'), use_container_width=True)
|
| 790 |
|
| 791 |
+
if st.session_state.best_model_info:
|
| 792 |
+
best_model_name = st.session_state.best_model_info['name']
|
| 793 |
+
best_metric_val = st.session_state.best_model_info['metrics'].get(sort_by, 0)
|
| 794 |
+
st.markdown(f"<div class='success-message'><h4>🥇 Best Model: {best_model_name} ({sort_by}: {best_metric_val:.4f})</h4></div>", unsafe_allow_html=True)
|
| 795 |
+
|
| 796 |
+
st.subheader("📈 Performance Visualization")
|
| 797 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 798 |
+
plot_data = scores_df[sort_by].sort_values(ascending=True)
|
| 799 |
+
bars = ax.barh(plot_data.index, plot_data.values,
|
| 800 |
+
color=['#ff6b6b' if idx == best_model_name else '#4ecdc4' for idx in plot_data.index])
|
| 801 |
+
ax.set_xlabel(sort_by)
|
| 802 |
ax.set_title('Model Performance Comparison')
|
| 803 |
st.pyplot(fig)
|
| 804 |
|