DT24-Tiny / README.md
igidn's picture
Update README.md
cea6824 verified
---
license: apache-2.0
datasets:
- deepghs/danbooru2024
tags:
- Classification
pipeline_tag: image-classification
---
# DT24-Tiny
A Mutli-Label Image Classification model Trained using [deepghs/danbooru2024](https://huggingface.co/datasets/deepghs/danbooru2024) designed to tag anime-style illustrations with a vocabulary of **10,000 tags**.
This model uses **ConvNeXt V2 Tiny** as the backbone, optimized for a balance between speed and accuracy (448px resolution).
| Attribute | Details |
| :--- | :--- |
| **Model Architecture** | ConvNeXt V2 Tiny + GeM Pooling |
| **Resolution** | 448 x 448 (Letterbox Padding) |
| **Vocabulary** | Top 10,000 Tags (Danbooru) |
| **Format** | SafeTensors (`model.safetensors`) |
## ๐Ÿš€ Quick Start (Inference)
You need `timm`, `torch`, and `Pillow` installed.
```bash
pip install torch torchvision timm pillow huggingface_hub
```
### Python Inference Script
Since this model uses a custom head (GeM Pooling + Linear), you need to define the class structure before loading weights.
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import timm
from huggingface_hub import hf_hub_download
import json
from torchvision import transforms
# --- 1. Define Architecture ---
class GeM(nn.Module):
def __init__(self, p=3, eps=1e-6):
super(GeM, self).__init__()
self.p = nn.Parameter(torch.ones(1) * p)
self.eps = eps
def forward(self, x):
return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1. / self.p)
class DT24Tiny(nn.Module):
def __init__(self, num_classes=10000):
super().__init__()
# Load backbone without head
self.backbone = timm.create_model("convnextv2_tiny.fcmae_ft_in1k", pretrained=False, num_classes=0, global_pool='')
self.pooling = GeM()
self.head = nn.Linear(768, num_classes)
def forward(self, x):
# Sigmoid is applied here for multi-label classification
return torch.sigmoid(self.head(self.pooling(self.backbone(x)).flatten(1)))
# --- 2. Load Model & Tags ---
REPO_ID = "igidn/DT24-Tiny"
# Load Tags
tag_path = hf_hub_download(repo_id=REPO_ID, filename="tags.json")
with open(tag_path, "r") as f:
tag_map = json.load(f)
idx_to_tag = {v: k for k, v in tag_map.items()}
# Load Weights (SafeTensors)
from safetensors.torch import load_file
model_path = hf_hub_download(repo_id=REPO_ID, filename="model.safetensors")
state_dict = load_file(model_path)
model = DT24Tiny(num_classes=len(tag_map))
model.load_state_dict(state_dict)
model.eval()
# --- 3. Preprocessing (Letterbox Pad) ---
class LetterboxPad:
def __init__(self, size): self.size = size
def __call__(self, img):
w, h = img.size
scale = self.size / max(w, h)
new_w, new_h = int(w * scale), int(h * scale)
img = img.resize((new_w, new_h), Image.BICUBIC)
new_img = Image.new("RGB", (self.size, self.size), (0, 0, 0))
new_img.paste(img, ((self.size - new_w) // 2, (self.size - new_h) // 2))
return new_img
transform = transforms.Compose([
LetterboxPad(448),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# --- 4. Prediction ---
def predict(image_path, threshold=0.60):
img = Image.open(image_path).convert("RGB")
tensor = transform(img).unsqueeze(0)
with torch.no_grad():
probs = model(tensor)[0]
# Filter results
results = {}
for idx, score in enumerate(probs):
if score > threshold:
results[idx_to_tag[idx]] = score.item()
return dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
# Test
# print(predict("test_image.jpg"))
```
## ๐Ÿ›  Model Details
### Training Data
* **Dataset:** `deepghs/danbooru2024`
* **Selection:** Top 10,000 most frequent tags.
### Preprocessing
Unlike standard resizing which distorts aspect ratio, this model uses **Letterbox Padding**.
1. Resize the longest edge to 448px.
2. Paste the image onto a black 448x448 canvas.
3. Standard ImageNet normalization.
### Architecture Nuances
* **Backbone:** ConvNeXt V2 Tiny (Pretrained on ImageNet-1k).
* **Pooling:** Replaced standard Global Average Pooling with **GeM (Generalized Mean Pooling)**. This allows the model to better focus on salient features (like small accessories) rather than washing them out.
* **Head:** A single Linear layer mapping 768 features to 10,000 tags.
* **Loss:** Trained with **Asymmetric Loss (ASL)** to handle the extreme class imbalance of sparse tagging.
## ๐Ÿ“‚ Files in Repo
* `model.safetensors`: The FP16 trained weights (use this for inference).
* `config.json`: Basic configuration parameters.
* `tags.json`: The mapping of `Tag Name -> Index`.
* `optimizer.pt`: (Optional) Optimizer state, only needed if you plan to resume training this model.
## โš–๏ธ License
This model is released under the **Apache 2.0** license.