which-cat / app.py
khasinski's picture
Add Gradio app
d16aaa8 verified
#!/usr/bin/env python3
"""Gradio interface for the cat classifier."""
import gradio as gr
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
from pathlib import Path
MODEL_PATH = "cat_classifier.pth"
IMAGE_SIZE = 224
DEVICE = (
"mps" if torch.backends.mps.is_available()
else "cuda" if torch.cuda.is_available()
else "cpu"
)
# Load model once at startup
def load_model():
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False)
class_to_idx = checkpoint['class_to_idx']
idx_to_class = {v: k for k, v in class_to_idx.items()}
model = models.efficientnet_b0(weights=None)
num_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
nn.Dropout(p=0.3),
nn.Linear(num_features, len(class_to_idx))
)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(DEVICE)
model.eval()
return model, idx_to_class
print(f"Loading model on {DEVICE}...")
model, idx_to_class = load_model()
print("Model loaded!")
# Image transform
transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def predict(image):
"""Predict which cat is in the image."""
if image is None:
return None
# Convert to PIL if needed
if not isinstance(image, Image.Image):
image = Image.fromarray(image)
image = image.convert('RGB')
tensor = transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
outputs = model(tensor)
probs = torch.softmax(outputs, dim=1)[0]
return {idx_to_class[i].capitalize(): float(probs[i]) for i in range(len(idx_to_class))}
# Build interface
demo = gr.Interface(
fn=predict,
inputs=gr.Image(type="pil", label="Upload a cat photo"),
outputs=gr.Label(num_top_classes=2, label="Prediction"),
title="Which Cat? - Lucy vs Madelaine",
description="Upload a photo to identify if it's Lucy or Madelaine!",
examples=[
["cats/lucy/IMG_7066.jpeg"],
["cats/madelaine/IMG_2730.jpeg"],
] if Path("cats").exists() else None,
allow_flagging="never",
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")