SoraRyuu commited on
Commit
0c6c8ff
·
verified ·
1 Parent(s): 9a7fc09

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -30
app.py CHANGED
@@ -1,13 +1,15 @@
1
  import torch
2
  import torch.nn as nn
3
- from torchvision.models import resnet18, resnet34, resnet50
4
  from torchvision import transforms
5
  from PIL import Image
6
  import gradio as gr
7
 
8
- # -----------------------
9
- # CLASS LABELS
10
- # -----------------------
 
 
11
  CLASS_LABELS = [
12
  'Corn_Common_Rust', 'Corn_Gray_Leaf_Spot', 'Corn_Healthy', 'Corn_Northern_Leaf_Blight',
13
  'Potato_Early_Blight', 'Potato_Healthy', 'Potato_Late_Blight',
@@ -18,21 +20,11 @@ CLASS_LABELS = [
18
 
19
  NUM_CLASSES = len(CLASS_LABELS)
20
 
21
-
22
- # -----------------------
23
- # MODEL ARCHITECTURE
24
- # -----------------------
25
  class ResNetPlantDisease(nn.Module):
26
  def __init__(self, num_classes=17, model_name='resnet50', pretrained=False):
27
  super().__init__()
28
 
29
- if model_name == 'resnet18':
30
- self.backbone = resnet18(weights=None)
31
- num_features = 512
32
- elif model_name == 'resnet34':
33
- self.backbone = resnet34(weights=None)
34
- num_features = 512
35
- elif model_name == 'resnet50':
36
  self.backbone = resnet50(weights=None)
37
  num_features = 2048
38
  else:
@@ -50,18 +42,11 @@ class ResNetPlantDisease(nn.Module):
50
  return self.backbone(x)
51
 
52
 
53
- # -----------------------
54
- # LOAD MODEL
55
- # -----------------------
56
  model = ResNetPlantDisease(num_classes=NUM_CLASSES, model_name='resnet50')
57
  state = torch.load("plant_disease_resnet_model.pth", map_location="cpu")
58
  model.load_state_dict(state)
59
  model.eval()
60
 
61
-
62
- # -----------------------
63
- # TRANSFORMS
64
- # -----------------------
65
  transform = transforms.Compose([
66
  transforms.Resize((224, 224)),
67
  transforms.ToTensor(),
@@ -71,10 +56,6 @@ transform = transforms.Compose([
71
  )
72
  ])
73
 
74
-
75
- # -----------------------
76
- # PREDICT FUNCTION
77
- # -----------------------
78
  def predict(image):
79
  img = Image.fromarray(image)
80
  img = transform(img).unsqueeze(0)
@@ -86,10 +67,6 @@ def predict(image):
86
  result = {CLASS_LABELS[i]: float(probs[i]) for i in range(NUM_CLASSES)}
87
  return result
88
 
89
-
90
- # -----------------------
91
- # GRADIO UI + API
92
- # -----------------------
93
  demo = gr.Interface(
94
  fn=predict,
95
  inputs=gr.Image(type="numpy"),
 
1
  import torch
2
  import torch.nn as nn
3
+ from torchvision.models import resnet50
4
  from torchvision import transforms
5
  from PIL import Image
6
  import gradio as gr
7
 
8
+ '''
9
+ This is only the first model, the one from CV, input an image, output will be top three results,
10
+ will later change to return best results only
11
+ '''
12
+
13
  CLASS_LABELS = [
14
  'Corn_Common_Rust', 'Corn_Gray_Leaf_Spot', 'Corn_Healthy', 'Corn_Northern_Leaf_Blight',
15
  'Potato_Early_Blight', 'Potato_Healthy', 'Potato_Late_Blight',
 
20
 
21
  NUM_CLASSES = len(CLASS_LABELS)
22
 
 
 
 
 
23
  class ResNetPlantDisease(nn.Module):
24
  def __init__(self, num_classes=17, model_name='resnet50', pretrained=False):
25
  super().__init__()
26
 
27
+ if model_name == 'resnet50':
 
 
 
 
 
 
28
  self.backbone = resnet50(weights=None)
29
  num_features = 2048
30
  else:
 
42
  return self.backbone(x)
43
 
44
 
 
 
 
45
  model = ResNetPlantDisease(num_classes=NUM_CLASSES, model_name='resnet50')
46
  state = torch.load("plant_disease_resnet_model.pth", map_location="cpu")
47
  model.load_state_dict(state)
48
  model.eval()
49
 
 
 
 
 
50
  transform = transforms.Compose([
51
  transforms.Resize((224, 224)),
52
  transforms.ToTensor(),
 
56
  )
57
  ])
58
 
 
 
 
 
59
  def predict(image):
60
  img = Image.fromarray(image)
61
  img = transform(img).unsqueeze(0)
 
67
  result = {CLASS_LABELS[i]: float(probs[i]) for i in range(NUM_CLASSES)}
68
  return result
69
 
 
 
 
 
70
  demo = gr.Interface(
71
  fn=predict,
72
  inputs=gr.Image(type="numpy"),