Rafa-bork commited on
Commit
58a2e7e
·
1 Parent(s): fa566c5
Files changed (4) hide show
  1. app.py +141 -4
  2. models/model.pth +3 -0
  3. models/resnet_model.pth +3 -0
  4. requirements.txt +6 -0
app.py CHANGED
@@ -1,7 +1,144 @@
 
 
 
 
 
1
  import gradio as gr
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision.transforms as transforms
5
+ from PIL import Image
6
  import gradio as gr
7
+ import os
8
 
9
+ class_names = ["Pinus pinaster", "Eucalyptus globulus", "Quercus suber"]
 
10
 
11
+ model_aliases = {
12
+ "model.pth": "GBIF Model",
13
+ "resnet_model.pth": "ResNet Model"
14
+ }
15
+
16
+ # Image transformations
17
+ transform = transforms.Compose([
18
+ transforms.Resize((128, 128)),
19
+ transforms.ToTensor()
20
+ ])
21
+
22
+ # ResNet model class
23
+ class ResnetModel(torch.nn.Module):
24
+ def __init__(self):
25
+ super().__init__()
26
+ self.model = torch.hub.load('pytorch/vision', 'resnet18', pretrained=False)
27
+ self.model.fc = torch.nn.Linear(self.model.fc.in_features, len(class_names))
28
+
29
+ def forward(self, x):
30
+ return self.model(x)
31
+
32
+ # Global variable for the loaded model
33
+ model = None
34
+
35
+ # Custom CNN model class
36
+ class DeeperCNN(nn.Module):
37
+ def __init__(self, num_classes):
38
+ super(DeeperCNN, self).__init__()
39
+
40
+ self.features = nn.Sequential(
41
+ nn.Conv2d(3, 32, kernel_size=3, padding=1),
42
+ nn.BatchNorm2d(32),
43
+ nn.ReLU(inplace=True),
44
+ nn.Conv2d(32, 32, kernel_size=3, padding=1),
45
+ nn.BatchNorm2d(32),
46
+ nn.ReLU(inplace=True),
47
+ nn.MaxPool2d(2), # Downsample from 128 to 64
48
+ nn.Dropout(0.25),
49
+
50
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
51
+ nn.BatchNorm2d(64),
52
+ nn.ReLU(inplace=True),
53
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
54
+ nn.BatchNorm2d(64),
55
+ nn.ReLU(inplace=True),
56
+ nn.MaxPool2d(2), # Downsample from 64 to 32
57
+ nn.Dropout(0.25),
58
+
59
+ nn.Conv2d(64, 128, kernel_size=3, padding=1),
60
+ nn.BatchNorm2d(128),
61
+ nn.ReLU(inplace=True),
62
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
63
+ nn.BatchNorm2d(128),
64
+ nn.ReLU(inplace=True),
65
+ nn.Dropout(0.25),
66
+ )
67
+
68
+ self.classifier = nn.Sequential(
69
+ nn.Flatten(),
70
+ nn.Linear(128 * 32 * 32, 256),
71
+ nn.ReLU(),
72
+ nn.Dropout(0.5),
73
+ nn.Linear(256, num_classes)
74
+ )
75
+
76
+ def forward(self, x):
77
+ x = self.features(x)
78
+ x = self.classifier(x)
79
+ return x
80
+
81
+ # Number of classes in your task
82
+ num_classes = 3
83
+
84
+ # Function to load model given a selected model filename
85
+ def load_model(selected_model):
86
+ global model
87
+ model_path = os.path.join("models", selected_model)
88
+ model_path = os.path.normpath(model_path)
89
+ try:
90
+ model = torch.load(model_path, map_location=torch.device("cpu"), weights_only=False)
91
+ model.eval()
92
+ alias = model_aliases.get(selected_model, selected_model)
93
+ return f"{alias} loaded successfully."
94
+ except Exception as e:
95
+ return f"Error loading {model_aliases.get(selected_model, selected_model)}: {str(e)}"
96
+
97
+ # Prediction function
98
+ def predict(image):
99
+ if model is None:
100
+ return {}, "Please load a model first."
101
+ image = transform(image).unsqueeze(0)
102
+ with torch.no_grad():
103
+ outputs = model(image)
104
+ probs = F.softmax(outputs, dim=1)[0]
105
+ probs_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
106
+ max_prob = max(probs_dict.values())
107
+ if max_prob < 0.6:
108
+ final_label = "Unknown"
109
+ else:
110
+ final_label = max(probs_dict, key=probs_dict.get)
111
+ return probs_dict, final_label
112
+
113
+ # Gradio UI elements
114
+ model_selector = gr.Dropdown(
115
+ choices=[(alias, filename) for filename, alias in model_aliases.items()],
116
+ label="Select Model",
117
+ value="model.pth"
118
+ )
119
+
120
+ load_button = gr.Button("Load Model")
121
+
122
+ with gr.Blocks() as demo:
123
+ with gr.Row():
124
+ # Left column: model selector, load button, load status
125
+ with gr.Column(scale=1):
126
+ model_selector.render()
127
+ load_button.render()
128
+ load_output = gr.Textbox(label="Model Status")
129
+
130
+ # Center column: image input and predict button
131
+ with gr.Column(scale=1):
132
+ image_input = gr.Image(type="pil", interactive=True, label="", height=400)
133
+ predict_button = gr.Button("Predict")
134
+
135
+ # Right column: prediction output label
136
+ with gr.Column(scale=1):
137
+ label_output = gr.Label(label="")
138
+ final_output = gr.Textbox(label="Final Prediction", interactive=False)
139
+
140
+ # Events
141
+ load_button.click(fn=load_model, inputs=model_selector, outputs=load_output)
142
+ predict_button.click(fn=predict, inputs=image_input, outputs=[label_output, final_output])
143
+
144
+ demo.launch()
models/model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8f89dc5b6b5f4058c1af905a717c69951944b94df4d90cd2aa2a4b3acb89149
3
+ size 135396465
models/resnet_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1534522afd5b25113958edf021815101109ecc412817bb50e4d2128ba736c457
3
+ size 44803339
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torch.nn
3
+ torchvision
4
+ gradio
5
+ Pillow
6
+ os