jaker86 commited on
Commit
2fe6c63
·
verified ·
1 Parent(s): f0ece10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -24
app.py CHANGED
@@ -19,6 +19,9 @@ MIN_ROWS = 10
19
  MIN_COLS = 2
20
  MAX_FEATURES_TO_SHOW = 10
21
 
 
 
 
22
  def update_dropdown(file):
23
  if file is None:
24
  return gr.update(choices=[], value=None)
@@ -98,14 +101,14 @@ def analyze_file(file, label_col, n_clusters):
98
  ax = axes[i]
99
  ax.scatter(X_test[feature], y_pred, alpha=0.5)
100
  ax.set_xlabel(feature)
101
- ax.set_ylabel('Predicted SalePrice')
102
- ax.set_title(f'{feature} vs Predicted SalePrice')
103
  ax = axes[3]
104
  ax.scatter(y_test, y_pred, alpha=0.5)
105
  ax.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', label='Perfect Prediction')
106
- ax.set_xlabel('True SalePrice')
107
- ax.set_ylabel('Predicted SalePrice')
108
- ax.set_title('True vs Predicted SalePrice')
109
  min_val = min(y_test.min(), y_pred.min())
110
  max_val = max(y_test.max(), y_pred.max())
111
  ax.set_xlim(min_val, max_val)
@@ -117,8 +120,9 @@ def analyze_file(file, label_col, n_clusters):
117
  plt.close()
118
  buf.seek(0)
119
  model_img = Image.open(buf)
 
120
  else:
121
- # Classification (unchanged)
122
  if len(y.unique()) < 2:
123
  return ("Label must have at least 2 unique values.", None, None, None, None, None)
124
  y_encoded, uniques = pd.factorize(y)
@@ -128,23 +132,19 @@ def analyze_file(file, label_col, n_clusters):
128
  y_pred = model.predict(X_test)
129
  cr = classification_report(y_test, y_pred, target_names=[str(u) for u in uniques])
130
  results_text += "Classification Results:\n" + cr + "\n"
131
- fi = pd.Series(model.feature_importances_, index=X_processed.columns).sort_values(ascending=False)
132
- if len(fi) < 3:
133
- results_text += "\nNot enough features for a 3D plot with the next two most important features."
134
- else:
135
- next_two_features = fi.index[1:3]
136
- fig = plt.figure(figsize=(10, 8))
137
- ax = fig.add_subplot(111, projection='3d')
138
- scatter = ax.scatter(X_test[next_two_features[0]], X_test[next_two_features[1]], y_test, c=y_test, cmap='viridis', marker='o')
139
- ax.set_xlabel(next_two_features[0])
140
- ax.set_ylabel(next_two_features[1])
141
- ax.set_zlabel(label_col + " (encoded)")
142
- ax.set_title("3D Plot: Label vs Next Two Most Important Features")
143
- buf = io.BytesIO()
144
- plt.savefig(buf, format="png", bbox_inches="tight")
145
- plt.close()
146
- buf.seek(0)
147
- model_img = Image.open(buf)
148
  except Exception as e:
149
  results_text += f"\nError during model training: {e}"
150
 
@@ -218,6 +218,63 @@ def analyze_file(file, label_col, n_clusters):
218
 
219
  return results_text, model_img, fi_img, kmeans_img, agg_img, diff_img
220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
221
  with gr.Blocks() as demo:
222
  gr.Markdown("## Data Analysis Explorer")
223
  gr.Markdown("Upload a CSV or XLSX file to explore classification, regression, and clustering. Select a column to predict and the number of clusters!")
@@ -242,7 +299,7 @@ with gr.Blocks() as demo:
242
 
243
  with gr.TabItem("Prediction Plot"):
244
  gr.Markdown("### Prediction Visualization")
245
- gr.Markdown("For regression, shows scatter plots of the top three features vs. predicted values and a plot of true vs. predicted values. For classification, shows a 3D plot of the label vs. next two features.")
246
  model_img_output = gr.Image(label="Prediction Output")
247
 
248
  with gr.TabItem("Feature Importances"):
@@ -265,6 +322,21 @@ with gr.Blocks() as demo:
265
  gr.Markdown("Shows features that vary most between clusters, helping explain the groupings.")
266
  diff_output = gr.Image(label="Differentiating Features")
267
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
268
  analyze_btn.click(fn=analyze_file, inputs=[file_input, label_dropdown, clusters_slider],
269
  outputs=[results_textbox, model_img_output, fi_output, kmeans_output, agg_output, diff_output])
270
 
 
19
  MIN_COLS = 2
20
  MAX_FEATURES_TO_SHOW = 10
21
 
22
+ # Global variable to store trained model and data
23
+ global_data = {'model': None, 'scaler': None, 'X_columns': None, 'y_type': None, 'uniques': None}
24
+
25
  def update_dropdown(file):
26
  if file is None:
27
  return gr.update(choices=[], value=None)
 
101
  ax = axes[i]
102
  ax.scatter(X_test[feature], y_pred, alpha=0.5)
103
  ax.set_xlabel(feature)
104
+ ax.set_ylabel('Predicted Value')
105
+ ax.set_title(f'{feature} vs Predicted')
106
  ax = axes[3]
107
  ax.scatter(y_test, y_pred, alpha=0.5)
108
  ax.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', label='Perfect Prediction')
109
+ ax.set_xlabel('True Value')
110
+ ax.set_ylabel('Predicted Value')
111
+ ax.set_title('True vs Predicted')
112
  min_val = min(y_test.min(), y_pred.min())
113
  max_val = max(y_test.max(), y_pred.max())
114
  ax.set_xlim(min_val, max_val)
 
120
  plt.close()
121
  buf.seek(0)
122
  model_img = Image.open(buf)
123
+ global_data.update({'model': model, 'scaler': scaler, 'X_columns': X_processed.columns, 'y_type': 'regression', 'uniques': None})
124
  else:
125
+ # Classification
126
  if len(y.unique()) < 2:
127
  return ("Label must have at least 2 unique values.", None, None, None, None, None)
128
  y_encoded, uniques = pd.factorize(y)
 
132
  y_pred = model.predict(X_test)
133
  cr = classification_report(y_test, y_pred, target_names=[str(u) for u in uniques])
134
  results_text += "Classification Results:\n" + cr + "\n"
135
+ # 2D Confusion Matrix
136
+ cm = confusion_matrix(y_test, y_pred)
137
+ plt.figure(figsize=(8, 6))
138
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=[str(u) for u in uniques], yticklabels=[str(u) for u in uniques])
139
+ plt.xlabel('Predicted')
140
+ plt.ylabel('True')
141
+ plt.title('Confusion Matrix')
142
+ buf = io.BytesIO()
143
+ plt.savefig(buf, format="png", bbox_inches="tight")
144
+ plt.close()
145
+ buf.seek(0)
146
+ model_img = Image.open(buf)
147
+ global_data.update({'model': model, 'scaler': scaler, 'X_columns': X_processed.columns, 'y_type': 'classification', 'uniques': uniques})
 
 
 
 
148
  except Exception as e:
149
  results_text += f"\nError during model training: {e}"
150
 
 
218
 
219
  return results_text, model_img, fi_img, kmeans_img, agg_img, diff_img
220
 
221
+ def predict_interactive(**kwargs):
222
+ if global_data['model'] is None:
223
+ return "Please analyze a file first to train the model."
224
+
225
+ try:
226
+ # Create DataFrame from user inputs
227
+ input_data = pd.DataFrame([kwargs])
228
+
229
+ # Handle categorical variables with one-hot encoding
230
+ X_processed = pd.get_dummies(input_data)
231
+
232
+ # Ensure all expected columns are present
233
+ for col in global_data['X_columns']:
234
+ if col not in X_processed.columns:
235
+ X_processed[col] = 0
236
+
237
+ # Reorder columns to match training data
238
+ X_processed = X_processed[global_data['X_columns']]
239
+
240
+ # Scale the input
241
+ X_scaled = global_data['scaler'].transform(X_processed)
242
+
243
+ # Predict
244
+ prediction = global_data['model'].predict(X_scaled)
245
+
246
+ if global_data['y_type'] == 'classification':
247
+ pred_value = global_data['uniques'][int(prediction[0])]
248
+ return f"Predicted class: {pred_value}"
249
+ else:
250
+ return f"Predicted value: {prediction[0]:.3f}"
251
+ except Exception as e:
252
+ return f"Error in prediction: {str(e)}. Please ensure all inputs are valid numbers or categories."
253
+
254
+ def create_interactive_inputs(file, label_col):
255
+ if file is None or label_col is None:
256
+ return []
257
+
258
+ try:
259
+ if file.name.endswith('.csv'):
260
+ df = pd.read_csv(file.name)
261
+ elif file.name.endswith('.xlsx'):
262
+ df = pd.read_excel(file.name)
263
+ else:
264
+ return []
265
+
266
+ X = df.drop(columns=[label_col])
267
+ inputs = []
268
+ for col in X.columns:
269
+ examples = X[col].dropna().sample(min(3, len(X[col].dropna()))).tolist()
270
+ if pd.api.types.is_numeric_dtype(X[col]):
271
+ inputs.append(gr.Number(label=f"{col} (e.g., {', '.join(map(str, examples))})"))
272
+ else:
273
+ inputs.append(gr.Textbox(label=f"{col} (e.g., {', '.join(map(str, examples))})"))
274
+ return inputs
275
+ except Exception:
276
+ return []
277
+
278
  with gr.Blocks() as demo:
279
  gr.Markdown("## Data Analysis Explorer")
280
  gr.Markdown("Upload a CSV or XLSX file to explore classification, regression, and clustering. Select a column to predict and the number of clusters!")
 
299
 
300
  with gr.TabItem("Prediction Plot"):
301
  gr.Markdown("### Prediction Visualization")
302
+ gr.Markdown("For regression: scatter plots of top 3 features vs. predicted values and true vs. predicted. For classification: confusion matrix.")
303
  model_img_output = gr.Image(label="Prediction Output")
304
 
305
  with gr.TabItem("Feature Importances"):
 
322
  gr.Markdown("Shows features that vary most between clusters, helping explain the groupings.")
323
  diff_output = gr.Image(label="Differentiating Features")
324
 
325
+ with gr.TabItem("Interactive"):
326
+ gr.Markdown("### Interactive Prediction")
327
+ gr.Markdown("Enter values for each feature to get a prediction based on the trained model.")
328
+ interactive_inputs = gr.State(value=[])
329
+ with gr.Column():
330
+ input_components = gr.DynamicLayout(fn=create_interactive_inputs, inputs=[file_input, label_dropdown], outputs=interactive_inputs)
331
+ predict_btn = gr.Button("Predict")
332
+ prediction_output = gr.Textbox(label="Prediction Result")
333
+
334
+ predict_btn.click(
335
+ fn=predict_interactive,
336
+ inputs=interactive_inputs,
337
+ outputs=prediction_output
338
+ )
339
+
340
  analyze_btn.click(fn=analyze_file, inputs=[file_input, label_dropdown, clusters_slider],
341
  outputs=[results_textbox, model_img_output, fi_output, kmeans_output, agg_output, diff_output])
342