EzekielMW commited on
Commit
5cb4d7a
·
verified ·
1 Parent(s): 40ff2e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -74
app.py CHANGED
@@ -231,23 +231,27 @@ def plot_all():
231
 
232
  return plots
233
 
234
- # ---------- Prepare Data for Modeling ----------
 
 
 
235
  X = df.iloc[:, 1:].values
236
  scaler = StandardScaler()
237
  X_scaled = scaler.fit_transform(X)
 
 
238
  pca = PCA(n_components=2)
239
  X_pca = pca.fit_transform(X_scaled)
240
  X_train, X_test, y_train, y_test = train_test_split(X_pca, y, test_size=0.2, random_state=42)
241
 
242
- # ---------- Train Random Forest ----------
243
  rf = RandomForestClassifier(n_estimators=100, random_state=42)
244
  rf.fit(X_train, y_train)
245
 
246
- # ---------- Train Decision Tree ----------
247
  dt = DecisionTreeClassifier(random_state=42)
248
  dt.fit(X_train, y_train)
249
 
250
- # ---------- CNN on Raw Data ----------
251
  class MilkDataset(Dataset):
252
  def __init__(self, X, y):
253
  self.X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
@@ -255,12 +259,9 @@ class MilkDataset(Dataset):
255
  def __len__(self): return len(self.X)
256
  def __getitem__(self, idx): return self.X[idx], self.y[idx]
257
 
258
- X_raw_scaled = scaler.fit_transform(X)
259
- X_train_raw, X_test_raw, y_train_raw, y_test_raw = train_test_split(X_raw_scaled, y, test_size=0.2, random_state=42)
260
- train_dataset = MilkDataset(X_train_raw, y_train_raw)
261
- test_dataset = MilkDataset(X_test_raw, y_test_raw)
262
- train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
263
- test_loader = DataLoader(test_dataset, batch_size=16)
264
 
265
  class CNN1D(nn.Module):
266
  def __init__(self):
@@ -278,119 +279,118 @@ model = CNN1D()
278
  criterion = nn.CrossEntropyLoss()
279
  optimizer = optim.Adam(model.parameters(), lr=0.001)
280
 
281
- train_acc_list, test_acc_list = [], []
282
- for epoch in range(1, 11):
283
  model.train()
284
  for Xb, yb in train_loader:
285
  optimizer.zero_grad()
286
  loss = criterion(model(Xb), yb)
287
  loss.backward()
288
  optimizer.step()
289
- model.eval()
290
- with torch.no_grad():
291
- train_preds = torch.argmax(model(torch.cat([X for X, _ in train_loader], 0)), dim=1)
292
- test_preds = torch.argmax(model(torch.cat([X for X, _ in test_loader], 0)), dim=1)
293
- y_train_all = torch.cat([y for _, y in train_loader])
294
- y_test_all = torch.cat([y for _, y in test_loader])
295
- train_acc = (train_preds == y_train_all).float().mean().item()
296
- test_acc = (test_preds == y_test_all).float().mean().item()
297
- train_acc_list.append(train_acc)
298
- test_acc_list.append(test_acc)
299
-
300
- # ---------- Gradio Interface ----------
301
  with gr.Blocks() as demo:
302
- gr.Markdown("# 🧪 Dataset Description")
 
303
  with gr.Tabs():
304
  with gr.Tab("Preview Raw Data"):
305
- gr.DataFrame(df.head(50), label="Preview of Raw Data")
306
 
307
  with gr.Tab("Visualizations"):
 
 
 
 
 
 
 
 
308
  plot_button = gr.Button("Generate Spectroscopy Visualizations")
309
- out_gallery = [gr.Plot() for _ in range(8)]
310
- plot_button.click(fn=plot_all, inputs=[], outputs=out_gallery)
311
 
312
  with gr.Tab("Models"):
313
  with gr.Tabs():
314
  with gr.Tab("Random Forest"):
315
- gr.Markdown(f"""Train Accuracy: {accuracy_score(y_train, rf.predict(X_train)):.2f} \
316
- Test Accuracy: {accuracy_score(y_test, rf.predict(X_test)):.2f}""")
317
  fig_rf = plt.figure()
318
  sns.heatmap(confusion_matrix(y_test, rf.predict(X_test)), annot=True, fmt='d')
319
  plt.title("Random Forest Confusion Matrix")
320
  gr.Plot(fig_rf)
321
 
322
  with gr.Tab("Decision Tree"):
323
- gr.Markdown(f"""Train Accuracy: {accuracy_score(y_train, dt.predict(X_train)):.2f} \
324
- Test Accuracy: {accuracy_score(y_test, dt.predict(X_test)):.2f}""")
325
  fig_dt = plt.figure()
326
  sns.heatmap(confusion_matrix(y_test, dt.predict(X_test)), annot=True, fmt='d')
327
  plt.title("Decision Tree Confusion Matrix")
328
  gr.Plot(fig_dt)
329
 
330
- with gr.Tab("1D CNN (Raw Data)"):
331
- gr.Markdown(f"""Train Accuracy: {train_acc:.2f} \
332
- Test Accuracy: {test_acc:.2f}""")
333
  fig_cnn = plt.figure()
334
- sns.heatmap(confusion_matrix(y_test_all, test_preds), annot=True, fmt='d')
335
  plt.title("1D CNN Confusion Matrix")
336
  gr.Plot(fig_cnn)
337
 
338
  with gr.Tab("Prediction"):
339
- model_dropdown = gr.Dropdown(choices=['Random Forest', 'Decision Tree', '1D CNN'], label="Choose Model")
340
- input_file = gr.File(label="Upload CSV File (Same Format as Original Data)")
341
- output_df = gr.DataFrame(label="Predicted Labels")
342
-
343
- def predict(file, model_name):
344
- test_df = pd.read_csv(file.name)
345
- if 'Label' in test_df.columns:
346
- test_df = test_df.drop(columns=['Label'])
347
- X_input = test_df.values
348
- if model_name == '1D CNN':
349
- X_scaled = scaler.transform(X_input)
350
- X_tensor = torch.tensor(X_scaled, dtype=torch.float32).unsqueeze(1)
 
351
  with torch.no_grad():
352
- preds = torch.argmax(model(X_tensor), dim=1).numpy()
353
- preds = le.inverse_transform(preds)
354
  else:
355
- X_pca_input = pca.transform(scaler.transform(X_input))
356
- preds = rf.predict(X_pca_input) if model_name == 'Random Forest' else dt.predict(X_pca_input)
357
- preds = le.inverse_transform(preds)
358
- test_df['Predicted Label'] = preds
359
- return test_df
360
 
361
  predict_btn = gr.Button("Predict")
362
- predict_btn.click(fn=predict, inputs=[input_file, model_dropdown], outputs=[output_df])
363
 
364
  with gr.Tab("Takeaways"):
365
  gr.Markdown("## 🌾 Spectroscopy: Transforming the Dairy Sector")
366
-
367
  gr.Markdown("""
368
  ### 👨‍🌾 Farmers
369
- - Enables quick and non-destructive testing of milk quality.
370
- - Helps identify adulteration or spoilage early.
371
- - Boosts credibility and fair pricing in local and export markets.
372
 
373
  ### 🏧 Government
374
- - Supports enforcement of food safety and regulatory standards.
375
- - Aids in surveillance of quality at collection centers and cooperatives.
376
- - Encourages investment in agri-tech and rural innovation.
377
 
378
  ### 🏢 Businesses & Cooperatives
379
- - Enhances supply chain quality control.
380
- - Reduces reliance on expensive lab-based testing.
381
- - Increases transparency and trust with consumers.
382
-
383
- ### 🧠 Why Spectroscopy?
384
- - Non-invasive, fast, and cost-effective.
385
- - Adaptable for large-scale or smallholder use.
386
- - Unlocks new value in digitizing dairy analytics.
387
 
388
  ---
389
 
390
- ### 💡 Parting Shot: Health Starts With What You Consume
391
  > “Milk is nature’s first food – and it should remain pure. Spectroscopy empowers us to ensure it stays that way.”
392
  Stay curious. Stay healthy.
393
  """)
394
 
395
- # Run app
396
- demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)
 
231
 
232
  return plots
233
 
234
+
235
+ # Encode labels
236
+ le = LabelEncoder()
237
+ y = le.fit_transform(df['Label'].values)
238
  X = df.iloc[:, 1:].values
239
  scaler = StandardScaler()
240
  X_scaled = scaler.fit_transform(X)
241
+
242
+ # === PCA reduction ===
243
  pca = PCA(n_components=2)
244
  X_pca = pca.fit_transform(X_scaled)
245
  X_train, X_test, y_train, y_test = train_test_split(X_pca, y, test_size=0.2, random_state=42)
246
 
247
+ # === Models ===
248
  rf = RandomForestClassifier(n_estimators=100, random_state=42)
249
  rf.fit(X_train, y_train)
250
 
 
251
  dt = DecisionTreeClassifier(random_state=42)
252
  dt.fit(X_train, y_train)
253
 
254
+ # === CNN ===
255
  class MilkDataset(Dataset):
256
  def __init__(self, X, y):
257
  self.X = torch.tensor(X, dtype=torch.float32).unsqueeze(1)
 
259
  def __len__(self): return len(self.X)
260
  def __getitem__(self, idx): return self.X[idx], self.y[idx]
261
 
262
+ X_train_raw, X_test_raw, y_train_raw, y_test_raw = train_test_split(X_scaled, y, test_size=0.2, random_state=42)
263
+ train_loader = DataLoader(MilkDataset(X_train_raw, y_train_raw), batch_size=16, shuffle=True)
264
+ test_loader = DataLoader(MilkDataset(X_test_raw, y_test_raw), batch_size=16)
 
 
 
265
 
266
  class CNN1D(nn.Module):
267
  def __init__(self):
 
279
  criterion = nn.CrossEntropyLoss()
280
  optimizer = optim.Adam(model.parameters(), lr=0.001)
281
 
282
+ for epoch in range(10):
 
283
  model.train()
284
  for Xb, yb in train_loader:
285
  optimizer.zero_grad()
286
  loss = criterion(model(Xb), yb)
287
  loss.backward()
288
  optimizer.step()
289
+
290
+ model.eval()
291
+ with torch.no_grad():
292
+ X_test_tensor = torch.tensor(X_test_raw, dtype=torch.float32).unsqueeze(1)
293
+ test_preds = model(X_test_tensor).argmax(dim=1)
294
+ test_acc = (test_preds == torch.tensor(y_test_raw)).float().mean().item()
295
+
296
+ X_train_tensor = torch.tensor(X_train_raw, dtype=torch.float32).unsqueeze(1)
297
+ train_preds = model(X_train_tensor).argmax(dim=1)
298
+ train_acc = (train_preds == torch.tensor(y_train_raw)).float().mean().item()
299
+
300
+ # === Gradio App ===
301
  with gr.Blocks() as demo:
302
+ gr.Markdown("# 🥛 NIR Milk Spectroscopy Analysis App")
303
+
304
  with gr.Tabs():
305
  with gr.Tab("Preview Raw Data"):
306
+ gr.DataFrame(df.head(50), label="Milk Spectra")
307
 
308
  with gr.Tab("Visualizations"):
309
+ def plot_all():
310
+ plots = []
311
+ for i in range(8):
312
+ fig, ax = plt.subplots()
313
+ ax.plot(X[i])
314
+ ax.set_title(f"Spectrum {i+1}")
315
+ plots.append(fig)
316
+ return plots
317
  plot_button = gr.Button("Generate Spectroscopy Visualizations")
318
+ output_plots = [gr.Plot() for _ in range(8)]
319
+ plot_button.click(fn=plot_all, inputs=[], outputs=output_plots)
320
 
321
  with gr.Tab("Models"):
322
  with gr.Tabs():
323
  with gr.Tab("Random Forest"):
324
+ gr.Markdown(f"Train Accuracy: **{accuracy_score(y_train, rf.predict(X_train)):.2f}**<br>🎯 Test Accuracy: **{accuracy_score(y_test, rf.predict(X_test)):.2f}**")
 
325
  fig_rf = plt.figure()
326
  sns.heatmap(confusion_matrix(y_test, rf.predict(X_test)), annot=True, fmt='d')
327
  plt.title("Random Forest Confusion Matrix")
328
  gr.Plot(fig_rf)
329
 
330
  with gr.Tab("Decision Tree"):
331
+ gr.Markdown(f"Train Accuracy: **{accuracy_score(y_train, dt.predict(X_train)):.2f}**<br>🎯 Test Accuracy: **{accuracy_score(y_test, dt.predict(X_test)):.2f}**")
 
332
  fig_dt = plt.figure()
333
  sns.heatmap(confusion_matrix(y_test, dt.predict(X_test)), annot=True, fmt='d')
334
  plt.title("Decision Tree Confusion Matrix")
335
  gr.Plot(fig_dt)
336
 
337
+ with gr.Tab("1D CNN"):
338
+ gr.Markdown(f"Train Accuracy: **{train_acc:.2f}**<br>🎯 Test Accuracy: **{test_acc:.2f}**")
 
339
  fig_cnn = plt.figure()
340
+ sns.heatmap(confusion_matrix(y_test_raw, test_preds), annot=True, fmt='d')
341
  plt.title("1D CNN Confusion Matrix")
342
  gr.Plot(fig_cnn)
343
 
344
  with gr.Tab("Prediction"):
345
+ model_choice = gr.Dropdown(['Random Forest', 'Decision Tree', '1D CNN'], label="Choose Model")
346
+ input_file = gr.File(label="Upload CSV (same format)")
347
+ output_table = gr.DataFrame(label="Predictions")
348
+
349
+ def predict(file, model_choice):
350
+ df_new = pd.read_csv(file.name)
351
+ if 'Label' in df_new.columns:
352
+ df_new = df_new.drop(columns=['Label'])
353
+ X_input = df_new.values
354
+
355
+ if model_choice == "1D CNN":
356
+ X_input_scaled = scaler.transform(X_input)
357
+ tensor_input = torch.tensor(X_input_scaled, dtype=torch.float32).unsqueeze(1)
358
  with torch.no_grad():
359
+ preds = model(tensor_input).argmax(dim=1).numpy()
 
360
  else:
361
+ X_input_pca = pca.transform(scaler.transform(X_input))
362
+ preds = rf.predict(X_input_pca) if model_choice == "Random Forest" else dt.predict(X_input_pca)
363
+
364
+ df_new['Predicted Label'] = le.inverse_transform(preds)
365
+ return df_new
366
 
367
  predict_btn = gr.Button("Predict")
368
+ predict_btn.click(predict, inputs=[input_file, model_choice], outputs=[output_table])
369
 
370
  with gr.Tab("Takeaways"):
371
  gr.Markdown("## 🌾 Spectroscopy: Transforming the Dairy Sector")
 
372
  gr.Markdown("""
373
  ### 👨‍🌾 Farmers
374
+ - Quick, non-destructive testing of milk quality.
375
+ - Early detection of spoilage or adulteration.
376
+ - Enables fairer pricing in cooperative and market setups.
377
 
378
  ### 🏧 Government
379
+ - Strengthens food safety monitoring.
380
+ - Ensures consistent quality across the supply chain.
381
+ - Fosters innovation in rural/agricultural tech.
382
 
383
  ### 🏢 Businesses & Cooperatives
384
+ - Real-time quality control in logistics.
385
+ - Cost-effective compared to traditional labs.
386
+ - Enhances trust through transparency.
 
 
 
 
 
387
 
388
  ---
389
 
390
+ ### 💡 Final Note on Healthy Living
391
  > “Milk is nature’s first food – and it should remain pure. Spectroscopy empowers us to ensure it stays that way.”
392
  Stay curious. Stay healthy.
393
  """)
394
 
395
+ # === Run the app ===
396
+ demo.launch(server_name="0.0.0.0", server_port=7860)