Tree-Identifier / app.py
Rafa-bork's picture
Update app.py
d1afbc4 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from PIL import Image
import gradio as gr
import os
class DeeperCNN(nn.Module):
def __init__(self, num_classes):
super(DeeperCNN, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Dropout(0.25),
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2),
nn.Dropout(0.25),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Dropout(0.25),
)
self.classifier = nn.Sequential(
nn.Flatten(),
nn.Linear(128 * 32 * 32, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, num_classes)
)
def forward(self, x):
x = self.features(x)
x = self.classifier(x)
return x
class_names = ["Eucalyptus globulus", "Pinus pinaster", "Quercus suber"]
model_aliases = {
"model.pth": "GBIF CNN Model",
"resnet_model.pth": "ResNet Model"
}
transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor()
])
model = None
num_classes = 3
def load_model(selected_model):
global model
model_path = os.path.join("models", selected_model)
model_path = os.path.normpath(model_path)
try:
model = torch.load(model_path, map_location=torch.device("cpu"), weights_only=False)
model.eval()
alias = model_aliases.get(selected_model, selected_model)
return f"{alias} loaded successfully."
except FileNotFoundError:
return f"Error: Model file not found at {model_path}"
except Exception as e:
return f"Error loading {model_aliases.get(selected_model, selected_model)}: {str(e)}"
def predict(image_path):
if model is None:
return {}, "Please load a model first."
image = Image.open(image_path).convert("RGB")
image = transform(image).unsqueeze(0)
with torch.no_grad():
outputs = model(image)
probs = F.softmax(outputs, dim=1)[0]
probs_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))}
sorted_probs = sorted(probs_dict.items(), key=lambda x: x[1], reverse=True)
top1_class, top1_prob = sorted_probs[0]
top2_class, top2_prob = sorted_probs[1]
if top1_prob < 0.6 and (top1_prob - top2_prob) < 0.2:
final_label = "Inconclusive, low confidence"
else:
final_label = top1_class
return probs_dict, final_label
model_selector = gr.Dropdown(
choices=[(alias, filename) for filename, alias in model_aliases.items()],
label="Select Model",
value="resnet_model.pth"
)
load_button = gr.Button("Load Model")
image_input = gr.Image(type="filepath", interactive=True, label="Input Image", sources=["upload", "clipboard"], height=400)
with gr.Blocks() as demo:
with gr.Row():
with gr.Column(scale=1):
model_selector.render()
load_button.render()
load_output = gr.Textbox(label="Model Status", interactive=False)
gr.Examples(
examples=[os.path.join("examples", f) for f in os.listdir("examples") if f.lower().endswith((".jpg", ".jpeg", ".png"))],
inputs=image_input,
label="Example Images",
examples_per_page=9,
elem_id="custom-example-gallery"
)
with gr.Column(scale=1):
image_input.render()
predict_button = gr.Button("Predict")
with gr.Column(scale=1):
label_output = gr.Label(label="Prediction Probabilities")
final_output = gr.Textbox(label="Final Prediction", interactive=False)
gr.HTML("""
<style>
#custom-example-gallery img {
width: 150px !important;
height: 150px !important;
object-fit: cover !important;
}
</style>
""")
load_button.click(fn=load_model, inputs=model_selector, outputs=load_output)
predict_button.click(fn=predict, inputs=image_input, outputs=[label_output, final_output])
demo.launch()