Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -297,75 +297,17 @@ with torch.no_grad():
|
|
| 297 |
train_preds = model(X_train_tensor).argmax(dim=1)
|
| 298 |
train_acc = (train_preds == torch.tensor(y_train_raw)).float().mean().item()
|
| 299 |
|
| 300 |
-
# === Gradio App ===
|
| 301 |
with gr.Blocks() as demo:
|
| 302 |
-
gr.Markdown("#
|
| 303 |
|
| 304 |
with gr.Tabs():
|
| 305 |
with gr.Tab("Preview Raw Data"):
|
| 306 |
-
gr.DataFrame(df.head(50), label="
|
| 307 |
|
| 308 |
with gr.Tab("Visualizations"):
|
| 309 |
-
def plot_all():
|
| 310 |
-
plots = []
|
| 311 |
-
for i in range(8):
|
| 312 |
-
fig, ax = plt.subplots()
|
| 313 |
-
ax.plot(X[i])
|
| 314 |
-
ax.set_title(f"Spectrum {i+1}")
|
| 315 |
-
plots.append(fig)
|
| 316 |
-
return plots
|
| 317 |
plot_button = gr.Button("Generate Spectroscopy Visualizations")
|
| 318 |
-
|
| 319 |
-
plot_button.click(fn=plot_all, inputs=[], outputs=
|
| 320 |
-
|
| 321 |
-
with gr.Tab("Models"):
|
| 322 |
-
with gr.Tabs():
|
| 323 |
-
with gr.Tab("Random Forest"):
|
| 324 |
-
gr.Markdown(f"✅ Train Accuracy: **{accuracy_score(y_train, rf.predict(X_train)):.2f}**<br>🎯 Test Accuracy: **{accuracy_score(y_test, rf.predict(X_test)):.2f}**")
|
| 325 |
-
fig_rf = plt.figure()
|
| 326 |
-
sns.heatmap(confusion_matrix(y_test, rf.predict(X_test)), annot=True, fmt='d')
|
| 327 |
-
plt.title("Random Forest Confusion Matrix")
|
| 328 |
-
gr.Plot(fig_rf)
|
| 329 |
-
|
| 330 |
-
with gr.Tab("Decision Tree"):
|
| 331 |
-
gr.Markdown(f"✅ Train Accuracy: **{accuracy_score(y_train, dt.predict(X_train)):.2f}**<br>🎯 Test Accuracy: **{accuracy_score(y_test, dt.predict(X_test)):.2f}**")
|
| 332 |
-
fig_dt = plt.figure()
|
| 333 |
-
sns.heatmap(confusion_matrix(y_test, dt.predict(X_test)), annot=True, fmt='d')
|
| 334 |
-
plt.title("Decision Tree Confusion Matrix")
|
| 335 |
-
gr.Plot(fig_dt)
|
| 336 |
-
|
| 337 |
-
with gr.Tab("1D CNN"):
|
| 338 |
-
gr.Markdown(f"✅ Train Accuracy: **{train_acc:.2f}**<br>🎯 Test Accuracy: **{test_acc:.2f}**")
|
| 339 |
-
fig_cnn = plt.figure()
|
| 340 |
-
sns.heatmap(confusion_matrix(y_test_raw, test_preds), annot=True, fmt='d')
|
| 341 |
-
plt.title("1D CNN Confusion Matrix")
|
| 342 |
-
gr.Plot(fig_cnn)
|
| 343 |
-
|
| 344 |
-
with gr.Tab("Prediction"):
|
| 345 |
-
model_choice = gr.Dropdown(['Random Forest', 'Decision Tree', '1D CNN'], label="Choose Model")
|
| 346 |
-
input_file = gr.File(label="Upload CSV (same format)")
|
| 347 |
-
output_table = gr.DataFrame(label="Predictions")
|
| 348 |
-
|
| 349 |
-
def predict(file, model_choice):
|
| 350 |
-
df_new = pd.read_csv(file.name)
|
| 351 |
-
if 'Label' in df_new.columns:
|
| 352 |
-
df_new = df_new.drop(columns=['Label'])
|
| 353 |
-
X_input = df_new.values
|
| 354 |
-
|
| 355 |
-
if model_choice == "1D CNN":
|
| 356 |
-
X_input_scaled = scaler.transform(X_input)
|
| 357 |
-
tensor_input = torch.tensor(X_input_scaled, dtype=torch.float32).unsqueeze(1)
|
| 358 |
-
with torch.no_grad():
|
| 359 |
-
preds = model(tensor_input).argmax(dim=1).numpy()
|
| 360 |
-
else:
|
| 361 |
-
X_input_pca = pca.transform(scaler.transform(X_input))
|
| 362 |
-
preds = rf.predict(X_input_pca) if model_choice == "Random Forest" else dt.predict(X_input_pca)
|
| 363 |
-
|
| 364 |
-
df_new['Predicted Label'] = le.inverse_transform(preds)
|
| 365 |
-
return df_new
|
| 366 |
-
|
| 367 |
-
predict_btn = gr.Button("Predict")
|
| 368 |
-
predict_btn.click(predict, inputs=[input_file, model_choice], outputs=[output_table])
|
| 369 |
|
| 370 |
with gr.Tab("Takeaways"):
|
| 371 |
gr.Markdown("## 🌾 Spectroscopy: Transforming the Dairy Sector")
|
|
@@ -392,5 +334,19 @@ with gr.Blocks() as demo:
|
|
| 392 |
Stay curious. Stay healthy.
|
| 393 |
""")
|
| 394 |
|
| 395 |
-
|
| 396 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
train_preds = model(X_train_tensor).argmax(dim=1)
|
| 298 |
train_acc = (train_preds == torch.tensor(y_train_raw)).float().mean().item()
|
| 299 |
|
|
|
|
| 300 |
with gr.Blocks() as demo:
|
| 301 |
+
gr.Markdown("# 🧪 Dataset Description")
|
| 302 |
|
| 303 |
with gr.Tabs():
|
| 304 |
with gr.Tab("Preview Raw Data"):
|
| 305 |
+
gr.DataFrame(df.head(50), label="Preview of Raw Data")
|
| 306 |
|
| 307 |
with gr.Tab("Visualizations"):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
plot_button = gr.Button("Generate Spectroscopy Visualizations")
|
| 309 |
+
out_gallery = [gr.Plot() for _ in range(8)]
|
| 310 |
+
plot_button.click(fn=plot_all, inputs=[], outputs=out_gallery)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
|
| 312 |
with gr.Tab("Takeaways"):
|
| 313 |
gr.Markdown("## 🌾 Spectroscopy: Transforming the Dairy Sector")
|
|
|
|
| 334 |
Stay curious. Stay healthy.
|
| 335 |
""")
|
| 336 |
|
| 337 |
+
with gr.Tab("Random Forest"):
|
| 338 |
+
gr.Image(value="random_forest_result.png", label="Random Forest Output")
|
| 339 |
+
|
| 340 |
+
with gr.Tab("Decision Tree"):
|
| 341 |
+
gr.Markdown("**Confusion Matrix**")
|
| 342 |
+
gr.Image(value="decision_tree_confusion.png", label="Confusion Matrix")
|
| 343 |
+
gr.Markdown("**Decision Tree Visualization**")
|
| 344 |
+
gr.Image(value="decision_tree_model.png", label="Tree Structure")
|
| 345 |
+
gr.Image(value="decision_tree_extra.png", label="Additional Insight")
|
| 346 |
+
|
| 347 |
+
with gr.Tab("1D CNN (Raw Data)"):
|
| 348 |
+
gr.Image(value="cnn_result_image.png", label="1D CNN Output")
|
| 349 |
+
|
| 350 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)
|
| 351 |
+
|
| 352 |
+
|