MSU576 commited on
Commit
72a027e
·
verified ·
1 Parent(s): 4cc911c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -40
app.py CHANGED
@@ -218,42 +218,44 @@ def site_details_panel():
218
  # =============================
219
  # FEATURE MODULES
220
  # =============================
221
-
222
- # 1. Soil Recognizer (CNN Image Model)
223
- # 1. Soil Recognizer (ResNet18 Model)
224
-
225
  import torch
226
  import torch.nn as nn
227
  import torchvision.models as models
228
  import torchvision.transforms as T
229
  from PIL import Image
230
- import numpy as np
231
  import streamlit as st
232
 
233
  # ----------------------------
234
- # Load Soil Recognition Model
235
  # ----------------------------
236
  @st.cache_resource
237
- def load_soil_model():
 
238
  try:
239
- # Load ResNet18 (pretrained base, adjust if you used different arch)
240
- model = models.resnet18(weights=None)
241
  num_ftrs = model.fc.in_features
242
- model.fc = nn.Linear(num_ftrs, 5) # 5 soil classes
243
- state_dict = torch.load("soil_best_model.pth", map_location="cpu")
 
 
244
  model.load_state_dict(state_dict)
 
245
  model.eval()
246
- return model
247
  except Exception as e:
248
  st.error(f"⚠️ Could not load soil model: {e}")
249
- return None
250
 
251
- soil_model = load_soil_model()
252
 
253
- # Soil class labels
254
- SOIL_CLASSES = ["Sand", "Silt", "Clay", "Gravel", "Peat"]
 
 
255
 
256
- # Image preprocessing
257
  transform = T.Compose([
258
  T.Resize((224, 224)),
259
  T.ToTensor(),
@@ -261,45 +263,54 @@ transform = T.Compose([
261
  [0.229, 0.224, 0.225])
262
  ])
263
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
264
  # ----------------------------
265
  # Soil Recognizer Page
266
  # ----------------------------
267
  def soil_recognizer_page():
268
  st.header("🖼️ Soil Recognizer (ResNet18)")
269
 
270
- site = get_active_site()
271
  if site is None:
272
  st.warning("⚠️ No active site selected. Please add or select a site from the sidebar.")
273
  return
274
 
275
  uploaded = st.file_uploader("Upload soil image", type=["jpg", "jpeg", "png"])
276
  if uploaded is not None:
277
- img = Image.open(uploaded).convert("RGB")
278
  st.image(img, caption="Uploaded soil image", use_column_width=True)
279
 
280
- if soil_model:
281
- try:
282
- inp = transform(img).unsqueeze(0) # add batch dim
283
- with torch.no_grad():
284
- logits = soil_model(inp)
285
- probs = torch.softmax(logits, dim=1)[0]
286
- conf, pred = torch.max(probs, 0)
287
- predicted_class = SOIL_CLASSES[pred.item()]
288
- confidence = conf.item()
289
-
290
- st.success(f"✅ Predicted: **{predicted_class}** ({confidence:.2%} confidence)")
291
-
292
- if st.button("Save to site"):
293
- site["Soil Profile"] = predicted_class
294
- site["Soil Recognizer Confidence"] = confidence
295
- save_active_site(site)
296
- st.success("Saved prediction to active site memory.")
297
 
298
- except Exception as e:
299
- st.error(f"❌ Inference error: {e}")
300
- else:
301
- st.warning("⚠️ Soil model not loaded. Please check `soil_best_model.pth`.")
302
 
 
 
 
 
 
303
 
304
 
305
  # ----------------------------
 
218
  # =============================
219
  # FEATURE MODULES
220
  # =============================
221
+ # ----------------------------
222
+ # Soil Recognizer Page (Integrated 6-Class ResNet18)
223
+ # ----------------------------
 
224
  import torch
225
  import torch.nn as nn
226
  import torchvision.models as models
227
  import torchvision.transforms as T
228
  from PIL import Image
 
229
  import streamlit as st
230
 
231
  # ----------------------------
232
+ # Load Soil Model (6 Classes)
233
  # ----------------------------
234
  @st.cache_resource
235
+ def load_soil_model(path="/content/drive/MyDrive/soil_best_model.pth"):
236
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
237
  try:
238
+ model = models.resnet18(pretrained=False)
 
239
  num_ftrs = model.fc.in_features
240
+ model.fc = nn.Linear(num_ftrs, 6) # 6 soil classes
241
+
242
+ # Load checkpoint
243
+ state_dict = torch.load(path, map_location=device)
244
  model.load_state_dict(state_dict)
245
+ model = model.to(device)
246
  model.eval()
247
+ return model, device
248
  except Exception as e:
249
  st.error(f"⚠️ Could not load soil model: {e}")
250
+ return None, device
251
 
252
+ soil_model, device = load_soil_model()
253
 
254
+ # ----------------------------
255
+ # Soil Classes & Transform
256
+ # ----------------------------
257
+ SOIL_CLASSES = ["Clay", "Gravel", "Loam", "Peat", "Sand", "Silt"]
258
 
 
259
  transform = T.Compose([
260
  T.Resize((224, 224)),
261
  T.ToTensor(),
 
263
  [0.229, 0.224, 0.225])
264
  ])
265
 
266
+ # ----------------------------
267
+ # Prediction Function
268
+ # ----------------------------
269
+ def predict_soil(img: Image.Image):
270
+ if soil_model is None:
271
+ return "Model not loaded", {}
272
+
273
+ img = img.convert("RGB")
274
+ inp = transform(img).unsqueeze(0).to(device)
275
+
276
+ with torch.no_grad():
277
+ logits = soil_model(inp)
278
+ probs = torch.softmax(logits[0], dim=0)
279
+
280
+ top_idx = torch.argmax(probs).item()
281
+ predicted_class = SOIL_CLASSES[top_idx]
282
+
283
+ result = {SOIL_CLASSES[i]: float(probs[i]) for i in range(len(SOIL_CLASSES))}
284
+ return predicted_class, result
285
+
286
  # ----------------------------
287
  # Soil Recognizer Page
288
  # ----------------------------
289
  def soil_recognizer_page():
290
  st.header("🖼️ Soil Recognizer (ResNet18)")
291
 
292
+ site = get_active_site() # your existing site getter
293
  if site is None:
294
  st.warning("⚠️ No active site selected. Please add or select a site from the sidebar.")
295
  return
296
 
297
  uploaded = st.file_uploader("Upload soil image", type=["jpg", "jpeg", "png"])
298
  if uploaded is not None:
299
+ img = Image.open(uploaded)
300
  st.image(img, caption="Uploaded soil image", use_column_width=True)
301
 
302
+ predicted_class, confidence_scores = predict_soil(img)
303
+ st.success(f"✅ Predicted: **{predicted_class}**")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
 
305
+ st.subheader("Confidence Scores")
306
+ for cls, score in confidence_scores.items():
307
+ st.write(f"{cls}: {score:.2%}")
 
308
 
309
+ if st.button("Save to site"):
310
+ site["Soil Profile"] = predicted_class
311
+ site["Soil Recognizer Confidence"] = confidence_scores[predicted_class]
312
+ save_active_site(site)
313
+ st.success("Saved prediction to active site memory.")
314
 
315
 
316
  # ----------------------------