| |
| """Gradio interface for the cat classifier.""" |
|
|
| import gradio as gr |
| import torch |
| import torch.nn as nn |
| from torchvision import transforms, models |
| from PIL import Image |
| from pathlib import Path |
|
|
| MODEL_PATH = "cat_classifier.pth" |
| IMAGE_SIZE = 224 |
| DEVICE = ( |
| "mps" if torch.backends.mps.is_available() |
| else "cuda" if torch.cuda.is_available() |
| else "cpu" |
| ) |
|
|
| |
| def load_model(): |
| checkpoint = torch.load(MODEL_PATH, map_location=DEVICE, weights_only=False) |
| class_to_idx = checkpoint['class_to_idx'] |
| idx_to_class = {v: k for k, v in class_to_idx.items()} |
|
|
| model = models.efficientnet_b0(weights=None) |
| num_features = model.classifier[1].in_features |
| model.classifier = nn.Sequential( |
| nn.Dropout(p=0.3), |
| nn.Linear(num_features, len(class_to_idx)) |
| ) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| model.to(DEVICE) |
| model.eval() |
|
|
| return model, idx_to_class |
|
|
| print(f"Loading model on {DEVICE}...") |
| model, idx_to_class = load_model() |
| print("Model loaded!") |
|
|
| |
| transform = transforms.Compose([ |
| transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
| ]) |
|
|
| def predict(image): |
| """Predict which cat is in the image.""" |
| if image is None: |
| return None |
|
|
| |
| if not isinstance(image, Image.Image): |
| image = Image.fromarray(image) |
|
|
| image = image.convert('RGB') |
| tensor = transform(image).unsqueeze(0).to(DEVICE) |
|
|
| with torch.no_grad(): |
| outputs = model(tensor) |
| probs = torch.softmax(outputs, dim=1)[0] |
|
|
| return {idx_to_class[i].capitalize(): float(probs[i]) for i in range(len(idx_to_class))} |
|
|
| |
| demo = gr.Interface( |
| fn=predict, |
| inputs=gr.Image(type="pil", label="Upload a cat photo"), |
| outputs=gr.Label(num_top_classes=2, label="Prediction"), |
| title="Which Cat? - Lucy vs Madelaine", |
| description="Upload a photo to identify if it's Lucy or Madelaine!", |
| examples=[ |
| ["cats/lucy/IMG_7066.jpeg"], |
| ["cats/madelaine/IMG_2730.jpeg"], |
| ] if Path("cats").exists() else None, |
| allow_flagging="never", |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0") |
|
|