NagashreePai commited on
Commit
90fa633
Β·
verified Β·
1 Parent(s): 350548c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -19
app.py CHANGED
@@ -7,9 +7,8 @@ from PIL import Image
7
 
8
  # πŸ”§ Model definition
9
  class MMIM(nn.Module):
10
- def __init__(self, num_classes=12):
11
  super(MMIM, self).__init__()
12
- print("[INFO] Initializing MMIM model...")
13
  self.backbone = swin_t(weights='IMAGENET1K_V1')
14
  self.backbone.head = nn.Identity()
15
  self.classifier = nn.Sequential(
@@ -21,21 +20,16 @@ class MMIM(nn.Module):
21
 
22
  def forward(self, x):
23
  features = self.backbone(x)
24
- print(f"[DEBUG] Backbone output shape: {features.shape}")
25
  return self.classifier(features)
26
 
27
  # βœ… Load model
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
- print(f"[INFO] Using device: {device}")
30
-
31
  model = MMIM(num_classes=12)
32
- print("[INFO] Loading model weights from MMIM_best2.pth...")
33
  model.load_state_dict(torch.load("MMIM_best2.pth", map_location=device))
34
  model.to(device)
35
  model.eval()
36
- print("[INFO] Model loaded and ready.")
37
 
38
- # βœ… Class names
39
  class_names = [
40
  'Black grass', 'Charlock', 'Cleavers', 'Common Chickweed', 'Common Wheat',
41
  'Fat Hen', 'Loose Silky-bent', 'Maize', 'Scentless Mayweed',
@@ -48,30 +42,31 @@ transform = transforms.Compose([
48
  transforms.ToTensor()
49
  ])
50
 
51
- # πŸ” Prediction function
52
  def predict(img):
53
- print("[INFO] Received image for prediction.")
54
  img = img.convert('RGB')
55
  img_tensor = transform(img).unsqueeze(0).to(device)
56
- print(f"[DEBUG] Image tensor shape: {img_tensor.shape}")
57
 
58
  with torch.no_grad():
59
  outputs = model(img_tensor)
60
- print(f"[DEBUG] Raw model outputs: {outputs}")
61
- _, pred = torch.max(outputs, 1)
62
- predicted_class = class_names[pred.item()]
63
- print(f"[INFO] Predicted class: {predicted_class} (index {pred.item()})")
 
64
 
65
- return predicted_class
 
 
 
66
 
67
  # 🎨 Gradio Interface
68
  interface = gr.Interface(
69
  fn=predict,
70
  inputs=gr.Image(type="pil"),
71
- outputs="label",
72
  title="Weed Image Classifier",
73
- description="Upload a weed image to predict its class"
74
  )
75
 
76
  interface.launch()
77
-
 
7
 
8
  # πŸ”§ Model definition
9
  class MMIM(nn.Module):
10
+ def __init__(self, num_classes=9):
11
  super(MMIM, self).__init__()
 
12
  self.backbone = swin_t(weights='IMAGENET1K_V1')
13
  self.backbone.head = nn.Identity()
14
  self.classifier = nn.Sequential(
 
20
 
21
  def forward(self, x):
22
  features = self.backbone(x)
 
23
  return self.classifier(features)
24
 
25
  # βœ… Load model
26
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
27
  model = MMIM(num_classes=12)
 
28
  model.load_state_dict(torch.load("MMIM_best2.pth", map_location=device))
29
  model.to(device)
30
  model.eval()
 
31
 
32
+ # βœ… Updated class names (match folder structure)
33
  class_names = [
34
  'Black grass', 'Charlock', 'Cleavers', 'Common Chickweed', 'Common Wheat',
35
  'Fat Hen', 'Loose Silky-bent', 'Maize', 'Scentless Mayweed',
 
42
  transforms.ToTensor()
43
  ])
44
 
45
+ # πŸ” Prediction function with negative detection
46
  def predict(img):
 
47
  img = img.convert('RGB')
48
  img_tensor = transform(img).unsqueeze(0).to(device)
 
49
 
50
  with torch.no_grad():
51
  outputs = model(img_tensor)
52
+ probs = torch.softmax(outputs, dim=1)
53
+ conf, pred = torch.max(probs, 1)
54
+
55
+ predicted_class = class_names[pred.item()]
56
+ confidence = conf.item() * 100
57
 
58
+ if predicted_class.lower() == "negative":
59
+ return f"⚠️ This image is predicted as Negative.\nConfidence: {confidence:.2f}%"
60
+
61
+ return f"βœ… Predicted as a weed with class-{predicted_class}\nConfidence: {confidence:.2f}%"
62
 
63
  # 🎨 Gradio Interface
64
  interface = gr.Interface(
65
  fn=predict,
66
  inputs=gr.Image(type="pil"),
67
+ outputs="text",
68
  title="Weed Image Classifier",
69
+ description="Upload a weed image to predict its class. If the model detects a non-weed image, it will return 'Negative'."
70
  )
71
 
72
  interface.launch()