asad9641's picture
Update app.py
bc58c61 verified
import gradio as gr
import torch
import torch.nn as nn
import torchvision.transforms as T
import torchvision.models as models
from PIL import Image
import os
# -----------------------------
# Safe model loading
# -----------------------------
possible_paths = [
"model/model.pth",
"model.pth",
"/app/model/model.pth",
"/app/model.pth"
]
model_path = None
for p in possible_paths:
if os.path.exists(p):
model_path = p
break
if model_path is None:
raise FileNotFoundError(
"❌ model.pth not found. Upload it to /model/model.pth or root folder."
)
checkpoint = torch.load(model_path, map_location="cpu")
class_names = checkpoint["class_names"]
# -----------------------------
# Load Model
# -----------------------------
model = models.resnet50(pretrained=False)
model.fc = nn.Linear(model.fc.in_features, len(class_names))
model.load_state_dict(checkpoint["model_state_dict"], strict=True)
model.eval()
# -----------------------------
# Image Preprocessing
# -----------------------------
transform = T.Compose([
T.Resize((224,224)),
T.ToTensor(),
T.Normalize([0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])
])
# -----------------------------
# Prediction Function
# -----------------------------
def predict(img):
img = transform(img).unsqueeze(0)
with torch.no_grad():
outputs = model(img)
probs = torch.softmax(outputs[0], dim=0)
top3_probs, top3_idxs = torch.topk(probs, 3)
result = {class_names[i]: float(top3_probs[idx])
for idx, i in enumerate(top3_idxs)}
return result
# -----------------------------
# Gradio Interface
# -----------------------------
title = "🐾 Animal Classifier — ResNet50 Fine-Tuned"
description = """
Upload an image of an animal and the model will predict what species it is.
"""
iface = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Label(num_top_classes=3),
title=title,
description=description,
)
iface.launch()