|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- image-classification |
|
|
- pytorch |
|
|
- simpsons |
|
|
- convnext |
|
|
datasets: |
|
|
- custom |
|
|
metrics: |
|
|
- accuracy |
|
|
--- |
|
|
|
|
|
# NYCU_ML_2025_ImageClassification |
|
|
|
|
|
## Model Description |
|
|
|
|
|
This is a **convnextv2_base.fcmae_ft_in22k_in1k (2023 - 推薦首選, timm)** model fine-tuned for **Simpsons character classification**. |
|
|
|
|
|
- **Developed by:** NYCU ML Course 2025 |
|
|
- **Model type:** Image Classification |
|
|
- **Framework:** PyTorch + timm |
|
|
- **Best Validation Accuracy:** 0.9934 |
|
|
|
|
|
## Training Details |
|
|
|
|
|
### Hyperparameters |
|
|
|
|
|
| Parameter | Value | |
|
|
|-----------|-------| |
|
|
| Image Resolution | 256 | |
|
|
| Batch Size | 80 | |
|
|
| Learning Rate | 0.0001 | |
|
|
| Optimizer | AdamW | |
|
|
| Weight Decay | 0.01 | |
|
|
| Scheduler | CosineAnnealingLR | |
|
|
| Label Smoothing | 0.1 | |
|
|
| Epochs | 15 | |
|
|
| CutMix | False | |
|
|
| HEM-TA | False | |
|
|
|
|
|
### Dataset |
|
|
|
|
|
- **Number of Classes:** 50 |
|
|
- **Training Samples:** 87236 |
|
|
- **Validation Samples:** 9693 |
|
|
|
|
|
### Classes |
|
|
|
|
|
``` |
|
|
abraham_grampa_simpson, agnes_skinner, apu_nahasapeemapetilon, barney_gumble, bart_simpson, brandine_spuckler, carl_carlson, charles_montgomery_burns, chief_wiggum, cletus_spuckler, comic_book_guy, disco_stu, dolph_starbeam, duff_man, edna_krabappel, fat_tony, gary_chalmers, gil, groundskeeper_willie, homer_simpson... |
|
|
``` |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
import torch |
|
|
import timm |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
|
|
|
# Load model |
|
|
model = timm.create_model('convnextv2_base.fcmae_ft_in22k_in1k', |
|
|
pretrained=False, |
|
|
num_classes=50) |
|
|
model.load_state_dict(torch.load('pytorch_model.pth', map_location='cpu')) |
|
|
model.eval() |
|
|
|
|
|
# Preprocess |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize(294), |
|
|
transforms.CenterCrop(256), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
# Predict |
|
|
img = Image.open('your_image.jpg').convert('RGB') |
|
|
input_tensor = transform(img).unsqueeze(0) |
|
|
with torch.no_grad(): |
|
|
output = model(input_tensor) |
|
|
pred = output.argmax(dim=1).item() |
|
|
``` |
|
|
|
|
|
## License |
|
|
|
|
|
MIT License |
|
|
|