SoraRyuu commited on
Commit
9a7fc09
·
verified ·
1 Parent(s): 381e2e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +54 -40
app.py CHANGED
@@ -1,12 +1,13 @@
1
  import torch
2
  import torch.nn as nn
3
- from torchvision import transforms, models
 
4
  from PIL import Image
5
  import gradio as gr
6
 
7
- # --------------------------------------------------------------
8
- # CLASS LABELS (17 classes)
9
- # --------------------------------------------------------------
10
  CLASS_LABELS = [
11
  'Corn_Common_Rust', 'Corn_Gray_Leaf_Spot', 'Corn_Healthy', 'Corn_Northern_Leaf_Blight',
12
  'Potato_Early_Blight', 'Potato_Healthy', 'Potato_Late_Blight',
@@ -17,34 +18,50 @@ CLASS_LABELS = [
17
 
18
  NUM_CLASSES = len(CLASS_LABELS)
19
 
20
- # --------------------------------------------------------------
21
- # RESNET50 MODEL DEFINITION
22
- # --------------------------------------------------------------
 
23
  class ResNetPlantDisease(nn.Module):
24
- def __init__(self, num_classes=17):
25
  super().__init__()
26
- self.model = models.resnet50(weights=None)
27
- self.model.fc = nn.Sequential(
 
 
 
 
 
 
 
 
 
 
 
 
28
  nn.Dropout(0.5),
29
- nn.Linear(2048, 512),
30
  nn.ReLU(),
31
  nn.Dropout(0.3),
32
  nn.Linear(512, num_classes)
33
  )
34
 
35
  def forward(self, x):
36
- return self.model(x)
37
 
38
- # --------------------------------------------------------------
 
39
  # LOAD MODEL
40
- # --------------------------------------------------------------
41
- model = ResNetPlantDisease(num_classes=NUM_CLASSES)
42
- model.load_state_dict(torch.load("plant_disease_resnet_model.pth", map_location="cpu"))
 
43
  model.eval()
44
 
45
- # --------------------------------------------------------------
46
- # PREPROCESSING
47
- # --------------------------------------------------------------
 
48
  transform = transforms.Compose([
49
  transforms.Resize((224, 224)),
50
  transforms.ToTensor(),
@@ -54,34 +71,31 @@ transform = transforms.Compose([
54
  )
55
  ])
56
 
57
- # --------------------------------------------------------------
58
- # DISEASE CLASSIFICATION FUNCTION
59
- # --------------------------------------------------------------
60
- def classify(image):
61
- img = Image.fromarray(image).convert("RGB")
62
- x = transform(img).unsqueeze(0)
 
63
 
64
  with torch.no_grad():
65
- logits = model(x)
66
- probs = torch.softmax(logits[0], dim=0)
67
 
68
- top_probs, top_idxs = torch.topk(probs, 5)
69
- predictions = {
70
- CLASS_LABELS[top_idxs[i].item()]: float(top_probs[i].item())
71
- for i in range(5)
72
- }
73
 
74
- return predictions
75
 
76
- # --------------------------------------------------------------
77
- # GRADIO INTERFACE WITH API
78
- # --------------------------------------------------------------
79
  demo = gr.Interface(
80
- fn=classify,
81
  inputs=gr.Image(type="numpy"),
82
- outputs=gr.Label(num_top_classes=5),
83
- title="Plant Disease Classification (ResNet50)",
84
- description="Upload a leaf image to detect plant disease."
85
  )
86
 
87
  demo.launch()
 
1
  import torch
2
  import torch.nn as nn
3
+ from torchvision.models import resnet18, resnet34, resnet50
4
+ from torchvision import transforms
5
  from PIL import Image
6
  import gradio as gr
7
 
8
+ # -----------------------
9
+ # CLASS LABELS
10
+ # -----------------------
11
  CLASS_LABELS = [
12
  'Corn_Common_Rust', 'Corn_Gray_Leaf_Spot', 'Corn_Healthy', 'Corn_Northern_Leaf_Blight',
13
  'Potato_Early_Blight', 'Potato_Healthy', 'Potato_Late_Blight',
 
18
 
19
  NUM_CLASSES = len(CLASS_LABELS)
20
 
21
+
22
+ # -----------------------
23
+ # MODEL ARCHITECTURE
24
+ # -----------------------
25
  class ResNetPlantDisease(nn.Module):
26
+ def __init__(self, num_classes=17, model_name='resnet50', pretrained=False):
27
  super().__init__()
28
+
29
+ if model_name == 'resnet18':
30
+ self.backbone = resnet18(weights=None)
31
+ num_features = 512
32
+ elif model_name == 'resnet34':
33
+ self.backbone = resnet34(weights=None)
34
+ num_features = 512
35
+ elif model_name == 'resnet50':
36
+ self.backbone = resnet50(weights=None)
37
+ num_features = 2048
38
+ else:
39
+ raise ValueError("Unsupported model name")
40
+
41
+ self.backbone.fc = nn.Sequential(
42
  nn.Dropout(0.5),
43
+ nn.Linear(num_features, 512),
44
  nn.ReLU(),
45
  nn.Dropout(0.3),
46
  nn.Linear(512, num_classes)
47
  )
48
 
49
  def forward(self, x):
50
+ return self.backbone(x)
51
 
52
+
53
+ # -----------------------
54
  # LOAD MODEL
55
+ # -----------------------
56
+ model = ResNetPlantDisease(num_classes=NUM_CLASSES, model_name='resnet50')
57
+ state = torch.load("plant_disease_resnet_model.pth", map_location="cpu")
58
+ model.load_state_dict(state)
59
  model.eval()
60
 
61
+
62
+ # -----------------------
63
+ # TRANSFORMS
64
+ # -----------------------
65
  transform = transforms.Compose([
66
  transforms.Resize((224, 224)),
67
  transforms.ToTensor(),
 
71
  )
72
  ])
73
 
74
+
75
+ # -----------------------
76
+ # PREDICT FUNCTION
77
+ # -----------------------
78
+ def predict(image):
79
+ img = Image.fromarray(image)
80
+ img = transform(img).unsqueeze(0)
81
 
82
  with torch.no_grad():
83
+ outputs = model(img)
84
+ probs = torch.softmax(outputs, dim=1)[0]
85
 
86
+ result = {CLASS_LABELS[i]: float(probs[i]) for i in range(NUM_CLASSES)}
87
+ return result
 
 
 
88
 
 
89
 
90
+ # -----------------------
91
+ # GRADIO UI + API
92
+ # -----------------------
93
  demo = gr.Interface(
94
+ fn=predict,
95
  inputs=gr.Image(type="numpy"),
96
+ outputs=gr.Label(num_top_classes=3),
97
+ title="Plant Disease Detection - ResNet50",
98
+ description="Upload a leaf image to detect crop disease."
99
  )
100
 
101
  demo.launch()