DT24-Tiny
A Mutli-Label Image Classification model Trained using deepghs/danbooru2024 designed to tag anime-style illustrations with a vocabulary of 10,000 tags.
This model uses ConvNeXt V2 Tiny as the backbone, optimized for a balance between speed and accuracy (448px resolution).
| Attribute | Details |
|---|---|
| Model Architecture | ConvNeXt V2 Tiny + GeM Pooling |
| Resolution | 448 x 448 (Letterbox Padding) |
| Vocabulary | Top 10,000 Tags (Danbooru) |
| Format | SafeTensors (model.safetensors) |
π Quick Start (Inference)
You need timm, torch, and Pillow installed.
pip install torch torchvision timm pillow huggingface_hub
Python Inference Script
Since this model uses a custom head (GeM Pooling + Linear), you need to define the class structure before loading weights.
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import timm
from huggingface_hub import hf_hub_download
import json
from torchvision import transforms
# --- 1. Define Architecture ---
class GeM(nn.Module):
def __init__(self, p=3, eps=1e-6):
super(GeM, self).__init__()
self.p = nn.Parameter(torch.ones(1) * p)
self.eps = eps
def forward(self, x):
return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1. / self.p)
class DT24Tiny(nn.Module):
def __init__(self, num_classes=10000):
super().__init__()
# Load backbone without head
self.backbone = timm.create_model("convnextv2_tiny.fcmae_ft_in1k", pretrained=False, num_classes=0, global_pool='')
self.pooling = GeM()
self.head = nn.Linear(768, num_classes)
def forward(self, x):
# Sigmoid is applied here for multi-label classification
return torch.sigmoid(self.head(self.pooling(self.backbone(x)).flatten(1)))
# --- 2. Load Model & Tags ---
REPO_ID = "igidn/DT24-Tiny"
# Load Tags
tag_path = hf_hub_download(repo_id=REPO_ID, filename="tags.json")
with open(tag_path, "r") as f:
tag_map = json.load(f)
idx_to_tag = {v: k for k, v in tag_map.items()}
# Load Weights (SafeTensors)
from safetensors.torch import load_file
model_path = hf_hub_download(repo_id=REPO_ID, filename="model.safetensors")
state_dict = load_file(model_path)
model = DT24Tiny(num_classes=len(tag_map))
model.load_state_dict(state_dict)
model.eval()
# --- 3. Preprocessing (Letterbox Pad) ---
class LetterboxPad:
def __init__(self, size): self.size = size
def __call__(self, img):
w, h = img.size
scale = self.size / max(w, h)
new_w, new_h = int(w * scale), int(h * scale)
img = img.resize((new_w, new_h), Image.BICUBIC)
new_img = Image.new("RGB", (self.size, self.size), (0, 0, 0))
new_img.paste(img, ((self.size - new_w) // 2, (self.size - new_h) // 2))
return new_img
transform = transforms.Compose([
LetterboxPad(448),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# --- 4. Prediction ---
def predict(image_path, threshold=0.60):
img = Image.open(image_path).convert("RGB")
tensor = transform(img).unsqueeze(0)
with torch.no_grad():
probs = model(tensor)[0]
# Filter results
results = {}
for idx, score in enumerate(probs):
if score > threshold:
results[idx_to_tag[idx]] = score.item()
return dict(sorted(results.items(), key=lambda item: item[1], reverse=True))
# Test
# print(predict("test_image.jpg"))
π Model Details
Training Data
- Dataset:
deepghs/danbooru2024 - Selection: Top 10,000 most frequent tags.
Preprocessing
Unlike standard resizing which distorts aspect ratio, this model uses Letterbox Padding.
- Resize the longest edge to 448px.
- Paste the image onto a black 448x448 canvas.
- Standard ImageNet normalization.
Architecture Nuances
- Backbone: ConvNeXt V2 Tiny (Pretrained on ImageNet-1k).
- Pooling: Replaced standard Global Average Pooling with GeM (Generalized Mean Pooling). This allows the model to better focus on salient features (like small accessories) rather than washing them out.
- Head: A single Linear layer mapping 768 features to 10,000 tags.
- Loss: Trained with Asymmetric Loss (ASL) to handle the extreme class imbalance of sparse tagging.
π Files in Repo
model.safetensors: The FP16 trained weights (use this for inference).config.json: Basic configuration parameters.tags.json: The mapping ofTag Name -> Index.optimizer.pt: (Optional) Optimizer state, only needed if you plan to resume training this model.
βοΈ License
This model is released under the Apache 2.0 license.
- Downloads last month
- 22