EzekielMW commited on
Commit
3c7ea85
·
verified ·
1 Parent(s): b37aa6c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -64
app.py CHANGED
@@ -297,75 +297,17 @@ with torch.no_grad():
297
  train_preds = model(X_train_tensor).argmax(dim=1)
298
  train_acc = (train_preds == torch.tensor(y_train_raw)).float().mean().item()
299
 
300
- # === Gradio App ===
301
  with gr.Blocks() as demo:
302
- gr.Markdown("# 🥛 NIR Milk Spectroscopy Analysis App")
303
 
304
  with gr.Tabs():
305
  with gr.Tab("Preview Raw Data"):
306
- gr.DataFrame(df.head(50), label="Milk Spectra")
307
 
308
  with gr.Tab("Visualizations"):
309
- def plot_all():
310
- plots = []
311
- for i in range(8):
312
- fig, ax = plt.subplots()
313
- ax.plot(X[i])
314
- ax.set_title(f"Spectrum {i+1}")
315
- plots.append(fig)
316
- return plots
317
  plot_button = gr.Button("Generate Spectroscopy Visualizations")
318
- output_plots = [gr.Plot() for _ in range(8)]
319
- plot_button.click(fn=plot_all, inputs=[], outputs=output_plots)
320
-
321
- with gr.Tab("Models"):
322
- with gr.Tabs():
323
- with gr.Tab("Random Forest"):
324
- gr.Markdown(f"✅ Train Accuracy: **{accuracy_score(y_train, rf.predict(X_train)):.2f}**<br>🎯 Test Accuracy: **{accuracy_score(y_test, rf.predict(X_test)):.2f}**")
325
- fig_rf = plt.figure()
326
- sns.heatmap(confusion_matrix(y_test, rf.predict(X_test)), annot=True, fmt='d')
327
- plt.title("Random Forest Confusion Matrix")
328
- gr.Plot(fig_rf)
329
-
330
- with gr.Tab("Decision Tree"):
331
- gr.Markdown(f"✅ Train Accuracy: **{accuracy_score(y_train, dt.predict(X_train)):.2f}**<br>🎯 Test Accuracy: **{accuracy_score(y_test, dt.predict(X_test)):.2f}**")
332
- fig_dt = plt.figure()
333
- sns.heatmap(confusion_matrix(y_test, dt.predict(X_test)), annot=True, fmt='d')
334
- plt.title("Decision Tree Confusion Matrix")
335
- gr.Plot(fig_dt)
336
-
337
- with gr.Tab("1D CNN"):
338
- gr.Markdown(f"✅ Train Accuracy: **{train_acc:.2f}**<br>🎯 Test Accuracy: **{test_acc:.2f}**")
339
- fig_cnn = plt.figure()
340
- sns.heatmap(confusion_matrix(y_test_raw, test_preds), annot=True, fmt='d')
341
- plt.title("1D CNN Confusion Matrix")
342
- gr.Plot(fig_cnn)
343
-
344
- with gr.Tab("Prediction"):
345
- model_choice = gr.Dropdown(['Random Forest', 'Decision Tree', '1D CNN'], label="Choose Model")
346
- input_file = gr.File(label="Upload CSV (same format)")
347
- output_table = gr.DataFrame(label="Predictions")
348
-
349
- def predict(file, model_choice):
350
- df_new = pd.read_csv(file.name)
351
- if 'Label' in df_new.columns:
352
- df_new = df_new.drop(columns=['Label'])
353
- X_input = df_new.values
354
-
355
- if model_choice == "1D CNN":
356
- X_input_scaled = scaler.transform(X_input)
357
- tensor_input = torch.tensor(X_input_scaled, dtype=torch.float32).unsqueeze(1)
358
- with torch.no_grad():
359
- preds = model(tensor_input).argmax(dim=1).numpy()
360
- else:
361
- X_input_pca = pca.transform(scaler.transform(X_input))
362
- preds = rf.predict(X_input_pca) if model_choice == "Random Forest" else dt.predict(X_input_pca)
363
-
364
- df_new['Predicted Label'] = le.inverse_transform(preds)
365
- return df_new
366
-
367
- predict_btn = gr.Button("Predict")
368
- predict_btn.click(predict, inputs=[input_file, model_choice], outputs=[output_table])
369
 
370
  with gr.Tab("Takeaways"):
371
  gr.Markdown("## 🌾 Spectroscopy: Transforming the Dairy Sector")
@@ -392,5 +334,19 @@ with gr.Blocks() as demo:
392
  Stay curious. Stay healthy.
393
  """)
394
 
395
- # === Run the app ===
396
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  train_preds = model(X_train_tensor).argmax(dim=1)
298
  train_acc = (train_preds == torch.tensor(y_train_raw)).float().mean().item()
299
 
 
300
  with gr.Blocks() as demo:
301
+ gr.Markdown("# 🧪 Dataset Description")
302
 
303
  with gr.Tabs():
304
  with gr.Tab("Preview Raw Data"):
305
+ gr.DataFrame(df.head(50), label="Preview of Raw Data")
306
 
307
  with gr.Tab("Visualizations"):
 
 
 
 
 
 
 
 
308
  plot_button = gr.Button("Generate Spectroscopy Visualizations")
309
+ out_gallery = [gr.Plot() for _ in range(8)]
310
+ plot_button.click(fn=plot_all, inputs=[], outputs=out_gallery)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
 
312
  with gr.Tab("Takeaways"):
313
  gr.Markdown("## 🌾 Spectroscopy: Transforming the Dairy Sector")
 
334
  Stay curious. Stay healthy.
335
  """)
336
 
337
+ with gr.Tab("Random Forest"):
338
+ gr.Image(value="random_forest_result.png", label="Random Forest Output")
339
+
340
+ with gr.Tab("Decision Tree"):
341
+ gr.Markdown("**Confusion Matrix**")
342
+ gr.Image(value="decision_tree_confusion.png", label="Confusion Matrix")
343
+ gr.Markdown("**Decision Tree Visualization**")
344
+ gr.Image(value="decision_tree_model.png", label="Tree Structure")
345
+ gr.Image(value="decision_tree_extra.png", label="Additional Insight")
346
+
347
+ with gr.Tab("1D CNN (Raw Data)"):
348
+ gr.Image(value="cnn_result_image.png", label="1D CNN Output")
349
+
350
+ demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)
351
+
352
+