resnet18.dbv4-full / README.md
narugo1992's picture
Upload model 'animetimm/resnet18.dbv4-full', on 2025-09-29 06:41:56 UTC
a21320d verified
---
tags:
- image-classification
- timm
- transformers
- animetimm
- dghs-imgutils
library_name: timm
license: gpl-3.0
datasets:
- animetimm/danbooru-wdtagger-v4-w640-ws-full
base_model:
- timm/resnet18.a1_in1k
---
# Anime Tagger resnet18.dbv4-full
## Model Details
- **Model Type:** Multilabel Image classification / feature backbone
- **Model Stats:**
- Params: 17.6M
- FLOPs / MACs: 10.7G / 5.3G
- Image size: train = 384 x 384, test = 384 x 384
- **Dataset:** [animetimm/danbooru-wdtagger-v4-w640-ws-full](https://huggingface.co/datasets/animetimm/danbooru-wdtagger-v4-w640-ws-full)
- Tags Count: 12476
- General (#0) Tags Count: 9225
- Character (#4) Tags Count: 3247
- Rating (#9) Tags Count: 4
## Results
| # | Macro@0.40 (F1/MCC/P/R) | Micro@0.40 (F1/MCC/P/R) | Macro@Best (F1/P/R) |
|:----------:|:-----------------------------:|:-----------------------------:|:---------------------:|
| Validation | 0.263 / 0.277 / 0.368 / 0.236 | 0.502 / 0.508 / 0.595 / 0.435 | --- |
| Test | 0.259 / 0.275 / 0.380 / 0.227 | 0.499 / 0.507 / 0.613 / 0.421 | 0.314 / 0.360 / 0.309 |
* `Macro/Micro@0.40` means the metrics on the threshold 0.40.
* `Macro@Best` means the mean metrics on the tag-level thresholds on each tags, which should have the best F1 scores.
## Thresholds
| Category | Name | Alpha | Threshold | Micro@Thr (F1/P/R) | Macro@0.40 (F1/P/R) | Macro@Best (F1/P/R) |
|:----------:|:---------:|:-------:|:-----------:|:---------------------:|:---------------------:|:---------------------:|
| 0 | general | 1 | 0.3 | 0.497 / 0.520 / 0.476 | 0.148 / 0.286 / 0.119 | 0.210 / 0.232 / 0.227 |
| 4 | character | 1 | 0.48 | 0.629 / 0.746 / 0.544 | 0.572 / 0.648 / 0.532 | 0.609 / 0.724 / 0.540 |
| 9 | rating | 1 | 0.38 | 0.745 / 0.678 / 0.828 | 0.736 / 0.696 / 0.786 | 0.740 / 0.693 / 0.800 |
* `Micro@Thr` means the metrics on the category-level suggested thresholds, which are listed in the table above.
* `Macro@0.40` means the metrics on the threshold 0.40.
* `Macro@Best` means the metrics on the tag-level thresholds on each tags, which should have the best F1 scores.
For tag-level thresholds, you can find them in [selected_tags.csv](https://huggingface.co/animetimm/resnet18.dbv4-full/resolve/main/selected_tags.csv).
## How to Use
We provided a sample image for our code samples, you can find it [here](https://huggingface.co/animetimm/resnet18.dbv4-full/blob/main/sample.webp).
### Use TIMM And Torch
Install [dghs-imgutils](https://github.com/deepghs/imgutils), [timm](https://github.com/huggingface/pytorch-image-models) and other necessary requirements with the following command
```shell
pip install 'dghs-imgutils>=0.19.0' torch huggingface_hub timm pillow pandas
```
After that you can load this model with timm library, and use it for train, validation and test, with the following code
```python
import json
import pandas as pd
import torch
from huggingface_hub import hf_hub_download
from imgutils.data import load_image
from imgutils.preprocess import create_torchvision_transforms
from timm import create_model
repo_id = 'animetimm/resnet18.dbv4-full'
model = create_model(f'hf-hub:{repo_id}', pretrained=True)
model.eval()
with open(hf_hub_download(repo_id=repo_id, repo_type='model', filename='preprocess.json'), 'r') as f:
preprocessor = create_torchvision_transforms(json.load(f)['test'])
# Compose(
# PadToSize(size=(512, 512), interpolation=bilinear, background_color=white)
# Resize(size=384, interpolation=bicubic, max_size=None, antialias=True)
# CenterCrop(size=[384, 384])
# MaybeToTensor()
# Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
# )
image = load_image('https://huggingface.co/animetimm/resnet18.dbv4-full/resolve/main/sample.webp')
input_ = preprocessor(image).unsqueeze(0)
# input_, shape: torch.Size([1, 3, 384, 384]), dtype: torch.float32
with torch.no_grad():
output = model(input_)
prediction = torch.sigmoid(output)[0]
# output, shape: torch.Size([1, 12476]), dtype: torch.float32
# prediction, shape: torch.Size([12476]), dtype: torch.float32
df_tags = pd.read_csv(
hf_hub_download(repo_id=repo_id, repo_type='model', filename='selected_tags.csv'),
keep_default_na=False
)
tags = df_tags['name']
mask = prediction.numpy() >= df_tags['best_threshold']
print(dict(zip(tags[mask].tolist(), prediction[mask].tolist())))
# {'general': 0.5340393781661987,
# 'sensitive': 0.5520890951156616,
# '1girl': 0.9604854583740234,
# 'solo': 0.9270129799842834,
# 'looking_at_viewer': 0.7812865376472473,
# 'blush': 0.7300496697425842,
# 'smile': 0.8309312462806702,
# 'short_hair': 0.47302114963531494,
# 'shirt': 0.5381302833557129,
# 'long_sleeves': 0.6570523977279663,
# 'brown_hair': 0.7048164010047913,
# 'holding': 0.42473122477531433,
# 'dress': 0.36369165778160095,
# 'closed_mouth': 0.6090019345283508,
# 'purple_eyes': 0.6659274101257324,
# 'collarbone': 0.2718302011489868,
# 'upper_body': 0.29708775877952576,
# 'flower': 0.8741877675056458,
# 'red_hair': 0.3947563171386719,
# 'medium_hair': 0.16671739518642426,
# 'hand_up': 0.22859826683998108,
# 'blunt_bangs': 0.3909393548965454,
# 'necklace': 0.44065433740615845,
# 'sweater': 0.317997008562088,
# 'head_tilt': 0.11918959766626358,
# 'rose': 0.4927496910095215,
# 'white_flower': 0.1997985988855362,
# 'light_smile': 0.10588516294956207,
# 'hand_on_own_face': 0.07584860175848007,
# 'blue_flower': 0.253018856048584,
# 'backlighting': 0.17737838625907898,
# 'bouquet': 0.5038037300109863,
# 'playing_with_own_hair': 0.032126910984516144}
```
### Use ONNX Model For Inference
Install [dghs-imgutils](https://github.com/deepghs/imgutils) with the following command
```shell
pip install 'dghs-imgutils>=0.19.0'
```
Use `multilabel_timm_predict` function with the following code
```python
from imgutils.generic import multilabel_timm_predict
general, character, rating = multilabel_timm_predict(
'https://huggingface.co/animetimm/resnet18.dbv4-full/resolve/main/sample.webp',
repo_id='animetimm/resnet18.dbv4-full',
fmt=('general', 'character', 'rating'),
)
print(general)
# {'1girl': 0.9604854583740234,
# 'solo': 0.9270131587982178,
# 'flower': 0.8741884827613831,
# 'smile': 0.8309313058853149,
# 'looking_at_viewer': 0.7812865972518921,
# 'blush': 0.7300494909286499,
# 'brown_hair': 0.7048173546791077,
# 'purple_eyes': 0.6659266948699951,
# 'long_sleeves': 0.6570525765419006,
# 'closed_mouth': 0.6090021729469299,
# 'shirt': 0.5381303429603577,
# 'bouquet': 0.5038049221038818,
# 'rose': 0.492749959230423,
# 'short_hair': 0.47302067279815674,
# 'necklace': 0.44065436720848083,
# 'holding': 0.42473119497299194,
# 'red_hair': 0.3947550058364868,
# 'blunt_bangs': 0.39093902707099915,
# 'dress': 0.3636913001537323,
# 'sweater': 0.3179968595504761,
# 'upper_body': 0.2970876693725586,
# 'collarbone': 0.27183014154434204,
# 'blue_flower': 0.2530190050601959,
# 'hand_up': 0.22859838604927063,
# 'white_flower': 0.19979920983314514,
# 'backlighting': 0.17737817764282227,
# 'medium_hair': 0.16671738028526306,
# 'head_tilt': 0.11918962001800537,
# 'light_smile': 0.10588520765304565,
# 'hand_on_own_face': 0.07584834098815918,
# 'playing_with_own_hair': 0.03212690353393555}
print(character)
# {}
print(rating)
# {'sensitive': 0.5520895719528198, 'general': 0.5340390801429749}
```
For further information, see [documentation of function multilabel_timm_predict](https://dghs-imgutils.deepghs.org/main/api_doc/generic/multilabel_timm.html#multilabel-timm-predict).