|
|
|
|
|
"""Predict which cat is in an image.""" |
|
|
|
|
|
import argparse |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torchvision import transforms, models |
|
|
from PIL import Image |
|
|
from pathlib import Path |
|
|
|
|
|
MODEL_PATH = "cat_classifier.pth" |
|
|
IMAGE_SIZE = 224 |
|
|
DEVICE = ( |
|
|
"mps" if torch.backends.mps.is_available() |
|
|
else "cuda" if torch.cuda.is_available() |
|
|
else "cpu" |
|
|
) |
|
|
|
|
|
|
|
|
def load_model(model_path: str): |
|
|
"""Load the trained model.""" |
|
|
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False) |
|
|
class_to_idx = checkpoint['class_to_idx'] |
|
|
idx_to_class = {v: k for k, v in class_to_idx.items()} |
|
|
|
|
|
model = models.efficientnet_b0(weights=None) |
|
|
num_features = model.classifier[1].in_features |
|
|
model.classifier = nn.Sequential( |
|
|
nn.Dropout(p=0.3), |
|
|
nn.Linear(num_features, len(class_to_idx)) |
|
|
) |
|
|
model.load_state_dict(checkpoint['model_state_dict']) |
|
|
model.to(DEVICE) |
|
|
model.eval() |
|
|
|
|
|
return model, idx_to_class |
|
|
|
|
|
|
|
|
def predict(model, image_path: str, idx_to_class: dict): |
|
|
"""Predict the cat in an image.""" |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
image = Image.open(image_path).convert('RGB') |
|
|
tensor = transform(image).unsqueeze(0).to(DEVICE) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(tensor) |
|
|
probs = torch.softmax(outputs, dim=1)[0] |
|
|
pred_idx = probs.argmax().item() |
|
|
confidence = probs[pred_idx].item() |
|
|
|
|
|
return idx_to_class[pred_idx], confidence, {idx_to_class[i]: probs[i].item() for i in range(len(idx_to_class))} |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Predict which cat is in an image") |
|
|
parser.add_argument("image", type=str, help="Path to image file") |
|
|
parser.add_argument("--model", type=str, default=MODEL_PATH, help="Path to model file") |
|
|
args = parser.parse_args() |
|
|
|
|
|
if not Path(args.model).exists(): |
|
|
print(f"Error: Model not found at {args.model}") |
|
|
print("Run train.py first to train the model.") |
|
|
return |
|
|
|
|
|
if not Path(args.image).exists(): |
|
|
print(f"Error: Image not found at {args.image}") |
|
|
return |
|
|
|
|
|
print(f"Using device: {DEVICE}") |
|
|
print(f"Loading model from {args.model}...") |
|
|
|
|
|
model, idx_to_class = load_model(args.model) |
|
|
prediction, confidence, all_probs = predict(model, args.image, idx_to_class) |
|
|
|
|
|
print(f"\nImage: {args.image}") |
|
|
print(f"Prediction: {prediction.upper()}") |
|
|
print(f"Confidence: {confidence:.1%}") |
|
|
print(f"\nAll probabilities:") |
|
|
for cat, prob in sorted(all_probs.items(), key=lambda x: -x[1]): |
|
|
print(f" {cat}: {prob:.1%}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|