import gradio as gr import torch import torch.nn as nn import torchvision.transforms as T import torchvision.models as models from PIL import Image import os # ----------------------------- # Safe model loading # ----------------------------- possible_paths = [ "model/model.pth", "model.pth", "/app/model/model.pth", "/app/model.pth" ] model_path = None for p in possible_paths: if os.path.exists(p): model_path = p break if model_path is None: raise FileNotFoundError( "❌ model.pth not found. Upload it to /model/model.pth or root folder." ) checkpoint = torch.load(model_path, map_location="cpu") class_names = checkpoint["class_names"] # ----------------------------- # Load Model # ----------------------------- model = models.resnet50(pretrained=False) model.fc = nn.Linear(model.fc.in_features, len(class_names)) model.load_state_dict(checkpoint["model_state_dict"], strict=True) model.eval() # ----------------------------- # Image Preprocessing # ----------------------------- transform = T.Compose([ T.Resize((224,224)), T.ToTensor(), T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # ----------------------------- # Prediction Function # ----------------------------- def predict(img): img = transform(img).unsqueeze(0) with torch.no_grad(): outputs = model(img) probs = torch.softmax(outputs[0], dim=0) top3_probs, top3_idxs = torch.topk(probs, 3) result = {class_names[i]: float(top3_probs[idx]) for idx, i in enumerate(top3_idxs)} return result # ----------------------------- # Gradio Interface # ----------------------------- title = "🐾 Animal Classifier — ResNet50 Fine-Tuned" description = """ Upload an image of an animal and the model will predict what species it is. """ iface = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs=gr.Label(num_top_classes=3), title=title, description=description, ) iface.launch()