Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -19,6 +19,9 @@ MIN_ROWS = 10
|
|
| 19 |
MIN_COLS = 2
|
| 20 |
MAX_FEATURES_TO_SHOW = 10
|
| 21 |
|
|
|
|
|
|
|
|
|
|
| 22 |
def update_dropdown(file):
|
| 23 |
if file is None:
|
| 24 |
return gr.update(choices=[], value=None)
|
|
@@ -98,14 +101,14 @@ def analyze_file(file, label_col, n_clusters):
|
|
| 98 |
ax = axes[i]
|
| 99 |
ax.scatter(X_test[feature], y_pred, alpha=0.5)
|
| 100 |
ax.set_xlabel(feature)
|
| 101 |
-
ax.set_ylabel('Predicted
|
| 102 |
-
ax.set_title(f'{feature} vs Predicted
|
| 103 |
ax = axes[3]
|
| 104 |
ax.scatter(y_test, y_pred, alpha=0.5)
|
| 105 |
ax.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', label='Perfect Prediction')
|
| 106 |
-
ax.set_xlabel('True
|
| 107 |
-
ax.set_ylabel('Predicted
|
| 108 |
-
ax.set_title('True vs Predicted
|
| 109 |
min_val = min(y_test.min(), y_pred.min())
|
| 110 |
max_val = max(y_test.max(), y_pred.max())
|
| 111 |
ax.set_xlim(min_val, max_val)
|
|
@@ -117,8 +120,9 @@ def analyze_file(file, label_col, n_clusters):
|
|
| 117 |
plt.close()
|
| 118 |
buf.seek(0)
|
| 119 |
model_img = Image.open(buf)
|
|
|
|
| 120 |
else:
|
| 121 |
-
# Classification
|
| 122 |
if len(y.unique()) < 2:
|
| 123 |
return ("Label must have at least 2 unique values.", None, None, None, None, None)
|
| 124 |
y_encoded, uniques = pd.factorize(y)
|
|
@@ -128,23 +132,19 @@ def analyze_file(file, label_col, n_clusters):
|
|
| 128 |
y_pred = model.predict(X_test)
|
| 129 |
cr = classification_report(y_test, y_pred, target_names=[str(u) for u in uniques])
|
| 130 |
results_text += "Classification Results:\n" + cr + "\n"
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
plt.savefig(buf, format="png", bbox_inches="tight")
|
| 145 |
-
plt.close()
|
| 146 |
-
buf.seek(0)
|
| 147 |
-
model_img = Image.open(buf)
|
| 148 |
except Exception as e:
|
| 149 |
results_text += f"\nError during model training: {e}"
|
| 150 |
|
|
@@ -218,6 +218,63 @@ def analyze_file(file, label_col, n_clusters):
|
|
| 218 |
|
| 219 |
return results_text, model_img, fi_img, kmeans_img, agg_img, diff_img
|
| 220 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 221 |
with gr.Blocks() as demo:
|
| 222 |
gr.Markdown("## Data Analysis Explorer")
|
| 223 |
gr.Markdown("Upload a CSV or XLSX file to explore classification, regression, and clustering. Select a column to predict and the number of clusters!")
|
|
@@ -242,7 +299,7 @@ with gr.Blocks() as demo:
|
|
| 242 |
|
| 243 |
with gr.TabItem("Prediction Plot"):
|
| 244 |
gr.Markdown("### Prediction Visualization")
|
| 245 |
-
gr.Markdown("For regression
|
| 246 |
model_img_output = gr.Image(label="Prediction Output")
|
| 247 |
|
| 248 |
with gr.TabItem("Feature Importances"):
|
|
@@ -265,6 +322,21 @@ with gr.Blocks() as demo:
|
|
| 265 |
gr.Markdown("Shows features that vary most between clusters, helping explain the groupings.")
|
| 266 |
diff_output = gr.Image(label="Differentiating Features")
|
| 267 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 268 |
analyze_btn.click(fn=analyze_file, inputs=[file_input, label_dropdown, clusters_slider],
|
| 269 |
outputs=[results_textbox, model_img_output, fi_output, kmeans_output, agg_output, diff_output])
|
| 270 |
|
|
|
|
| 19 |
MIN_COLS = 2
|
| 20 |
MAX_FEATURES_TO_SHOW = 10
|
| 21 |
|
| 22 |
+
# Global variable to store trained model and data
|
| 23 |
+
global_data = {'model': None, 'scaler': None, 'X_columns': None, 'y_type': None, 'uniques': None}
|
| 24 |
+
|
| 25 |
def update_dropdown(file):
|
| 26 |
if file is None:
|
| 27 |
return gr.update(choices=[], value=None)
|
|
|
|
| 101 |
ax = axes[i]
|
| 102 |
ax.scatter(X_test[feature], y_pred, alpha=0.5)
|
| 103 |
ax.set_xlabel(feature)
|
| 104 |
+
ax.set_ylabel('Predicted Value')
|
| 105 |
+
ax.set_title(f'{feature} vs Predicted')
|
| 106 |
ax = axes[3]
|
| 107 |
ax.scatter(y_test, y_pred, alpha=0.5)
|
| 108 |
ax.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', label='Perfect Prediction')
|
| 109 |
+
ax.set_xlabel('True Value')
|
| 110 |
+
ax.set_ylabel('Predicted Value')
|
| 111 |
+
ax.set_title('True vs Predicted')
|
| 112 |
min_val = min(y_test.min(), y_pred.min())
|
| 113 |
max_val = max(y_test.max(), y_pred.max())
|
| 114 |
ax.set_xlim(min_val, max_val)
|
|
|
|
| 120 |
plt.close()
|
| 121 |
buf.seek(0)
|
| 122 |
model_img = Image.open(buf)
|
| 123 |
+
global_data.update({'model': model, 'scaler': scaler, 'X_columns': X_processed.columns, 'y_type': 'regression', 'uniques': None})
|
| 124 |
else:
|
| 125 |
+
# Classification
|
| 126 |
if len(y.unique()) < 2:
|
| 127 |
return ("Label must have at least 2 unique values.", None, None, None, None, None)
|
| 128 |
y_encoded, uniques = pd.factorize(y)
|
|
|
|
| 132 |
y_pred = model.predict(X_test)
|
| 133 |
cr = classification_report(y_test, y_pred, target_names=[str(u) for u in uniques])
|
| 134 |
results_text += "Classification Results:\n" + cr + "\n"
|
| 135 |
+
# 2D Confusion Matrix
|
| 136 |
+
cm = confusion_matrix(y_test, y_pred)
|
| 137 |
+
plt.figure(figsize=(8, 6))
|
| 138 |
+
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=[str(u) for u in uniques], yticklabels=[str(u) for u in uniques])
|
| 139 |
+
plt.xlabel('Predicted')
|
| 140 |
+
plt.ylabel('True')
|
| 141 |
+
plt.title('Confusion Matrix')
|
| 142 |
+
buf = io.BytesIO()
|
| 143 |
+
plt.savefig(buf, format="png", bbox_inches="tight")
|
| 144 |
+
plt.close()
|
| 145 |
+
buf.seek(0)
|
| 146 |
+
model_img = Image.open(buf)
|
| 147 |
+
global_data.update({'model': model, 'scaler': scaler, 'X_columns': X_processed.columns, 'y_type': 'classification', 'uniques': uniques})
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
except Exception as e:
|
| 149 |
results_text += f"\nError during model training: {e}"
|
| 150 |
|
|
|
|
| 218 |
|
| 219 |
return results_text, model_img, fi_img, kmeans_img, agg_img, diff_img
|
| 220 |
|
| 221 |
+
def predict_interactive(**kwargs):
|
| 222 |
+
if global_data['model'] is None:
|
| 223 |
+
return "Please analyze a file first to train the model."
|
| 224 |
+
|
| 225 |
+
try:
|
| 226 |
+
# Create DataFrame from user inputs
|
| 227 |
+
input_data = pd.DataFrame([kwargs])
|
| 228 |
+
|
| 229 |
+
# Handle categorical variables with one-hot encoding
|
| 230 |
+
X_processed = pd.get_dummies(input_data)
|
| 231 |
+
|
| 232 |
+
# Ensure all expected columns are present
|
| 233 |
+
for col in global_data['X_columns']:
|
| 234 |
+
if col not in X_processed.columns:
|
| 235 |
+
X_processed[col] = 0
|
| 236 |
+
|
| 237 |
+
# Reorder columns to match training data
|
| 238 |
+
X_processed = X_processed[global_data['X_columns']]
|
| 239 |
+
|
| 240 |
+
# Scale the input
|
| 241 |
+
X_scaled = global_data['scaler'].transform(X_processed)
|
| 242 |
+
|
| 243 |
+
# Predict
|
| 244 |
+
prediction = global_data['model'].predict(X_scaled)
|
| 245 |
+
|
| 246 |
+
if global_data['y_type'] == 'classification':
|
| 247 |
+
pred_value = global_data['uniques'][int(prediction[0])]
|
| 248 |
+
return f"Predicted class: {pred_value}"
|
| 249 |
+
else:
|
| 250 |
+
return f"Predicted value: {prediction[0]:.3f}"
|
| 251 |
+
except Exception as e:
|
| 252 |
+
return f"Error in prediction: {str(e)}. Please ensure all inputs are valid numbers or categories."
|
| 253 |
+
|
| 254 |
+
def create_interactive_inputs(file, label_col):
|
| 255 |
+
if file is None or label_col is None:
|
| 256 |
+
return []
|
| 257 |
+
|
| 258 |
+
try:
|
| 259 |
+
if file.name.endswith('.csv'):
|
| 260 |
+
df = pd.read_csv(file.name)
|
| 261 |
+
elif file.name.endswith('.xlsx'):
|
| 262 |
+
df = pd.read_excel(file.name)
|
| 263 |
+
else:
|
| 264 |
+
return []
|
| 265 |
+
|
| 266 |
+
X = df.drop(columns=[label_col])
|
| 267 |
+
inputs = []
|
| 268 |
+
for col in X.columns:
|
| 269 |
+
examples = X[col].dropna().sample(min(3, len(X[col].dropna()))).tolist()
|
| 270 |
+
if pd.api.types.is_numeric_dtype(X[col]):
|
| 271 |
+
inputs.append(gr.Number(label=f"{col} (e.g., {', '.join(map(str, examples))})"))
|
| 272 |
+
else:
|
| 273 |
+
inputs.append(gr.Textbox(label=f"{col} (e.g., {', '.join(map(str, examples))})"))
|
| 274 |
+
return inputs
|
| 275 |
+
except Exception:
|
| 276 |
+
return []
|
| 277 |
+
|
| 278 |
with gr.Blocks() as demo:
|
| 279 |
gr.Markdown("## Data Analysis Explorer")
|
| 280 |
gr.Markdown("Upload a CSV or XLSX file to explore classification, regression, and clustering. Select a column to predict and the number of clusters!")
|
|
|
|
| 299 |
|
| 300 |
with gr.TabItem("Prediction Plot"):
|
| 301 |
gr.Markdown("### Prediction Visualization")
|
| 302 |
+
gr.Markdown("For regression: scatter plots of top 3 features vs. predicted values and true vs. predicted. For classification: confusion matrix.")
|
| 303 |
model_img_output = gr.Image(label="Prediction Output")
|
| 304 |
|
| 305 |
with gr.TabItem("Feature Importances"):
|
|
|
|
| 322 |
gr.Markdown("Shows features that vary most between clusters, helping explain the groupings.")
|
| 323 |
diff_output = gr.Image(label="Differentiating Features")
|
| 324 |
|
| 325 |
+
with gr.TabItem("Interactive"):
|
| 326 |
+
gr.Markdown("### Interactive Prediction")
|
| 327 |
+
gr.Markdown("Enter values for each feature to get a prediction based on the trained model.")
|
| 328 |
+
interactive_inputs = gr.State(value=[])
|
| 329 |
+
with gr.Column():
|
| 330 |
+
input_components = gr.DynamicLayout(fn=create_interactive_inputs, inputs=[file_input, label_dropdown], outputs=interactive_inputs)
|
| 331 |
+
predict_btn = gr.Button("Predict")
|
| 332 |
+
prediction_output = gr.Textbox(label="Prediction Result")
|
| 333 |
+
|
| 334 |
+
predict_btn.click(
|
| 335 |
+
fn=predict_interactive,
|
| 336 |
+
inputs=interactive_inputs,
|
| 337 |
+
outputs=prediction_output
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
analyze_btn.click(fn=analyze_file, inputs=[file_input, label_dropdown, clusters_slider],
|
| 341 |
outputs=[results_textbox, model_img_output, fi_output, kmeans_output, agg_output, diff_output])
|
| 342 |
|