saherPervaiz commited on
Commit
34b1335
·
verified ·
1 Parent(s): 43d6671

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +141 -3
app.py CHANGED
@@ -15,7 +15,7 @@ import seaborn as sns
15
  from io import BytesIO
16
 
17
  # Streamlit app title
18
- st.title("Model Training with Metrics and Correlation Heatmap")
19
 
20
  # File uploader
21
  uploaded_file = st.file_uploader("Choose a CSV file", type=["csv"])
@@ -52,10 +52,46 @@ if uploaded_file is not None:
52
  else:
53
  df[col].fillna(df[col].mode()[0], inplace=True)
54
 
55
- # Show cleaned dataset
56
- st.write("Cleaned Dataset:")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  st.dataframe(df)
58
 
 
 
 
 
 
 
 
 
 
59
  # Correlation Heatmap
60
  st.subheader("Correlation Heatmap")
61
  corr = df.corr()
@@ -63,6 +99,32 @@ if uploaded_file is not None:
63
  sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f", cbar=True)
64
  st.pyplot(plt)
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  # Select target variable
67
  target = st.selectbox("Select Target Variable", df.columns)
68
  features = [col for col in df.columns if col != target]
@@ -102,6 +164,44 @@ if uploaded_file is not None:
102
  st.subheader("Classification Model Performance Metrics")
103
  st.dataframe(metrics_df)
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  else: # Regression
106
  st.subheader("Regression Model Training")
107
  regressors = {
@@ -131,5 +231,43 @@ if uploaded_file is not None:
131
  regression_metrics_df = pd.DataFrame(regression_metrics)
132
  st.subheader("Regression Model Performance Metrics")
133
  st.dataframe(regression_metrics_df)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  else:
135
  st.error("The target variable must contain at least two unique values for classification or regression. Please check your dataset.")
 
15
  from io import BytesIO
16
 
17
  # Streamlit app title
18
+ st.title("Model Training with Outlier Removal, Metrics, and Correlation Heatmap")
19
 
20
  # File uploader
21
  uploaded_file = st.file_uploader("Choose a CSV file", type=["csv"])
 
52
  else:
53
  df[col].fillna(df[col].mode()[0], inplace=True)
54
 
55
+ # Remove outliers using the IQR method
56
+ st.write("Removing Outliers Using IQR:")
57
+ def remove_outliers_iqr(data, column):
58
+ Q1 = data[column].quantile(0.25)
59
+ Q3 = data[column].quantile(0.75)
60
+ IQR = Q3 - Q1
61
+ lower_bound = Q1 - 1.5 * IQR
62
+ upper_bound = Q3 + 1.5 * IQR
63
+ return data[(data[column] >= lower_bound) & (data[column] <= upper_bound)]
64
+
65
+ numeric_cols = df.select_dtypes(include=['float64', 'int64']).columns
66
+ for col in numeric_cols:
67
+ original_count = len(df)
68
+ df = remove_outliers_iqr(df, col)
69
+ st.write(f"Removed outliers from **{col}**: {original_count - len(df)} rows removed.")
70
+
71
+ # Capping Extreme Values (based on 5% and 95% percentiles)
72
+ st.write("Handling Extreme Values (Capping):")
73
+ def cap_extreme_values(dataframe):
74
+ for col in dataframe.select_dtypes(include=[np.number]).columns:
75
+ lower_limit = dataframe[col].quantile(0.05)
76
+ upper_limit = dataframe[col].quantile(0.95)
77
+ dataframe[col] = np.clip(dataframe[col], lower_limit, upper_limit)
78
+ return dataframe
79
+
80
+ df = cap_extreme_values(df)
81
+
82
+ # Display dataset after cleaning
83
+ st.write("Dataset After Outlier Removal and Capping Extreme Values:")
84
  st.dataframe(df)
85
 
86
+ # Add clean data download option
87
+ st.subheader("Download Cleaned Dataset")
88
+ st.download_button(
89
+ label="Download Cleaned Dataset (CSV)",
90
+ data=df.to_csv(index=False),
91
+ file_name="cleaned_dataset.csv",
92
+ mime="text/csv"
93
+ )
94
+
95
  # Correlation Heatmap
96
  st.subheader("Correlation Heatmap")
97
  corr = df.corr()
 
99
  sns.heatmap(corr, annot=True, cmap="coolwarm", fmt=".2f", cbar=True)
100
  st.pyplot(plt)
101
 
102
+ # Save heatmap as PNG
103
+ buf = BytesIO()
104
+ plt.savefig(buf, format="png")
105
+ buf.seek(0)
106
+ st.download_button(
107
+ label="Download Correlation Heatmap as PNG",
108
+ data=buf,
109
+ file_name="correlation_heatmap.png",
110
+ mime="image/png"
111
+ )
112
+
113
+ # Highlight highly correlated pairs
114
+ st.subheader("Highly Correlated Features")
115
+ high_corr = corr.abs().unstack().sort_values(ascending=False).drop_duplicates()
116
+ high_corr = high_corr[high_corr.index.get_level_values(0) != high_corr.index.get_level_values(1)]
117
+ high_corr_df = pd.DataFrame(high_corr, columns=["Correlation"])
118
+ st.dataframe(high_corr_df)
119
+
120
+ # Download correlation table as CSV
121
+ st.download_button(
122
+ label="Download Correlation Table (CSV)",
123
+ data=high_corr_df.to_csv(index=True),
124
+ file_name="correlation_table.csv",
125
+ mime="text/csv"
126
+ )
127
+
128
  # Select target variable
129
  target = st.selectbox("Select Target Variable", df.columns)
130
  features = [col for col in df.columns if col != target]
 
164
  st.subheader("Classification Model Performance Metrics")
165
  st.dataframe(metrics_df)
166
 
167
+ # Save metrics as PNG (table form)
168
+ fig, ax = plt.subplots(figsize=(8, 4))
169
+ ax.axis('tight')
170
+ ax.axis('off')
171
+ table = plt.table(cellText=metrics_df.values, colLabels=metrics_df.columns, cellLoc='center', loc='center')
172
+ table.auto_set_font_size(False)
173
+ table.set_fontsize(10)
174
+ table.auto_set_column_width(col=list(range(len(metrics_df.columns))))
175
+ buf = BytesIO()
176
+ fig.savefig(buf, format="png")
177
+ buf.seek(0)
178
+ st.download_button(
179
+ label="Download Classification Metrics Table as PNG",
180
+ data=buf,
181
+ file_name="classification_metrics_table.png",
182
+ mime="image/png"
183
+ )
184
+
185
+ # Visualization (Bar Graphs for Classification)
186
+ st.subheader("Classification Model Performance Metrics Graph")
187
+ metrics_df.set_index('Model', inplace=True)
188
+ ax = metrics_df.plot(kind='bar', figsize=(10, 6), colormap='coolwarm', rot=45)
189
+ plt.title("Classification Models - Performance Metrics")
190
+ plt.ylabel("Scores")
191
+ plt.xlabel("Models")
192
+ st.pyplot(plt)
193
+
194
+ # Download button for the bar graph
195
+ buf = BytesIO()
196
+ ax.figure.savefig(buf, format="png")
197
+ buf.seek(0)
198
+ st.download_button(
199
+ label="Download Classification Performance Graph as PNG",
200
+ data=buf,
201
+ file_name="classification_performance_graph.png",
202
+ mime="image/png"
203
+ )
204
+
205
  else: # Regression
206
  st.subheader("Regression Model Training")
207
  regressors = {
 
231
  regression_metrics_df = pd.DataFrame(regression_metrics)
232
  st.subheader("Regression Model Performance Metrics")
233
  st.dataframe(regression_metrics_df)
234
+
235
+ # Save metrics as PNG (table form)
236
+ fig, ax = plt.subplots(figsize=(8, 4))
237
+ ax.axis('tight')
238
+ ax.axis('off')
239
+ table = plt.table(cellText=regression_metrics_df.values, colLabels=regression_metrics_df.columns, cellLoc='center', loc='center')
240
+ table.auto_set_font_size(False)
241
+ table.set_fontsize(10)
242
+ table.auto_set_column_width(col=list(range(len(regression_metrics_df.columns))))
243
+ buf = BytesIO()
244
+ fig.savefig(buf, format="png")
245
+ buf.seek(0)
246
+ st.download_button(
247
+ label="Download Regression Metrics Table as PNG",
248
+ data=buf,
249
+ file_name="regression_metrics_table.png",
250
+ mime="image/png"
251
+ )
252
+
253
+ # Visualization (Bar Graphs for Regression)
254
+ st.subheader("Regression Model Performance Metrics Graph")
255
+ regression_metrics_df.set_index('Model', inplace=True)
256
+ ax = regression_metrics_df.plot(kind='bar', figsize=(10, 6), colormap='coolwarm', rot=45)
257
+ plt.title("Regression Models - Performance Metrics")
258
+ plt.ylabel("Scores")
259
+ plt.xlabel("Models")
260
+ st.pyplot(plt)
261
+
262
+ # Download button for the bar graph
263
+ buf = BytesIO()
264
+ ax.figure.savefig(buf, format="png")
265
+ buf.seek(0)
266
+ st.download_button(
267
+ label="Download Regression Performance Graph as PNG",
268
+ data=buf,
269
+ file_name="regression_performance_graph.png",
270
+ mime="image/png"
271
+ )
272
  else:
273
  st.error("The target variable must contain at least two unique values for classification or regression. Please check your dataset.")