File size: 5,001 Bytes
685cf04
 
 
 
 
 
f464a84
685cf04
 
 
 
cea6824
685cf04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f201200
685cf04
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
---
license: apache-2.0
datasets:
- deepghs/danbooru2024
tags:
- Classification
pipeline_tag: image-classification
---

# DT24-Tiny

A Mutli-Label Image Classification model Trained using [deepghs/danbooru2024](https://huggingface.co/datasets/deepghs/danbooru2024) designed to tag anime-style illustrations with a vocabulary of **10,000 tags**.

This model uses **ConvNeXt V2 Tiny** as the backbone, optimized for a balance between speed and accuracy (448px resolution).

| Attribute | Details |
| :--- | :--- |
| **Model Architecture** | ConvNeXt V2 Tiny + GeM Pooling |
| **Resolution** | 448 x 448 (Letterbox Padding) |
| **Vocabulary** | Top 10,000 Tags (Danbooru) |
| **Format** | SafeTensors (`model.safetensors`) |

## 🚀 Quick Start (Inference)

You need `timm`, `torch`, and `Pillow` installed.

```bash
pip install torch torchvision timm pillow huggingface_hub
```

### Python Inference Script

Since this model uses a custom head (GeM Pooling + Linear), you need to define the class structure before loading weights.

```python
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import timm
from huggingface_hub import hf_hub_download
import json
from torchvision import transforms

# --- 1. Define Architecture ---
class GeM(nn.Module):
    def __init__(self, p=3, eps=1e-6):
        super(GeM, self).__init__()
        self.p = nn.Parameter(torch.ones(1) * p)
        self.eps = eps
    def forward(self, x):
        return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1. / self.p)

class DT24Tiny(nn.Module):
    def __init__(self, num_classes=10000):
        super().__init__()
        # Load backbone without head
        self.backbone = timm.create_model("convnextv2_tiny.fcmae_ft_in1k", pretrained=False, num_classes=0, global_pool='')
        self.pooling = GeM()
        self.head = nn.Linear(768, num_classes)

    def forward(self, x):
        # Sigmoid is applied here for multi-label classification
        return torch.sigmoid(self.head(self.pooling(self.backbone(x)).flatten(1)))

# --- 2. Load Model & Tags ---
REPO_ID = "igidn/DT24-Tiny"

# Load Tags
tag_path = hf_hub_download(repo_id=REPO_ID, filename="tags.json")
with open(tag_path, "r") as f:
    tag_map = json.load(f)
idx_to_tag = {v: k for k, v in tag_map.items()}

# Load Weights (SafeTensors)
from safetensors.torch import load_file
model_path = hf_hub_download(repo_id=REPO_ID, filename="model.safetensors")
state_dict = load_file(model_path)

model = DT24Tiny(num_classes=len(tag_map))
model.load_state_dict(state_dict)
model.eval()

# --- 3. Preprocessing (Letterbox Pad) ---
class LetterboxPad:
    def __init__(self, size): self.size = size
    def __call__(self, img):
        w, h = img.size
        scale = self.size / max(w, h)
        new_w, new_h = int(w * scale), int(h * scale)
        img = img.resize((new_w, new_h), Image.BICUBIC)
        new_img = Image.new("RGB", (self.size, self.size), (0, 0, 0))
        new_img.paste(img, ((self.size - new_w) // 2, (self.size - new_h) // 2))
        return new_img

transform = transforms.Compose([
    LetterboxPad(448),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# --- 4. Prediction ---
def predict(image_path, threshold=0.60):
    img = Image.open(image_path).convert("RGB")
    tensor = transform(img).unsqueeze(0)
    
    with torch.no_grad():
        probs = model(tensor)[0]
    
    # Filter results
    results = {}
    for idx, score in enumerate(probs):
        if score > threshold:
            results[idx_to_tag[idx]] = score.item()
    
    return dict(sorted(results.items(), key=lambda item: item[1], reverse=True))

# Test
# print(predict("test_image.jpg"))
```

## 🛠 Model Details

### Training Data
*   **Dataset:** `deepghs/danbooru2024`
*   **Selection:** Top 10,000 most frequent tags.


### Preprocessing
Unlike standard resizing which distorts aspect ratio, this model uses **Letterbox Padding**.
1.  Resize the longest edge to 448px.
2.  Paste the image onto a black 448x448 canvas.
3.  Standard ImageNet normalization.

### Architecture Nuances
*   **Backbone:** ConvNeXt V2 Tiny (Pretrained on ImageNet-1k).
*   **Pooling:** Replaced standard Global Average Pooling with **GeM (Generalized Mean Pooling)**. This allows the model to better focus on salient features (like small accessories) rather than washing them out.
*   **Head:** A single Linear layer mapping 768 features to 10,000 tags.
*   **Loss:** Trained with **Asymmetric Loss (ASL)** to handle the extreme class imbalance of sparse tagging.

## 📂 Files in Repo
*   `model.safetensors`: The FP16 trained weights (use this for inference).
*   `config.json`: Basic configuration parameters.
*   `tags.json`: The mapping of `Tag Name -> Index`.
*   `optimizer.pt`: (Optional) Optimizer state, only needed if you plan to resume training this model.

## ⚖️ License
This model is released under the **Apache 2.0** license.