pokemon-classifier / README.md
Fullname
Update README.md
85eb922 verified
|
Raw
History Blame Contribute Delete
5.01 kB
---
license: mit
base_model:
- google/vit-large-patch16-224
pipeline_tag: image-classification
datasets:
- N-o-1/pokemon-images
---
# Pokemon Team Classification with Vision Transformer
A fine-tuned Vision Transformer (ViT) model for classifying 6 specific Pokemon from a competitive team setup. This model can identify Arceus, Marshadow, Sandy Shocks, Slaking, Reshiram, and Magearna with high accuracy.
## Model Details
- **Base Model**: `google/vit-base-patch16-224`
- **Model Type**: Vision Transformer for Image Classification
- **Classes**: 6 Pokemon (Arceus, Marshadow, Sandy Shocks, Slaking, Reshiram, Magearna)
- **Input Size**: 224x224 RGB images
- **Framework**: PyTorch + Transformers
## Training Details
### Dataset
- **Arceus**: 644 images
- **Marshadow**: 101 images
- **Sandy Shocks**: 75 images
- **Slaking**: 152 images
- **Reshiram**: 118 images
- **Magearna**: 200 images
### Training Strategy
- **Balanced Sampling**: Each epoch uses exactly 75 samples per class to prevent overfitting on Arceus
- **Data Augmentation**: Random horizontal flip, rotation (±15°), color jitter, and resized crop
- **Transfer Learning**: Froze early ViT layers, fine-tuned classifier and later transformer layers
- **Early Stopping**: Training stopped when validation loss plateaued (patience=3 epochs)
### Hyperparameters
- **Learning Rate**: 2e-5
- **Batch Size**: 16
- **Weight Decay**: 0.01
- **Optimizer**: AdamW
- **Epochs**: ~18 (early stopped from max 1000)
## Performance
The model achieves excellent classification performance with balanced accuracy across all 6 Pokemon classes despite the imbalanced training dataset.
## Usage
### Basic Classification
```python
from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image
import torch
# Load model and processor
model = ViTForImageClassification.from_pretrained("your-username/pokemon-team-vit")
processor = ViTImageProcessor.from_pretrained("your-username/pokemon-team-vit")
# Load and process image
image = Image.open("pokemon_image.jpg")
inputs = processor(images=image, return_tensors="pt")
# Get predictions
with torch.no_grad():
outputs = model(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
# Get results
pokemon_names = ["arceus", "marshadow", "sandy-shocks", "slaking", "reshiram", "magearna"]
predicted_class = predictions.argmax().item()
confidence = predictions.max().item()
print(f"Predicted: {pokemon_names[predicted_class]} (confidence: {confidence:.2%})")
```
### Detailed Probabilities
```python
# Get all class probabilities
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)[0]
results = {}
for idx, pokemon in enumerate(pokemon_names):
results[pokemon] = float(probabilities[idx])
# Sort by probability
sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True)
for pokemon, prob in sorted_results:
print(f"{pokemon}: {prob:.1%}")
```
## Applications
- **Pokemon Recognition**: Identify specific Pokemon in images, artwork, or screenshots
- **Competitive Team Analysis**: Analyze team compositions in competitive Pokemon content
- **Content Moderation**: Filter or categorize Pokemon-related content
- **Educational Tools**: Pokemon identification for learning applications
## Limitations
- **Specific Pokemon Only**: Only recognizes the 6 trained Pokemon classes
- **Image Quality**: Performance may vary with very low resolution or heavily distorted images
- **Artistic Variations**: May struggle with highly stylized or non-canonical Pokemon representations
- **Background Complexity**: Performance may decrease with very cluttered backgrounds
## Model Architecture
The model uses the Vision Transformer (ViT) architecture:
- **Patch Size**: 16x16
- **Hidden Size**: 768
- **Attention Heads**: 12
- **Layers**: 12
- **Parameters**: ~86M (base model) + classification head
## Training Infrastructure
- **Hardware**: AMD GPU with ROCm support
- **Framework**: PyTorch with Transformers library
- **Duration**: ~2 minutes per epoch, early stopped at epoch 18
- **Memory**: Optimized for consumer-grade GPU memory
## Citation
If you use this model, please cite:
```bibtex
@misc{pokemon-team-vit,
title={Pokemon Team Classification with Vision Transformer},
author={Steven Van Ingelgem},
year={2025},
url={https://huggingface.co/your-username/pokemon-team-vit}
}
```
## License
This model is released under the MIT License. The training data consists of Pokemon images which are © The Pokémon Company/Nintendo. This model is for research and educational purposes.
## Acknowledgments
- Base model: Google's Vision Transformer (ViT)
- Training framework: Hugging Face Transformers
- Pokemon images: Various sources for competitive team analysis
---
**Note**: This model is specifically trained for a competitive Pokemon team setup and may not generalize to other Pokemon or use cases. For broader Pokemon classification, consider training on a more comprehensive dataset.