--- license: apache-2.0 datasets: - deepghs/danbooru2024 tags: - Classification pipeline_tag: image-classification --- # DT24-Tiny A Mutli-Label Image Classification model Trained using [deepghs/danbooru2024](https://huggingface.co/datasets/deepghs/danbooru2024) designed to tag anime-style illustrations with a vocabulary of **10,000 tags**. This model uses **ConvNeXt V2 Tiny** as the backbone, optimized for a balance between speed and accuracy (448px resolution). | Attribute | Details | | :--- | :--- | | **Model Architecture** | ConvNeXt V2 Tiny + GeM Pooling | | **Resolution** | 448 x 448 (Letterbox Padding) | | **Vocabulary** | Top 10,000 Tags (Danbooru) | | **Format** | SafeTensors (`model.safetensors`) | ## 🚀 Quick Start (Inference) You need `timm`, `torch`, and `Pillow` installed. ```bash pip install torch torchvision timm pillow huggingface_hub ``` ### Python Inference Script Since this model uses a custom head (GeM Pooling + Linear), you need to define the class structure before loading weights. ```python import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image import timm from huggingface_hub import hf_hub_download import json from torchvision import transforms # --- 1. Define Architecture --- class GeM(nn.Module): def __init__(self, p=3, eps=1e-6): super(GeM, self).__init__() self.p = nn.Parameter(torch.ones(1) * p) self.eps = eps def forward(self, x): return F.avg_pool2d(x.clamp(min=self.eps).pow(self.p), (x.size(-2), x.size(-1))).pow(1. / self.p) class DT24Tiny(nn.Module): def __init__(self, num_classes=10000): super().__init__() # Load backbone without head self.backbone = timm.create_model("convnextv2_tiny.fcmae_ft_in1k", pretrained=False, num_classes=0, global_pool='') self.pooling = GeM() self.head = nn.Linear(768, num_classes) def forward(self, x): # Sigmoid is applied here for multi-label classification return torch.sigmoid(self.head(self.pooling(self.backbone(x)).flatten(1))) # --- 2. Load Model & Tags --- REPO_ID = "igidn/DT24-Tiny" # Load Tags tag_path = hf_hub_download(repo_id=REPO_ID, filename="tags.json") with open(tag_path, "r") as f: tag_map = json.load(f) idx_to_tag = {v: k for k, v in tag_map.items()} # Load Weights (SafeTensors) from safetensors.torch import load_file model_path = hf_hub_download(repo_id=REPO_ID, filename="model.safetensors") state_dict = load_file(model_path) model = DT24Tiny(num_classes=len(tag_map)) model.load_state_dict(state_dict) model.eval() # --- 3. Preprocessing (Letterbox Pad) --- class LetterboxPad: def __init__(self, size): self.size = size def __call__(self, img): w, h = img.size scale = self.size / max(w, h) new_w, new_h = int(w * scale), int(h * scale) img = img.resize((new_w, new_h), Image.BICUBIC) new_img = Image.new("RGB", (self.size, self.size), (0, 0, 0)) new_img.paste(img, ((self.size - new_w) // 2, (self.size - new_h) // 2)) return new_img transform = transforms.Compose([ LetterboxPad(448), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) # --- 4. Prediction --- def predict(image_path, threshold=0.60): img = Image.open(image_path).convert("RGB") tensor = transform(img).unsqueeze(0) with torch.no_grad(): probs = model(tensor)[0] # Filter results results = {} for idx, score in enumerate(probs): if score > threshold: results[idx_to_tag[idx]] = score.item() return dict(sorted(results.items(), key=lambda item: item[1], reverse=True)) # Test # print(predict("test_image.jpg")) ``` ## 🛠 Model Details ### Training Data * **Dataset:** `deepghs/danbooru2024` * **Selection:** Top 10,000 most frequent tags. ### Preprocessing Unlike standard resizing which distorts aspect ratio, this model uses **Letterbox Padding**. 1. Resize the longest edge to 448px. 2. Paste the image onto a black 448x448 canvas. 3. Standard ImageNet normalization. ### Architecture Nuances * **Backbone:** ConvNeXt V2 Tiny (Pretrained on ImageNet-1k). * **Pooling:** Replaced standard Global Average Pooling with **GeM (Generalized Mean Pooling)**. This allows the model to better focus on salient features (like small accessories) rather than washing them out. * **Head:** A single Linear layer mapping 768 features to 10,000 tags. * **Loss:** Trained with **Asymmetric Loss (ASL)** to handle the extreme class imbalance of sparse tagging. ## 📂 Files in Repo * `model.safetensors`: The FP16 trained weights (use this for inference). * `config.json`: Basic configuration parameters. * `tags.json`: The mapping of `Tag Name -> Index`. * `optimizer.pt`: (Optional) Optimizer state, only needed if you plan to resume training this model. ## ⚖️ License This model is released under the **Apache 2.0** license.