Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| import gradio as gr | |
| import os | |
| class DeeperCNN(nn.Module): | |
| def __init__(self, num_classes): | |
| super(DeeperCNN, self).__init__() | |
| self.features = nn.Sequential( | |
| nn.Conv2d(3, 32, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(32, 32, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(2), | |
| nn.Dropout(0.25), | |
| nn.Conv2d(32, 64, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(64, 64, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(2), | |
| nn.Dropout(0.25), | |
| nn.Conv2d(64, 128, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(128, 128, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(0.25), | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.Flatten(), | |
| nn.Linear(128 * 32 * 32, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Linear(256, num_classes) | |
| ) | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = self.classifier(x) | |
| return x | |
| class_names = ["Eucalyptus globulus", "Pinus pinaster", "Quercus suber"] | |
| model_aliases = { | |
| "model.pth": "GBIF CNN Model", | |
| "resnet_model.pth": "ResNet Model" | |
| } | |
| transform = transforms.Compose([ | |
| transforms.Resize((128, 128)), | |
| transforms.ToTensor() | |
| ]) | |
| model = None | |
| num_classes = 3 | |
| def load_model(selected_model): | |
| global model | |
| model_path = os.path.join("models", selected_model) | |
| model_path = os.path.normpath(model_path) | |
| try: | |
| model = torch.load(model_path, map_location=torch.device("cpu"), weights_only=False) | |
| model.eval() | |
| alias = model_aliases.get(selected_model, selected_model) | |
| return f"{alias} loaded successfully." | |
| except FileNotFoundError: | |
| return f"Error: Model file not found at {model_path}" | |
| except Exception as e: | |
| return f"Error loading {model_aliases.get(selected_model, selected_model)}: {str(e)}" | |
| def predict(image_path): | |
| if model is None: | |
| return {}, "Please load a model first." | |
| image = Image.open(image_path).convert("RGB") | |
| image = transform(image).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(image) | |
| probs = F.softmax(outputs, dim=1)[0] | |
| probs_dict = {class_names[i]: float(probs[i]) for i in range(len(class_names))} | |
| sorted_probs = sorted(probs_dict.items(), key=lambda x: x[1], reverse=True) | |
| top1_class, top1_prob = sorted_probs[0] | |
| top2_class, top2_prob = sorted_probs[1] | |
| if top1_prob < 0.6 and (top1_prob - top2_prob) < 0.2: | |
| final_label = "Inconclusive, low confidence" | |
| else: | |
| final_label = top1_class | |
| return probs_dict, final_label | |
| model_selector = gr.Dropdown( | |
| choices=[(alias, filename) for filename, alias in model_aliases.items()], | |
| label="Select Model", | |
| value="resnet_model.pth" | |
| ) | |
| load_button = gr.Button("Load Model") | |
| image_input = gr.Image(type="filepath", interactive=True, label="Input Image", sources=["upload", "clipboard"], height=400) | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_selector.render() | |
| load_button.render() | |
| load_output = gr.Textbox(label="Model Status", interactive=False) | |
| gr.Examples( | |
| examples=[os.path.join("examples", f) for f in os.listdir("examples") if f.lower().endswith((".jpg", ".jpeg", ".png"))], | |
| inputs=image_input, | |
| label="Example Images", | |
| examples_per_page=9, | |
| elem_id="custom-example-gallery" | |
| ) | |
| with gr.Column(scale=1): | |
| image_input.render() | |
| predict_button = gr.Button("Predict") | |
| with gr.Column(scale=1): | |
| label_output = gr.Label(label="Prediction Probabilities") | |
| final_output = gr.Textbox(label="Final Prediction", interactive=False) | |
| gr.HTML(""" | |
| <style> | |
| #custom-example-gallery img { | |
| width: 150px !important; | |
| height: 150px !important; | |
| object-fit: cover !important; | |
| } | |
| </style> | |
| """) | |
| load_button.click(fn=load_model, inputs=model_selector, outputs=load_output) | |
| predict_button.click(fn=predict, inputs=image_input, outputs=[label_output, final_output]) | |
| demo.launch() |