#!/usr/bin/env python3 """Gradio interface for the cat classifier.""" import gradio as gr import torch import torch.nn as nn from torchvision import transforms, models from PIL import Image from pathlib import Path MODEL_PATH = "cat_classifier.pth" IMAGE_SIZE = 224 DEVICE = ( "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" ) # Load model once at startup def load_model(): checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) class_to_idx = checkpoint['class_to_idx'] idx_to_class = {v: k for k, v in class_to_idx.items()} model = models.efficientnet_b0(weights=None) num_features = model.classifier[1].in_features model.classifier = nn.Sequential( nn.Dropout(p=0.3), nn.Linear(num_features, len(class_to_idx)) ) model.load_state_dict(checkpoint['model_state_dict']) model.to(DEVICE) model.eval() return model, idx_to_class print(f"Loading model on {DEVICE}...") model, idx_to_class = load_model() print("Model loaded!") # Image transform transform = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) def predict(image): """Predict which cat is in the image.""" if image is None: return None # Convert to PIL if needed if not isinstance(image, Image.Image): image = Image.fromarray(image) image = image.convert('RGB') tensor = transform(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): outputs = model(tensor) probs = torch.softmax(outputs, dim=1)[0] return {idx_to_class[i].capitalize(): float(probs[i]) for i in range(len(idx_to_class))} # Build interface demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil", label="Upload a cat photo"), outputs=gr.Label(num_top_classes=2, label="Prediction"), title="Which Cat? - Lucy vs Madelaine", description="Upload a photo to identify if it's Lucy or Madelaine!", examples=[ ["cats/lucy/IMG_7066.jpeg"], ["cats/madelaine/IMG_2730.jpeg"], ] if Path("cats").exists() else None, allow_flagging="never", ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0")