NagashreePai commited on
Commit
429b365
Β·
verified Β·
1 Parent(s): 5957fe9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -77
app.py CHANGED
@@ -5,83 +5,6 @@ from torchvision import transforms
5
  from torchvision.models import swin_t
6
  from PIL import Image
7
 
8
- # πŸ”§ Model definition
9
- class MMIM(nn.Module):
10
- def __init__(self, num_classes=4):
11
- super(MMIM, self).__init__()
12
- print("[INFO] Initializing MMIM model...")
13
- self.backbone = swin_t(weights='IMAGENET1K_V1')
14
- self.backbone.head = nn.Identity()
15
- self.classifier = nn.Sequential(
16
- nn.Linear(768, 512),
17
- nn.ReLU(),
18
- nn.Dropout(0.3),
19
- nn.Linear(512, num_classes)
20
- )
21
-
22
- def forward(self, x):
23
- features = self.backbone(x)
24
- print(f"[DEBUG] Feature shape: {features.shape}")
25
- return self.classifier(features)
26
-
27
- # βœ… Load model
28
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
- print(f"[INFO] Using device: {device}")
30
-
31
- model = MMIM(num_classes=4)
32
- model.load_state_dict(torch.load("MMIM_best3.pth", map_location=device))
33
- model.to(device)
34
- model.eval()
35
- print("[INFO] Model loaded successfully.")
36
-
37
- # βœ… Class names
38
- class_names = []
39
-
40
- # πŸ” Transform
41
- transform = transforms.Compose([
42
- transforms.Resize((224, 224)),
43
- transforms.ToTensor()
44
- ])
45
-
46
- # πŸ” Prediction
47
- def predict(img):
48
- print("[INFO] Image received for prediction.")
49
- img = img.convert('RGB')
50
- img_tensor = transform(img).unsqueeze(0).to(device)
51
- print(f"[DEBUG] Tensor shape: {img_tensor.shape}")
52
-
53
- with torch.no_grad():
54
- outputs = model(img_tensor)
55
- probs = torch.softmax(outputs, dim=1)
56
- conf, pred = torch.max(probs, 1)
57
-
58
- predicted_class = class_names[pred.item()]
59
- confidence = conf.item() * 100
60
- print(f"[INFO] Predicted: {predicted_class}, Confidence: {confidence:.2f}%")
61
-
62
- return f"βœ… Predicted as a weed with class-{predicted_class}\nConfidence: {confidence:.2f}%"
63
-
64
- # 🎨 Gradio Interface
65
- interface = gr.Interface(
66
- fn=predict,
67
- inputs=gr.Image(type="pil", label="Upload Weed Image"), # βœ… tool removed
68
- outputs="text",
69
- title="Weed Image Classifier",
70
- description="Upload a weed image to predict its class. If the model detects a non-weed image, it will return 'Negative'.",
71
- allow_flagging="manual",
72
- live=True
73
- )
74
-
75
- interface.launch()
76
-
77
-
78
- import gradio as gr
79
- import torch
80
- import torch.nn as nn
81
- from torchvision import transforms
82
- from torchvision.models import swin_t
83
- from PIL import Image
84
-
85
  # πŸ”§ Model definition
86
  class MMIM(nn.Module):
87
  def __init__(self, num_classes=9):
 
5
  from torchvision.models import swin_t
6
  from PIL import Image
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  # πŸ”§ Model definition
9
  class MMIM(nn.Module):
10
  def __init__(self, num_classes=9):