khasinski commited on
Commit
a200959
·
verified ·
1 Parent(s): 28cd547

Upload predict.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. predict.py +89 -0
predict.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Predict which cat is in an image."""
3
+
4
+ import argparse
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision import transforms, models
8
+ from PIL import Image
9
+ from pathlib import Path
10
+
11
+ MODEL_PATH = "cat_classifier.pth"
12
+ IMAGE_SIZE = 224
13
+ DEVICE = (
14
+ "mps" if torch.backends.mps.is_available()
15
+ else "cuda" if torch.cuda.is_available()
16
+ else "cpu"
17
+ )
18
+
19
+
20
+ def load_model(model_path: str):
21
+ """Load the trained model."""
22
+ checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
23
+ class_to_idx = checkpoint['class_to_idx']
24
+ idx_to_class = {v: k for k, v in class_to_idx.items()}
25
+
26
+ model = models.efficientnet_b0(weights=None)
27
+ num_features = model.classifier[1].in_features
28
+ model.classifier = nn.Sequential(
29
+ nn.Dropout(p=0.3),
30
+ nn.Linear(num_features, len(class_to_idx))
31
+ )
32
+ model.load_state_dict(checkpoint['model_state_dict'])
33
+ model.to(DEVICE)
34
+ model.eval()
35
+
36
+ return model, idx_to_class
37
+
38
+
39
+ def predict(model, image_path: str, idx_to_class: dict):
40
+ """Predict the cat in an image."""
41
+ transform = transforms.Compose([
42
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
43
+ transforms.ToTensor(),
44
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
45
+ ])
46
+
47
+ image = Image.open(image_path).convert('RGB')
48
+ tensor = transform(image).unsqueeze(0).to(DEVICE)
49
+
50
+ with torch.no_grad():
51
+ outputs = model(tensor)
52
+ probs = torch.softmax(outputs, dim=1)[0]
53
+ pred_idx = probs.argmax().item()
54
+ confidence = probs[pred_idx].item()
55
+
56
+ return idx_to_class[pred_idx], confidence, {idx_to_class[i]: probs[i].item() for i in range(len(idx_to_class))}
57
+
58
+
59
+ def main():
60
+ parser = argparse.ArgumentParser(description="Predict which cat is in an image")
61
+ parser.add_argument("image", type=str, help="Path to image file")
62
+ parser.add_argument("--model", type=str, default=MODEL_PATH, help="Path to model file")
63
+ args = parser.parse_args()
64
+
65
+ if not Path(args.model).exists():
66
+ print(f"Error: Model not found at {args.model}")
67
+ print("Run train.py first to train the model.")
68
+ return
69
+
70
+ if not Path(args.image).exists():
71
+ print(f"Error: Image not found at {args.image}")
72
+ return
73
+
74
+ print(f"Using device: {DEVICE}")
75
+ print(f"Loading model from {args.model}...")
76
+
77
+ model, idx_to_class = load_model(args.model)
78
+ prediction, confidence, all_probs = predict(model, args.image, idx_to_class)
79
+
80
+ print(f"\nImage: {args.image}")
81
+ print(f"Prediction: {prediction.upper()}")
82
+ print(f"Confidence: {confidence:.1%}")
83
+ print(f"\nAll probabilities:")
84
+ for cat, prob in sorted(all_probs.items(), key=lambda x: -x[1]):
85
+ print(f" {cat}: {prob:.1%}")
86
+
87
+
88
+ if __name__ == "__main__":
89
+ main()