File size: 2,084 Bytes
28cd547 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 | ---
license: mit
tags:
- image-classification
- pytorch
- cats
- efficientnet
library_name: pytorch
pipeline_tag: image-classification
---
# Which Cat? - Lucy vs Madelaine Classifier
A fine-tuned EfficientNet-B0 model that distinguishes between two cats: Lucy and Madelaine.
## Model Details
- **Base Model**: EfficientNet-B0 (pretrained on ImageNet)
- **Task**: Binary image classification
- **Classes**: `lucy`, `madelaine`
- **Training Data**: ~190 personal photos
- **Validation Accuracy**: 90%
## Usage
```python
import torch
import torch.nn as nn
from torchvision import transforms, models
from PIL import Image
from huggingface_hub import hf_hub_download
# Download model
model_path = hf_hub_download(repo_id="khasinski/which-cat", filename="cat_classifier.pth")
# Load model
device = "cuda" if torch.cuda.is_available() else "cpu"
checkpoint = torch.load(model_path, map_location=device, weights_only=False)
model = models.efficientnet_b0(weights=None)
model.classifier = nn.Sequential(
nn.Dropout(p=0.3),
nn.Linear(1280, 2)
)
model.load_state_dict(checkpoint['model_state_dict'])
model.to(device)
model.eval()
# Class mapping
idx_to_class = {v: k for k, v in checkpoint['class_to_idx'].items()}
# Predict
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
image = Image.open("your_cat.jpg").convert('RGB')
tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
probs = torch.softmax(model(tensor), dim=1)[0]
pred_idx = probs.argmax().item()
print(f"Prediction: {idx_to_class[pred_idx]} ({probs[pred_idx]:.1%})")
```
## Training
Trained using transfer learning with:
- Data augmentation (flips, rotations, color jitter)
- Weighted random sampling for class balance
- AdamW optimizer with learning rate scheduling
- 20 epochs on Apple MPS
## Limitations
This model is trained specifically on Lucy and Madelaine. It will not generalize to other cats - it will simply classify any cat image as one of the two.
|