|
|
--- |
|
|
library_name: timm |
|
|
pipeline_tag: image-classification |
|
|
base_model: |
|
|
- timm/efficientnet_b0 |
|
|
tags: |
|
|
- anime-classification |
|
|
- real-photos |
|
|
- rendered-graphics |
|
|
- pytorch |
|
|
- efficientnet |
|
|
- vision |
|
|
license: openrail |
|
|
model_type: efficientnet_b0 |
|
|
inference: true |
|
|
--- |
|
|
|
|
|
# Anime/Real/Rendered Image Classifier (EfficientNet-B0) |
|
|
|
|
|
**Fast, lightweight classifier for distinguishing photographs from anime and 3D rendered images.** |
|
|
|
|
|
## Model Details |
|
|
|
|
|
- **Architecture:** EfficientNet-B0 (timm) |
|
|
- **Input Size:** 224×224 RGB |
|
|
- **Classes:** anime, real, rendered |
|
|
- **Parameters:** 5.3M |
|
|
- **Validation Accuracy:** 97.44% |
|
|
- **Training Speed:** ~1 min/epoch (GPU) |
|
|
- **Inference Speed:** ~20ms per image (RTX 3060) |
|
|
|
|
|
## Performance |
|
|
|
|
|
| Class | Precision | Recall | F1-Score | |
|
|
|-------|-----------|--------|----------| |
|
|
| anime | 0.98 | 0.99 | 0.99 | |
|
|
| real | 0.98 | 0.98 | 0.98 | |
|
|
| rendered | 0.96 | 0.93 | 0.94 | |
|
|
| **macro avg** | **0.97** | **0.97** | **0.97** | |
|
|
|
|
|
## Usage |
|
|
|
|
|
```python |
|
|
from PIL import Image |
|
|
import torch |
|
|
from torchvision import transforms |
|
|
import timm |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
# Load model |
|
|
model = timm.create_model('efficientnet_b0', num_classes=3, pretrained=False) |
|
|
state_dict = load_file('model.safetensors') |
|
|
model.load_state_dict(state_dict) |
|
|
model.eval() |
|
|
|
|
|
# Prepare image |
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
|
|
]) |
|
|
|
|
|
image = Image.open('image.jpg').convert('RGB') |
|
|
x = transform(image).unsqueeze(0) |
|
|
|
|
|
# Predict |
|
|
with torch.no_grad(): |
|
|
logits = model(x) |
|
|
probs = torch.softmax(logits, dim=1) |
|
|
pred_class = probs.argmax(dim=1).item() |
|
|
|
|
|
labels = ['anime', 'real', 'rendered'] |
|
|
print(f"{labels[pred_class]}: {probs[0, pred_class]:.2%}") |
|
|
``` |
|
|
|
|
|
## Dataset |
|
|
|
|
|
- **Real:** 5,000 COCO 2017 validation images (diverse real-world scenarios) |
|
|
- **Anime:** 2,357 curated anime/animation frames |
|
|
- **Rendered:** 1,610 AAA game screenshots + 61 Pixar movie stills |
|
|
- **Total:** 8,967 images (8,070 train / 897 val) |
|
|
|
|
|
## Training Details |
|
|
|
|
|
- **Augmentation:** None (raw resize to 224×224) |
|
|
- **Optimizer:** AdamW (lr=0.001) |
|
|
- **Loss:** CrossEntropyLoss with class weighting |
|
|
- **Epochs:** 20 |
|
|
- **Batch Size:** 80 |
|
|
- **Hardware:** NVIDIA RTX 3060 (12GB) |
|
|
|
|
|
## Known Limitations |
|
|
|
|
|
- **Real vs Rendered:** Some confusion (photorealistic games misclassified as real) |
|
|
- **Stylized Games:** Cel-shaded games (e.g., Fate/Extella) may score as anime |
|
|
- **Pixar:** Stylized rendered images may show mixed confidence |
|
|
|
|
|
## Recommendations |
|
|
|
|
|
- Use ensemble with tf_efficientnetv2_s for critical applications |
|
|
- Apply confidence threshold: only trust predictions >85% confidence |
|
|
- For edge cases, use the full confusion matrix to understand failure modes |
|
|
|
|
|
## License |
|
|
|
|
|
OpenRAIL - Free for research and commercial use with proper attribution |
|
|
|