NagashreePai commited on
Commit
f5ac3ce
Β·
verified Β·
1 Parent(s): 8b3c396

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -38
app.py CHANGED
@@ -4,8 +4,6 @@ import torch.nn as nn
4
  from torchvision import transforms
5
  from torchvision.models import swin_t
6
  from PIL import Image
7
- import matplotlib.pyplot as plt
8
- import io
9
 
10
  # πŸ”§ Model definition
11
  class MMIM(nn.Module):
@@ -31,7 +29,7 @@ model.load_state_dict(torch.load("MMIM_best.pth", map_location=device))
31
  model.to(device)
32
  model.eval()
33
 
34
- # βœ… Class names
35
  class_names = [
36
  "Chinee apple", "Lantana", "Negative", "Parkinsonia", "Parthenium",
37
  "Prickly acacia", "Rubber vine", "Siam weed", "Snake weed"
@@ -43,57 +41,32 @@ transform = transforms.Compose([
43
  transforms.ToTensor()
44
  ])
45
 
46
- # πŸ” Prediction with confidence + top-3 chart
47
  def predict(img):
48
  img = img.convert('RGB')
49
  img_tensor = transform(img).unsqueeze(0).to(device)
50
 
51
  with torch.no_grad():
52
  outputs = model(img_tensor)
53
- probs = torch.softmax(outputs, dim=1).cpu().squeeze()
 
54
 
55
- conf, pred = torch.max(probs, 0)
56
  predicted_class = class_names[pred.item()]
57
  confidence = conf.item() * 100
58
 
59
- # βœ… Main prediction message
60
  if predicted_class.lower() == "negative":
61
- message = f"⚠️ Predicted as **Negative**\nConfidence: **{confidence:.2f}%**"
62
- else:
63
- message = f"βœ… Predicted class: **{predicted_class}**\nConfidence: **{confidence:.2f}%**"
64
 
65
- # πŸ“Š Plot top 3 predictions
66
- top_probs, top_idxs = torch.topk(probs, 3)
67
- top_classes = [class_names[i] for i in top_idxs]
68
- top_confidences = [p.item() * 100 for p in top_probs]
69
 
70
- fig, ax = plt.subplots(figsize=(6, 3))
71
- ax.barh(top_classes[::-1], top_confidences[::-1], color='skyblue')
72
- ax.set_xlim(0, 100)
73
- ax.set_xlabel("Confidence (%)")
74
- ax.set_title("Top 3 Predictions")
75
- plt.tight_layout()
76
-
77
- buf = io.BytesIO()
78
- plt.savefig(buf, format="png")
79
- plt.close(fig)
80
- buf.seek(0)
81
-
82
- from PIL import Image as PILImage
83
- chart_img = PILImage.open(buf)
84
-
85
- return message, chart_img
86
-
87
- # 🎨 Gradio UI
88
  interface = gr.Interface(
89
  fn=predict,
90
  inputs=gr.Image(type="pil"),
91
- outputs=[
92
- gr.Textbox(label="Prediction with Confidence"),
93
- gr.Image(type="pil", label="Top 3 Prediction Chart")
94
- ],
95
- title="Weed Image Classifier - Debug Mode",
96
- description="πŸ§ͺ Debug view enabled: see prediction confidence and top 3 predicted classes."
97
  )
98
 
99
  interface.launch()
 
 
4
  from torchvision import transforms
5
  from torchvision.models import swin_t
6
  from PIL import Image
 
 
7
 
8
  # πŸ”§ Model definition
9
  class MMIM(nn.Module):
 
29
  model.to(device)
30
  model.eval()
31
 
32
+ # βœ… Updated class names (match folder structure)
33
  class_names = [
34
  "Chinee apple", "Lantana", "Negative", "Parkinsonia", "Parthenium",
35
  "Prickly acacia", "Rubber vine", "Siam weed", "Snake weed"
 
41
  transforms.ToTensor()
42
  ])
43
 
44
+ # πŸ” Prediction function with negative detection
45
  def predict(img):
46
  img = img.convert('RGB')
47
  img_tensor = transform(img).unsqueeze(0).to(device)
48
 
49
  with torch.no_grad():
50
  outputs = model(img_tensor)
51
+ probs = torch.softmax(outputs, dim=1)
52
+ conf, pred = torch.max(probs, 1)
53
 
 
54
  predicted_class = class_names[pred.item()]
55
  confidence = conf.item() * 100
56
 
 
57
  if predicted_class.lower() == "negative":
58
+ return f"⚠️ This image is predicted as Negative.\nConfidence: {confidence:.2f}%"
 
 
59
 
60
+ return f"βœ… Predicted as a weed with class-{predicted_class}\nConfidence: {confidence:.2f}%"
 
 
 
61
 
62
+ # 🎨 Gradio Interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  interface = gr.Interface(
64
  fn=predict,
65
  inputs=gr.Image(type="pil"),
66
+ outputs="text",
67
+ title="Weed Image Classifier",
68
+ description="Upload a weed image to predict its class. If the model detects a non-weed image, it will return 'Negative'."
 
 
 
69
  )
70
 
71
  interface.launch()
72
+ remove the confidence level