| | --- |
| | license: apache-2.0 |
| | datasets: |
| | - deepghs/danbooru2024 |
| | tags: |
| | - Classification |
| | pipeline_tag: image-classification |
| | --- |
| | |
| | # DT24-Tiny |
| |
|
| | A Mutli-Label Image Classification model Trained using [deepghs/danbooru2024](https://huggingface.co/datasets/deepghs/danbooru2024) designed to tag anime-style illustrations with a vocabulary of **10,000 tags**. |
| |
|
| | This model uses **ConvNeXt V2 Tiny** as the backbone, optimized for a balance between speed and accuracy (448px resolution). |
| |
|
| | | Attribute | Details | |
| | | :--- | :--- | |
| | | **Model Architecture** | ConvNeXt V2 Tiny + GeM Pooling | |
| | | **Resolution** | 448 x 448 (Letterbox Padding) | |
| | | **Vocabulary** | Top 10,000 Tags (Danbooru) | |
| | | **Format** | SafeTensors (`model.safetensors`) | |
| |
|
| | ## ๐ Quick Start (Inference) |
| |
|
| | You need `timm`, `torch`, and `Pillow` installed. |
| |
|
| | ```bash |
| | pip install torch torchvision timm pillow huggingface_hub |
| | ``` |
| |
|
| | ### Python Inference Script |
| |
|
| | Since this model uses a custom head (GeM Pooling + Linear), you need to define the class structure before loading weights. |
| |
|
| | ```python |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from PIL import Image |
| | import timm |
| | from huggingface_hub import hf_hub_download |
| | import json |
| | from torchvision import transforms |
| | |
| | # --- 1. Define Architecture --- |
| | class GeM(nn.Module): |
| | def __init__(self, p=3, eps=1e-6): |
| | super(GeM, self).__init__() |
| | self.p = nn.Parameter(torch.ones(1) * p) |
| | self.eps = eps |
| | def forward(self, x): |
| | return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1. / self.p) |
| | |
| | class DT24Tiny(nn.Module): |
| | def __init__(self, num_classes=10000): |
| | super().__init__() |
| | # Load backbone without head |
| | self.backbone = timm.create_model("convnextv2_tiny.fcmae_ft_in1k", pretrained=False, num_classes=0, global_pool='') |
| | self.pooling = GeM() |
| | self.head = nn.Linear(768, num_classes) |
| | |
| | def forward(self, x): |
| | # Sigmoid is applied here for multi-label classification |
| | return torch.sigmoid(self.head(self.pooling(self.backbone(x)).flatten(1))) |
| | |
| | # --- 2. Load Model & Tags --- |
| | REPO_ID = "igidn/DT24-Tiny" |
| | |
| | # Load Tags |
| | tag_path = hf_hub_download(repo_id=REPO_ID, filename="tags.json") |
| | with open(tag_path, "r") as f: |
| | tag_map = json.load(f) |
| | idx_to_tag = {v: k for k, v in tag_map.items()} |
| | |
| | # Load Weights (SafeTensors) |
| | from safetensors.torch import load_file |
| | model_path = hf_hub_download(repo_id=REPO_ID, filename="model.safetensors") |
| | state_dict = load_file(model_path) |
| | |
| | model = DT24Tiny(num_classes=len(tag_map)) |
| | model.load_state_dict(state_dict) |
| | model.eval() |
| | |
| | # --- 3. Preprocessing (Letterbox Pad) --- |
| | class LetterboxPad: |
| | def __init__(self, size): self.size = size |
| | def __call__(self, img): |
| | w, h = img.size |
| | scale = self.size / max(w, h) |
| | new_w, new_h = int(w * scale), int(h * scale) |
| | img = img.resize((new_w, new_h), Image.BICUBIC) |
| | new_img = Image.new("RGB", (self.size, self.size), (0, 0, 0)) |
| | new_img.paste(img, ((self.size - new_w) // 2, (self.size - new_h) // 2)) |
| | return new_img |
| | |
| | transform = transforms.Compose([ |
| | LetterboxPad(448), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| | ]) |
| | |
| | # --- 4. Prediction --- |
| | def predict(image_path, threshold=0.60): |
| | img = Image.open(image_path).convert("RGB") |
| | tensor = transform(img).unsqueeze(0) |
| | |
| | with torch.no_grad(): |
| | probs = model(tensor)[0] |
| | |
| | # Filter results |
| | results = {} |
| | for idx, score in enumerate(probs): |
| | if score > threshold: |
| | results[idx_to_tag[idx]] = score.item() |
| | |
| | return dict(sorted(results.items(), key=lambda item: item[1], reverse=True)) |
| | |
| | # Test |
| | # print(predict("test_image.jpg")) |
| | ``` |
| |
|
| | ## ๐ Model Details |
| |
|
| | ### Training Data |
| | * **Dataset:** `deepghs/danbooru2024` |
| | * **Selection:** Top 10,000 most frequent tags. |
| |
|
| |
|
| | ### Preprocessing |
| | Unlike standard resizing which distorts aspect ratio, this model uses **Letterbox Padding**. |
| | 1. Resize the longest edge to 448px. |
| | 2. Paste the image onto a black 448x448 canvas. |
| | 3. Standard ImageNet normalization. |
| |
|
| | ### Architecture Nuances |
| | * **Backbone:** ConvNeXt V2 Tiny (Pretrained on ImageNet-1k). |
| | * **Pooling:** Replaced standard Global Average Pooling with **GeM (Generalized Mean Pooling)**. This allows the model to better focus on salient features (like small accessories) rather than washing them out. |
| | * **Head:** A single Linear layer mapping 768 features to 10,000 tags. |
| | * **Loss:** Trained with **Asymmetric Loss (ASL)** to handle the extreme class imbalance of sparse tagging. |
| |
|
| | ## ๐ Files in Repo |
| | * `model.safetensors`: The FP16 trained weights (use this for inference). |
| | * `config.json`: Basic configuration parameters. |
| | * `tags.json`: The mapping of `Tag Name -> Index`. |
| | * `optimizer.pt`: (Optional) Optimizer state, only needed if you plan to resume training this model. |
| |
|
| | ## โ๏ธ License |
| | This model is released under the **Apache 2.0** license. |