File size: 2,796 Bytes
a200959
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python3
"""Predict which cat is in an image."""

import argparse
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
from pathlib import Path

MODEL_PATH = "cat_classifier.pth"
IMAGE_SIZE = 224
DEVICE = (
    "mps" if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available()
    else "cpu"
)


def load_model(model_path: str):
    """Load the trained model."""
    checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
    class_to_idx = checkpoint['class_to_idx']
    idx_to_class = {v: k for k, v in class_to_idx.items()}

    model = models.efficientnet_b0(weights=None)
    num_features = model.classifier[1].in_features
    model.classifier = nn.Sequential(
        nn.Dropout(p=0.3),
        nn.Linear(num_features, len(class_to_idx))
    )
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(DEVICE)
    model.eval()

    return model, idx_to_class


def predict(model, image_path: str, idx_to_class: dict):
    """Predict the cat in an image."""
    transform = transforms.Compose([
        transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    image = Image.open(image_path).convert('RGB')
    tensor = transform(image).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        outputs = model(tensor)
        probs = torch.softmax(outputs, dim=1)[0]
        pred_idx = probs.argmax().item()
        confidence = probs[pred_idx].item()

    return idx_to_class[pred_idx], confidence, {idx_to_class[i]: probs[i].item() for i in range(len(idx_to_class))}


def main():
    parser = argparse.ArgumentParser(description="Predict which cat is in an image")
    parser.add_argument("image", type=str, help="Path to image file")
    parser.add_argument("--model", type=str, default=MODEL_PATH, help="Path to model file")
    args = parser.parse_args()

    if not Path(args.model).exists():
        print(f"Error: Model not found at {args.model}")
        print("Run train.py first to train the model.")
        return

    if not Path(args.image).exists():
        print(f"Error: Image not found at {args.image}")
        return

    print(f"Using device: {DEVICE}")
    print(f"Loading model from {args.model}...")

    model, idx_to_class = load_model(args.model)
    prediction, confidence, all_probs = predict(model, args.image, idx_to_class)

    print(f"\nImage: {args.image}")
    print(f"Prediction: {prediction.upper()}")
    print(f"Confidence: {confidence:.1%}")
    print(f"\nAll probabilities:")
    for cat, prob in sorted(all_probs.items(), key=lambda x: -x[1]):
        print(f"  {cat}: {prob:.1%}")


if __name__ == "__main__":
    main()