File size: 4,823 Bytes
58a2e7e
 
 
 
 
fa566c5
58a2e7e
fa566c5
58a2e7e
 
 
 
 
 
 
 
 
 
 
4d6400d
58a2e7e
 
 
 
 
 
 
 
4d6400d
58a2e7e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4453b5
 
 
d1afbc4
e4453b5
 
 
 
 
 
 
 
 
4d6400d
e4453b5
58a2e7e
 
 
 
 
 
4d6400d
58a2e7e
 
e4453b5
 
58a2e7e
 
 
4d6400d
58a2e7e
 
4d6400d
 
58a2e7e
4d6400d
58a2e7e
 
 
4d6400d
58a2e7e
4d6400d
 
 
 
 
 
58a2e7e
4d6400d
 
58a2e7e
 
 
 
 
77828cb
58a2e7e
 
 
4d6400d
58a2e7e
 
 
 
 
 
4d6400d
 
 
2204b1c
4d6400d
 
b60a2aa
2204b1c
4d6400d
58a2e7e
 
4d6400d
58a2e7e
4d6400d
58a2e7e
4d6400d
58a2e7e
 
b60a2aa
 
 
 
 
 
 
 
 
2204b1c
58a2e7e
 
 
e4453b5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
import os

class DeeperCNN(nn.Module):
    def __init__(self, num_classes):
        super(DeeperCNN, self).__init__()
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.Conv2d(32, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout(0.25),

            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(2),
            nn.Dropout(0.25),

            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Conv2d(128, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(inplace=True),
            nn.Dropout(0.25),
        )

        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(128 * 32 * 32, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

class_names = ["Eucalyptus globulus", "Pinus pinaster", "Quercus suber"]

model_aliases = {
    "model.pth": "GBIF CNN Model",
    "resnet_model.pth": "ResNet Model"
}

transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor()
])

model = None
num_classes = 3

def load_model(selected_model):
    global model
    model_path = os.path.join("models", selected_model)
    model_path = os.path.normpath(model_path)
    try:
        model = torch.load(model_path, map_location=torch.device("cpu"), weights_only=False)
        model.eval()
        alias = model_aliases.get(selected_model, selected_model)
        return f"{alias} loaded successfully."
    except FileNotFoundError:
        return f"Error: Model file not found at {model_path}"
    except Exception as e:
        return f"Error loading {model_aliases.get(selected_model, selected_model)}: {str(e)}"

def predict(image_path):
    if model is None:
        return {}, "Please load a model first."
    
    image = Image.open(image_path).convert("RGB")
    image = transform(image).unsqueeze(0)
    
    with torch.no_grad():
        outputs = model(image)
        probs = F.softmax(outputs, dim=1)[0]

        probs_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
        sorted_probs = sorted(probs_dict.items(), key=lambda x: x[1], reverse=True)
        top1_class, top1_prob = sorted_probs[0]
        top2_class, top2_prob = sorted_probs[1]

        if top1_prob < 0.6 and (top1_prob - top2_prob) < 0.2:
            final_label = "Inconclusive, low confidence"
        else:
            final_label = top1_class

        return probs_dict, final_label

model_selector = gr.Dropdown(
    choices=[(alias, filename) for filename, alias in model_aliases.items()],
    label="Select Model",
    value="resnet_model.pth"
)

load_button = gr.Button("Load Model")
image_input = gr.Image(type="filepath", interactive=True, label="Input Image", sources=["upload", "clipboard"], height=400)

with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column(scale=1):
            model_selector.render()
            load_button.render()
            load_output = gr.Textbox(label="Model Status", interactive=False)

            gr.Examples(
                examples=[os.path.join("examples", f) for f in os.listdir("examples") if f.lower().endswith((".jpg", ".jpeg", ".png"))],
                inputs=image_input,
                label="Example Images",
                examples_per_page=9,
                elem_id="custom-example-gallery"
            )
        
        with gr.Column(scale=1):
            image_input.render()
            predict_button = gr.Button("Predict")
        
        with gr.Column(scale=1):
            label_output = gr.Label(label="Prediction Probabilities")
            final_output = gr.Textbox(label="Final Prediction", interactive=False)

    gr.HTML("""
    <style>
    #custom-example-gallery img {
        width: 150px !important;
        height: 150px !important;
        object-fit: cover !important;
    }
    </style>
    """)

    load_button.click(fn=load_model, inputs=model_selector, outputs=load_output)
    predict_button.click(fn=predict, inputs=image_input, outputs=[label_output, final_output])

demo.launch()