SoraRyuu commited on
Commit
163c547
·
verified ·
1 Parent(s): 78539e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -10
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  import torch.nn as nn
3
- from torchvision.models import resnet18, resnet34, resnet50
4
  from torchvision import transforms
5
  from PIL import Image
6
  import gradio as gr
@@ -24,13 +24,7 @@ class ResNetPlantDisease(nn.Module):
24
  def __init__(self, num_classes=17, model_name='resnet50', pretrained=False):
25
  super().__init__()
26
 
27
- if model_name == 'resnet18':
28
- self.backbone = resnet18(weights=None)
29
- num_features = 512
30
- elif model_name == 'resnet34':
31
- self.backbone = resnet34(weights=None)
32
- num_features = 512
33
- elif model_name == 'resnet50':
34
  self.backbone = resnet50(weights=None)
35
  num_features = 2048
36
  else:
@@ -75,8 +69,8 @@ def predict(image):
75
  demo = gr.Interface(
76
  fn=predict,
77
  inputs=gr.Image(type="numpy"),
78
- outputs=gr.Label(num_top_classes=3),
79
- title="Plant Disease Detection - ResNet50",
80
  description="Upload a leaf image to detect crop disease."
81
  )
82
 
 
1
  import torch
2
  import torch.nn as nn
3
+ from torchvision.models import resnet50
4
  from torchvision import transforms
5
  from PIL import Image
6
  import gradio as gr
 
24
  def __init__(self, num_classes=17, model_name='resnet50', pretrained=False):
25
  super().__init__()
26
 
27
+ if model_name == 'resnet50':
 
 
 
 
 
 
28
  self.backbone = resnet50(weights=None)
29
  num_features = 2048
30
  else:
 
69
  demo = gr.Interface(
70
  fn=predict,
71
  inputs=gr.Image(type="numpy"),
72
+ outputs=gr.Label(num_top_classes=1),
73
+ title="Crop Disease Detection - ResNet",
74
  description="Upload a leaf image to detect crop disease."
75
  )
76