SoraRyuu commited on
Commit
381e2e8
·
verified ·
1 Parent(s): 503df44

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -30
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import torch
2
  import torch.nn as nn
3
- from torchvision import models, transforms
4
  from PIL import Image
5
  import gradio as gr
6
 
7
- # ---------------------------------------------
8
  # CLASS LABELS (17 classes)
9
- # ---------------------------------------------
10
  CLASS_LABELS = [
11
  'Corn_Common_Rust', 'Corn_Gray_Leaf_Spot', 'Corn_Healthy', 'Corn_Northern_Leaf_Blight',
12
  'Potato_Early_Blight', 'Potato_Healthy', 'Potato_Late_Blight',
@@ -17,27 +17,36 @@ CLASS_LABELS = [
17
 
18
  NUM_CLASSES = len(CLASS_LABELS)
19
 
20
- # ---------------------------------------------
21
- # Load Your ResNet50 Model
22
- # ---------------------------------------------
23
- class PlantDiseaseResNet(nn.Module):
24
- def __init__(self, num_classes):
25
  super().__init__()
26
  self.model = models.resnet50(weights=None)
27
- self.model.fc = nn.Linear(2048, num_classes)
 
 
 
 
 
 
28
 
29
  def forward(self, x):
30
  return self.model(x)
31
 
32
- model = PlantDiseaseResNet(NUM_CLASSES)
 
 
 
33
  model.load_state_dict(torch.load("plant_disease_resnet_model.pth", map_location="cpu"))
34
  model.eval()
35
 
36
- # ---------------------------------------------
37
- # Image Transform (224x224)
38
- # ---------------------------------------------
39
  transform = transforms.Compose([
40
- transforms.Resize((224, 224)), # Your image size
41
  transforms.ToTensor(),
42
  transforms.Normalize(
43
  mean=[0.485, 0.456, 0.406],
@@ -45,29 +54,34 @@ transform = transforms.Compose([
45
  )
46
  ])
47
 
48
- # ---------------------------------------------
49
- # Prediction Function
50
- # ---------------------------------------------
51
- def predict(image):
52
- img = Image.fromarray(image)
53
  x = transform(img).unsqueeze(0)
54
 
55
  with torch.no_grad():
56
  logits = model(x)
57
- probs = logits.softmax(dim=1)[0]
58
 
59
- result = {CLASS_LABELS[i]: float(probs[i]) for i in range(NUM_CLASSES)}
60
- return result
 
 
 
61
 
62
- # ---------------------------------------------
63
- # Gradio Interface + HuggingFace API
64
- # ---------------------------------------------
 
 
65
  demo = gr.Interface(
66
- fn=predict,
67
  inputs=gr.Image(type="numpy"),
68
- outputs=gr.Label(num_top_classes=3),
69
- title="Plant Disease Detector (ResNet50)",
70
- description="Upload a crop leaf image to predict its disease."
71
  )
72
 
73
- demo.launch()
 
1
  import torch
2
  import torch.nn as nn
3
+ from torchvision import transforms, models
4
  from PIL import Image
5
  import gradio as gr
6
 
7
+ # --------------------------------------------------------------
8
  # CLASS LABELS (17 classes)
9
+ # --------------------------------------------------------------
10
  CLASS_LABELS = [
11
  'Corn_Common_Rust', 'Corn_Gray_Leaf_Spot', 'Corn_Healthy', 'Corn_Northern_Leaf_Blight',
12
  'Potato_Early_Blight', 'Potato_Healthy', 'Potato_Late_Blight',
 
17
 
18
  NUM_CLASSES = len(CLASS_LABELS)
19
 
20
+ # --------------------------------------------------------------
21
+ # RESNET50 MODEL DEFINITION
22
+ # --------------------------------------------------------------
23
+ class ResNetPlantDisease(nn.Module):
24
+ def __init__(self, num_classes=17):
25
  super().__init__()
26
  self.model = models.resnet50(weights=None)
27
+ self.model.fc = nn.Sequential(
28
+ nn.Dropout(0.5),
29
+ nn.Linear(2048, 512),
30
+ nn.ReLU(),
31
+ nn.Dropout(0.3),
32
+ nn.Linear(512, num_classes)
33
+ )
34
 
35
  def forward(self, x):
36
  return self.model(x)
37
 
38
+ # --------------------------------------------------------------
39
+ # LOAD MODEL
40
+ # --------------------------------------------------------------
41
+ model = ResNetPlantDisease(num_classes=NUM_CLASSES)
42
  model.load_state_dict(torch.load("plant_disease_resnet_model.pth", map_location="cpu"))
43
  model.eval()
44
 
45
+ # --------------------------------------------------------------
46
+ # PREPROCESSING
47
+ # --------------------------------------------------------------
48
  transform = transforms.Compose([
49
+ transforms.Resize((224, 224)),
50
  transforms.ToTensor(),
51
  transforms.Normalize(
52
  mean=[0.485, 0.456, 0.406],
 
54
  )
55
  ])
56
 
57
+ # --------------------------------------------------------------
58
+ # DISEASE CLASSIFICATION FUNCTION
59
+ # --------------------------------------------------------------
60
+ def classify(image):
61
+ img = Image.fromarray(image).convert("RGB")
62
  x = transform(img).unsqueeze(0)
63
 
64
  with torch.no_grad():
65
  logits = model(x)
66
+ probs = torch.softmax(logits[0], dim=0)
67
 
68
+ top_probs, top_idxs = torch.topk(probs, 5)
69
+ predictions = {
70
+ CLASS_LABELS[top_idxs[i].item()]: float(top_probs[i].item())
71
+ for i in range(5)
72
+ }
73
 
74
+ return predictions
75
+
76
+ # --------------------------------------------------------------
77
+ # GRADIO INTERFACE WITH API
78
+ # --------------------------------------------------------------
79
  demo = gr.Interface(
80
+ fn=classify,
81
  inputs=gr.Image(type="numpy"),
82
+ outputs=gr.Label(num_top_classes=5),
83
+ title="Plant Disease Classification (ResNet50)",
84
+ description="Upload a leaf image to detect plant disease."
85
  )
86
 
87
+ demo.launch()