|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- image-classification |
|
|
- pytorch |
|
|
- cats-vs-dogs |
|
|
- computer-vision |
|
|
datasets: |
|
|
- microsoft/cats_vs_dogs |
|
|
metrics: |
|
|
- accuracy |
|
|
--- |
|
|
|
|
|
# Cat vs Dog Classifier |
|
|
|
|
|
This model is a simple CNN (Convolutional Neural Network) trained to classify images as either cats or dogs. |
|
|
|
|
|
## Model Description |
|
|
|
|
|
- **Architecture**: Custom CNN with 3 convolutional layers and 3 fully connected layers |
|
|
- **Input**: 224x224 RGB images |
|
|
- **Output**: Binary classification (Cat or Dog) |
|
|
- **Framework**: PyTorch |
|
|
|
|
|
## Training Data |
|
|
|
|
|
The model was trained on the [microsoft/cats_vs_dogs](https://huggingface.co/datasets/microsoft/cats_vs_dogs) dataset from Hugging Face. |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import torch |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
|
|
|
# Load model |
|
|
model = CatDogClassifier() |
|
|
model.load_state_dict(torch.load('model_weights.pth')) |
|
|
model.eval() |
|
|
|
|
|
# Prepare image |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
image = Image.open('your_image.jpg').convert('RGB') |
|
|
image_tensor = transform(image).unsqueeze(0) |
|
|
|
|
|
# Predict |
|
|
with torch.no_grad(): |
|
|
outputs = model(image_tensor) |
|
|
_, predicted = torch.max(outputs.data, 1) |
|
|
|
|
|
classes = ['Cat', 'Dog'] |
|
|
print(f"Prediction: {classes[predicted.item()]}") |
|
|
``` |
|
|
|
|
|
## Training Procedure |
|
|
|
|
|
- **Optimizer**: Adam with learning rate 0.001 |
|
|
- **Loss Function**: CrossEntropyLoss |
|
|
- **Batch Size**: 32 |
|
|
- **Epochs**: 10 |
|
|
- **Data Augmentation**: Random horizontal flip, rotation, and color jitter |
|
|
|
|
|
## Performance |
|
|
|
|
|
The model achieves approximately 85-90% accuracy on the validation set (results may vary based on training run). |
|
|
|
|
|
## Limitations |
|
|
|
|
|
- Model is trained specifically on cats and dogs only |
|
|
- Performance may degrade on images with multiple animals |
|
|
- Works best with clear, well-lit images |
|
|
- Input images must be resized to 224x224 |
|
|
|
|
|
## License |
|
|
|
|
|
MIT License |
|
|
|
|
|
## Author |
|
|
|
|
|
Created as an educational project for learning image classification with PyTorch. |
|
|
|