| --- |
| license: mit |
| base_model: |
| - google/vit-large-patch16-224 |
| pipeline_tag: image-classification |
| datasets: |
| - N-o-1/pokemon-images |
| --- |
| # Pokemon Team Classification with Vision Transformer |
|
|
| A fine-tuned Vision Transformer (ViT) model for classifying 6 specific Pokemon from a competitive team setup. This model can identify Arceus, Marshadow, Sandy Shocks, Slaking, Reshiram, and Magearna with high accuracy. |
|
|
| ## Model Details |
|
|
| - **Base Model**: `google/vit-base-patch16-224` |
| - **Model Type**: Vision Transformer for Image Classification |
| - **Classes**: 6 Pokemon (Arceus, Marshadow, Sandy Shocks, Slaking, Reshiram, Magearna) |
| - **Input Size**: 224x224 RGB images |
| - **Framework**: PyTorch + Transformers |
|
|
| ## Training Details |
|
|
| ### Dataset |
| - **Arceus**: 644 images |
| - **Marshadow**: 101 images |
| - **Sandy Shocks**: 75 images |
| - **Slaking**: 152 images |
| - **Reshiram**: 118 images |
| - **Magearna**: 200 images |
|
|
| ### Training Strategy |
| - **Balanced Sampling**: Each epoch uses exactly 75 samples per class to prevent overfitting on Arceus |
| - **Data Augmentation**: Random horizontal flip, rotation (±15°), color jitter, and resized crop |
| - **Transfer Learning**: Froze early ViT layers, fine-tuned classifier and later transformer layers |
| - **Early Stopping**: Training stopped when validation loss plateaued (patience=3 epochs) |
|
|
| ### Hyperparameters |
| - **Learning Rate**: 2e-5 |
| - **Batch Size**: 16 |
| - **Weight Decay**: 0.01 |
| - **Optimizer**: AdamW |
| - **Epochs**: ~18 (early stopped from max 1000) |
|
|
| ## Performance |
|
|
| The model achieves excellent classification performance with balanced accuracy across all 6 Pokemon classes despite the imbalanced training dataset. |
|
|
| ## Usage |
|
|
| ### Basic Classification |
|
|
| ```python |
| from transformers import ViTImageProcessor, ViTForImageClassification |
| from PIL import Image |
| import torch |
| |
| # Load model and processor |
| model = ViTForImageClassification.from_pretrained("your-username/pokemon-team-vit") |
| processor = ViTImageProcessor.from_pretrained("your-username/pokemon-team-vit") |
| |
| # Load and process image |
| image = Image.open("pokemon_image.jpg") |
| inputs = processor(images=image, return_tensors="pt") |
| |
| # Get predictions |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) |
| |
| # Get results |
| pokemon_names = ["arceus", "marshadow", "sandy-shocks", "slaking", "reshiram", "magearna"] |
| predicted_class = predictions.argmax().item() |
| confidence = predictions.max().item() |
| |
| print(f"Predicted: {pokemon_names[predicted_class]} (confidence: {confidence:.2%})") |
| ``` |
|
|
| ### Detailed Probabilities |
|
|
| ```python |
| # Get all class probabilities |
| probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)[0] |
| |
| results = {} |
| for idx, pokemon in enumerate(pokemon_names): |
| results[pokemon] = float(probabilities[idx]) |
| |
| # Sort by probability |
| sorted_results = sorted(results.items(), key=lambda x: x[1], reverse=True) |
| for pokemon, prob in sorted_results: |
| print(f"{pokemon}: {prob:.1%}") |
| ``` |
|
|
| ## Applications |
|
|
| - **Pokemon Recognition**: Identify specific Pokemon in images, artwork, or screenshots |
| - **Competitive Team Analysis**: Analyze team compositions in competitive Pokemon content |
| - **Content Moderation**: Filter or categorize Pokemon-related content |
| - **Educational Tools**: Pokemon identification for learning applications |
|
|
| ## Limitations |
|
|
| - **Specific Pokemon Only**: Only recognizes the 6 trained Pokemon classes |
| - **Image Quality**: Performance may vary with very low resolution or heavily distorted images |
| - **Artistic Variations**: May struggle with highly stylized or non-canonical Pokemon representations |
| - **Background Complexity**: Performance may decrease with very cluttered backgrounds |
|
|
| ## Model Architecture |
|
|
| The model uses the Vision Transformer (ViT) architecture: |
| - **Patch Size**: 16x16 |
| - **Hidden Size**: 768 |
| - **Attention Heads**: 12 |
| - **Layers**: 12 |
| - **Parameters**: ~86M (base model) + classification head |
|
|
| ## Training Infrastructure |
|
|
| - **Hardware**: AMD GPU with ROCm support |
| - **Framework**: PyTorch with Transformers library |
| - **Duration**: ~2 minutes per epoch, early stopped at epoch 18 |
| - **Memory**: Optimized for consumer-grade GPU memory |
|
|
| ## Citation |
|
|
| If you use this model, please cite: |
|
|
| ```bibtex |
| @misc{pokemon-team-vit, |
| title={Pokemon Team Classification with Vision Transformer}, |
| author={Steven Van Ingelgem}, |
| year={2025}, |
| url={https://huggingface.co/your-username/pokemon-team-vit} |
| } |
| ``` |
|
|
| ## License |
|
|
| This model is released under the MIT License. The training data consists of Pokemon images which are © The Pokémon Company/Nintendo. This model is for research and educational purposes. |
|
|
| ## Acknowledgments |
|
|
| - Base model: Google's Vision Transformer (ViT) |
| - Training framework: Hugging Face Transformers |
| - Pokemon images: Various sources for competitive team analysis |
|
|
| --- |
|
|
| **Note**: This model is specifically trained for a competitive Pokemon team setup and may not generalize to other Pokemon or use cases. For broader Pokemon classification, consider training on a more comprehensive dataset. |