ma4389's picture
Upload 3 files
44cc35d verified
import torch
import torch.nn as nn
from torchvision import models, transforms
from PIL import Image
import gradio as gr
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Classes
class_names = ["Cat", "Dog"]
# Load model (architecture same as training)
def load_model(model_path="pet_model.pth"):
base_model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
in_features = base_model.fc.in_features
base_model.fc = nn.Sequential(
nn.Linear(in_features, 512),
nn.ReLU(),
nn.Dropout(0.4),
nn.Linear(512, len(class_names))
)
base_model.load_state_dict(torch.load(model_path, map_location=device))
base_model.to(device)
base_model.eval()
return base_model
model = load_model()
# EXACT same transform as training
transform = transforms.Compose([
transforms.Lambda(lambda x: x.convert('RGB')),
transforms.Resize((224,224)),
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.ColorJitter(),
transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
])
# Prediction function
def predict(img):
img_tensor = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
outputs = model(img_tensor)
probs = torch.softmax(outputs, dim=1)
return {class_names[i]: float(probs[0][i]) for i in range(len(class_names))}
# Gradio interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Label(num_top_classes=2, label="Prediction"),
title="🐱🐶 Cat vs Dog Classifier",
description="Upload a picture of a cat or a dog. Model was trained with RandomRotation and ColorJitter on all images.",
)
if __name__ == "__main__":
demo.launch()