Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import torchvision.transforms as T | |
| import torchvision.models as models | |
| from PIL import Image | |
| import os | |
| # ----------------------------- | |
| # Safe model loading | |
| # ----------------------------- | |
| possible_paths = [ | |
| "model/model.pth", | |
| "model.pth", | |
| "/app/model/model.pth", | |
| "/app/model.pth" | |
| ] | |
| model_path = None | |
| for p in possible_paths: | |
| if os.path.exists(p): | |
| model_path = p | |
| break | |
| if model_path is None: | |
| raise FileNotFoundError( | |
| "❌ model.pth not found. Upload it to /model/model.pth or root folder." | |
| ) | |
| checkpoint = torch.load(model_path, map_location="cpu") | |
| class_names = checkpoint["class_names"] | |
| # ----------------------------- | |
| # Load Model | |
| # ----------------------------- | |
| model = models.resnet50(pretrained=False) | |
| model.fc = nn.Linear(model.fc.in_features, len(class_names)) | |
| model.load_state_dict(checkpoint["model_state_dict"], strict=True) | |
| model.eval() | |
| # ----------------------------- | |
| # Image Preprocessing | |
| # ----------------------------- | |
| transform = T.Compose([ | |
| T.Resize((224,224)), | |
| T.ToTensor(), | |
| T.Normalize([0.485, 0.456, 0.406], | |
| [0.229, 0.224, 0.225]) | |
| ]) | |
| # ----------------------------- | |
| # Prediction Function | |
| # ----------------------------- | |
| def predict(img): | |
| img = transform(img).unsqueeze(0) | |
| with torch.no_grad(): | |
| outputs = model(img) | |
| probs = torch.softmax(outputs[0], dim=0) | |
| top3_probs, top3_idxs = torch.topk(probs, 3) | |
| result = {class_names[i]: float(top3_probs[idx]) | |
| for idx, i in enumerate(top3_idxs)} | |
| return result | |
| # ----------------------------- | |
| # Gradio Interface | |
| # ----------------------------- | |
| title = "🐾 Animal Classifier — ResNet50 Fine-Tuned" | |
| description = """ | |
| Upload an image of an animal and the model will predict what species it is. | |
| """ | |
| iface = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Label(num_top_classes=3), | |
| title=title, | |
| description=description, | |
| ) | |
| iface.launch() | |