which-cat / predict.py
khasinski's picture
Upload predict.py with huggingface_hub
a200959 verified
#!/usr/bin/env python3
"""Predict which cat is in an image."""
import argparse
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
from pathlib import Path
MODEL_PATH = "cat_classifier.pth"
IMAGE_SIZE = 224
DEVICE = (
"mps" if torch.backends.mps.is_available()
else "cuda" if torch.cuda.is_available()
else "cpu"
)
def load_model(model_path: str):
"""Load the trained model."""
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
class_to_idx = checkpoint['class_to_idx']
idx_to_class = {v: k for k, v in class_to_idx.items()}
model = models.efficientnet_b0(weights=None)
num_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
nn.Dropout(p=0.3),
nn.Linear(num_features, len(class_to_idx))
)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(DEVICE)
model.eval()
return model, idx_to_class
def predict(model, image_path: str, idx_to_class: dict):
"""Predict the cat in an image."""
transform = transforms.Compose([
transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open(image_path).convert('RGB')
tensor = transform(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
outputs = model(tensor)
probs = torch.softmax(outputs, dim=1)[0]
pred_idx = probs.argmax().item()
confidence = probs[pred_idx].item()
return idx_to_class[pred_idx], confidence, {idx_to_class[i]: probs[i].item() for i in range(len(idx_to_class))}
def main():
parser = argparse.ArgumentParser(description="Predict which cat is in an image")
parser.add_argument("image", type=str, help="Path to image file")
parser.add_argument("--model", type=str, default=MODEL_PATH, help="Path to model file")
args = parser.parse_args()
if not Path(args.model).exists():
print(f"Error: Model not found at {args.model}")
print("Run train.py first to train the model.")
return
if not Path(args.image).exists():
print(f"Error: Image not found at {args.image}")
return
print(f"Using device: {DEVICE}")
print(f"Loading model from {args.model}...")
model, idx_to_class = load_model(args.model)
prediction, confidence, all_probs = predict(model, args.image, idx_to_class)
print(f"\nImage: {args.image}")
print(f"Prediction: {prediction.upper()}")
print(f"Confidence: {confidence:.1%}")
print(f"\nAll probabilities:")
for cat, prob in sorted(all_probs.items(), key=lambda x: -x[1]):
print(f" {cat}: {prob:.1%}")
if __name__ == "__main__":
main()