import torch import torch.nn as nn import torch.nn.functional as F import torchvision.transforms as transforms from PIL import Image import gradio as gr import os class DeeperCNN(nn.Module): def __init__(self, num_classes): super(DeeperCNN, self).__init__() self.features = nn.Sequential( nn.Conv2d(3, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(2), nn.Dropout(0.25), nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2), nn.Dropout(0.25), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Dropout(0.25), ) self.classifier = nn.Sequential( nn.Flatten(), nn.Linear(128 * 32 * 32, 256), nn.ReLU(), nn.Dropout(0.5), nn.Linear(256, num_classes) ) def forward(self, x): x = self.features(x) x = self.classifier(x) return x class_names = ["Eucalyptus globulus", "Pinus pinaster", "Quercus suber"] model_aliases = { "model.pth": "GBIF CNN Model", "resnet_model.pth": "ResNet Model" } transform = transforms.Compose([ transforms.Resize((128, 128)), transforms.ToTensor() ]) model = None num_classes = 3 def load_model(selected_model): global model model_path = os.path.join("models", selected_model) model_path = os.path.normpath(model_path) try: model = torch.load(model_path, map_location=torch.device("cpu"), weights_only=False) model.eval() alias = model_aliases.get(selected_model, selected_model) return f"{alias} loaded successfully." except FileNotFoundError: return f"Error: Model file not found at {model_path}" except Exception as e: return f"Error loading {model_aliases.get(selected_model, selected_model)}: {str(e)}" def predict(image_path): if model is None: return {}, "Please load a model first." image = Image.open(image_path).convert("RGB") image = transform(image).unsqueeze(0) with torch.no_grad(): outputs = model(image) probs = F.softmax(outputs, dim=1)[0] probs_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))} sorted_probs = sorted(probs_dict.items(), key=lambda x: x[1], reverse=True) top1_class, top1_prob = sorted_probs[0] top2_class, top2_prob = sorted_probs[1] if top1_prob < 0.6 and (top1_prob - top2_prob) < 0.2: final_label = "Inconclusive, low confidence" else: final_label = top1_class return probs_dict, final_label model_selector = gr.Dropdown( choices=[(alias, filename) for filename, alias in model_aliases.items()], label="Select Model", value="resnet_model.pth" ) load_button = gr.Button("Load Model") image_input = gr.Image(type="filepath", interactive=True, label="Input Image", sources=["upload", "clipboard"], height=400) with gr.Blocks() as demo: with gr.Row(): with gr.Column(scale=1): model_selector.render() load_button.render() load_output = gr.Textbox(label="Model Status", interactive=False) gr.Examples( examples=[os.path.join("examples", f) for f in os.listdir("examples") if f.lower().endswith((".jpg", ".jpeg", ".png"))], inputs=image_input, label="Example Images", examples_per_page=9, elem_id="custom-example-gallery" ) with gr.Column(scale=1): image_input.render() predict_button = gr.Button("Predict") with gr.Column(scale=1): label_output = gr.Label(label="Prediction Probabilities") final_output = gr.Textbox(label="Final Prediction", interactive=False) gr.HTML(""" """) load_button.click(fn=load_model, inputs=model_selector, outputs=load_output) predict_button.click(fn=predict, inputs=image_input, outputs=[label_output, final_output]) demo.launch()