#!/usr/bin/env python3 """Predict which cat is in an image.""" import argparse import torch import torch.nn as nn from torchvision import transforms, models from PIL import Image from pathlib import Path MODEL_PATH = "cat_classifier.pth" IMAGE_SIZE = 224 DEVICE = ( "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu" ) def load_model(model_path: str): """Load the trained model.""" checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False) class_to_idx = checkpoint['class_to_idx'] idx_to_class = {v: k for k, v in class_to_idx.items()} model = models.efficientnet_b0(weights=None) num_features = model.classifier[1].in_features model.classifier = nn.Sequential( nn.Dropout(p=0.3), nn.Linear(num_features, len(class_to_idx)) ) model.load_state_dict(checkpoint['model_state_dict']) model.to(DEVICE) model.eval() return model, idx_to_class def predict(model, image_path: str, idx_to_class: dict): """Predict the cat in an image.""" transform = transforms.Compose([ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) image = Image.open(image_path).convert('RGB') tensor = transform(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): outputs = model(tensor) probs = torch.softmax(outputs, dim=1)[0] pred_idx = probs.argmax().item() confidence = probs[pred_idx].item() return idx_to_class[pred_idx], confidence, {idx_to_class[i]: probs[i].item() for i in range(len(idx_to_class))} def main(): parser = argparse.ArgumentParser(description="Predict which cat is in an image") parser.add_argument("image", type=str, help="Path to image file") parser.add_argument("--model", type=str, default=MODEL_PATH, help="Path to model file") args = parser.parse_args() if not Path(args.model).exists(): print(f"Error: Model not found at {args.model}") print("Run train.py first to train the model.") return if not Path(args.image).exists(): print(f"Error: Image not found at {args.image}") return print(f"Using device: {DEVICE}") print(f"Loading model from {args.model}...") model, idx_to_class = load_model(args.model) prediction, confidence, all_probs = predict(model, args.image, idx_to_class) print(f"\nImage: {args.image}") print(f"Prediction: {prediction.upper()}") print(f"Confidence: {confidence:.1%}") print(f"\nAll probabilities:") for cat, prob in sorted(all_probs.items(), key=lambda x: -x[1]): print(f" {cat}: {prob:.1%}") if __name__ == "__main__": main()