saherPervaiz commited on
Commit
3d3a6dd
·
verified ·
1 Parent(s): f581892

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -32
app.py CHANGED
@@ -12,9 +12,10 @@ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_sc
12
  import numpy as np
13
  import matplotlib.pyplot as plt
14
  import seaborn as sns
 
15
 
16
  # File uploader
17
- st.title("Model Training with Metrics")
18
  uploaded_file = st.file_uploader("Choose a CSV file", type=["csv"])
19
 
20
  if uploaded_file is not None:
@@ -86,13 +87,27 @@ if uploaded_file is not None:
86
  # Correlation Heatmap
87
  st.subheader("Correlation Heatmap")
88
  corr = df.corr()
89
- plt.figure(figsize=(10, 6))
90
- sns.heatmap(corr, annot=True, cmap='coolwarm', fmt='.2f')
91
  st.pyplot(plt)
92
-
93
- # Correlation Metrics
94
- st.subheader("Correlation Metrics")
95
- st.dataframe(corr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
  target = st.selectbox("Select Target Variable", df.columns)
98
  features = [col for col in df.columns if col != target]
@@ -129,24 +144,16 @@ if uploaded_file is not None:
129
  st.subheader("Classification Model Performance Metrics")
130
  st.dataframe(metrics_df)
131
 
132
- # Download as CSV
133
- st.download_button(
134
- label="Download Classification Report as CSV",
135
- data=metrics_df.to_csv(index=False),
136
- file_name="classification_report.csv",
137
- mime="text/csv"
138
- )
139
-
140
- # Download as PNG
141
  fig, ax = plt.subplots()
142
- ax.axis('off')
143
- table = ax.table(cellText=metrics_df.values, colLabels=metrics_df.columns, loc='center', cellLoc='center')
144
- table.auto_set_font_size(False)
145
- table.set_fontsize(10)
146
- plt.savefig("classification_report.png")
147
  st.download_button(
148
  label="Download Classification Report as PNG",
149
- data=open("classification_report.png", "rb"),
150
  file_name="classification_report.png",
151
  mime="image/png"
152
  )
@@ -179,15 +186,16 @@ if uploaded_file is not None:
179
  st.subheader("Regression Model Performance Metrics")
180
  st.dataframe(regression_metrics_df)
181
 
182
- # Download as CSV
 
 
 
 
 
 
183
  st.download_button(
184
- label="Download Regression Report as CSV",
185
- data=regression_metrics_df.to_csv(index=False),
186
- file_name="regression_report.csv",
187
- mime="text/csv"
188
  )
189
-
190
- # Download as PNG
191
- fig, ax = plt.subplots()
192
- ax.axis('off')
193
- table = ax.table(cellText=regression_metrics_df.values
 
12
  import numpy as np
13
  import matplotlib.pyplot as plt
14
  import seaborn as sns
15
+ from io import BytesIO
16
 
17
  # File uploader
18
+ st.title("Model Training with Metrics and Correlation Heatmap")
19
  uploaded_file = st.file_uploader("Choose a CSV file", type=["csv"])
20
 
21
  if uploaded_file is not None:
 
87
  # Correlation Heatmap
88
  st.subheader("Correlation Heatmap")
89
  corr = df.corr()
90
+ plt.figure(figsize=(10, 8))
91
+ sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f", cbar=True)
92
  st.pyplot(plt)
93
+
94
+ # Save heatmap as PNG
95
+ buf = BytesIO()
96
+ plt.savefig(buf, format="png")
97
+ buf.seek(0)
98
+ st.download_button(
99
+ label="Download Correlation Heatmap as PNG",
100
+ data=buf,
101
+ file_name="correlation_heatmap.png",
102
+ mime="image/png"
103
+ )
104
+
105
+ # Highlight highly correlated pairs
106
+ st.subheader("Highly Correlated Features")
107
+ high_corr = corr.abs().unstack().sort_values(ascending=False).drop_duplicates()
108
+ high_corr = high_corr[high_corr >= 0.8]
109
+ high_corr_df = high_corr[high_corr.index.get_level_values(0) != high_corr.index.get_level_values(1)]
110
+ st.write(high_corr_df)
111
 
112
  target = st.selectbox("Select Target Variable", df.columns)
113
  features = [col for col in df.columns if col != target]
 
144
  st.subheader("Classification Model Performance Metrics")
145
  st.dataframe(metrics_df)
146
 
147
+ # Save metrics as PNG
 
 
 
 
 
 
 
 
148
  fig, ax = plt.subplots()
149
+ sns.barplot(data=metrics_df, x="Model", y="Accuracy", ax=ax)
150
+ ax.set_title("Classification Model Performance")
151
+ buf = BytesIO()
152
+ fig.savefig(buf, format="png")
153
+ buf.seek(0)
154
  st.download_button(
155
  label="Download Classification Report as PNG",
156
+ data=buf,
157
  file_name="classification_report.png",
158
  mime="image/png"
159
  )
 
186
  st.subheader("Regression Model Performance Metrics")
187
  st.dataframe(regression_metrics_df)
188
 
189
+ # Save metrics as PNG
190
+ fig, ax = plt.subplots()
191
+ sns.barplot(data=regression_metrics_df, x="Model", y="R² Score", ax=ax)
192
+ ax.set_title("Regression Model Performance")
193
+ buf = BytesIO()
194
+ fig.savefig(buf, format="png")
195
+ buf.seek(0)
196
  st.download_button(
197
+ label="Download Regression Report as PNG",
198
+ data=buf,
199
+ file_name="regression_report.png",
200
+ mime="image/png"
201
  )