narugo1992's picture
Upload model 'animetimm/swinv2_base_window8_256.dbv4-full', on 2025-06-03 02:09:12 JST
e09b265 verified
---
tags:
- image-classification
- timm
- transformers
- animetimm
- dghs-imgutils
library_name: timm
license: gpl-3.0
datasets:
- animetimm/danbooru-wdtagger-v4-w640-ws-full
base_model:
- timm/swinv2_base_window8_256.ms_in1k
---
# Anime Tagger swinv2_base_window8_256.dbv4-full
## Model Details
- **Model Type:** Multilabel Image classification / feature backbone
- **Model Stats:**
- Params: 99.7M
- FLOPs / MACs: 121.6G / 60.7G
- Image size: train = 448 x 448, test = 448 x 448
- **Dataset:** [animetimm/danbooru-wdtagger-v4-w640-ws-full](https://huggingface.co/datasets/animetimm/danbooru-wdtagger-v4-w640-ws-full)
- Tags Count: 12476
- General (#0) Tags Count: 9225
- Character (#4) Tags Count: 3247
- Rating (#9) Tags Count: 4
## Results
| # | Macro@0.40 (F1/MCC/P/R) | Micro@0.40 (F1/MCC/P/R) | Macro@Best (F1/P/R) |
|:----------:|:-----------------------------:|:-----------------------------:|:---------------------:|
| Validation | 0.540 / 0.546 / 0.583 / 0.528 | 0.683 / 0.682 / 0.672 / 0.693 | --- |
| Test | 0.541 / 0.547 / 0.584 / 0.528 | 0.683 / 0.682 / 0.673 / 0.694 | 0.575 / 0.581 / 0.591 |
* `Macro/Micro@0.40` means the metrics on the threshold 0.40.
* `Macro@Best` means the mean metrics on the tag-level thresholds on each tags, which should have the best F1 scores.
## Thresholds
| Category | Name | Alpha | Threshold | Micro@Thr (F1/P/R) | Macro@0.40 (F1/P/R) | Macro@Best (F1/P/R) |
|:----------:|:---------:|:-------:|:-----------:|:---------------------:|:---------------------:|:---------------------:|
| 0 | general | 1 | 0.41 | 0.671 / 0.667 / 0.675 | 0.415 / 0.471 / 0.397 | 0.453 / 0.454 / 0.482 |
| 4 | character | 1 | 0.59 | 0.927 / 0.951 / 0.904 | 0.901 / 0.906 / 0.900 | 0.920 / 0.945 / 0.901 |
| 9 | rating | 1 | 0.41 | 0.827 / 0.791 / 0.867 | 0.833 / 0.803 / 0.866 | 0.834 / 0.812 / 0.859 |
* `Micro@Thr` means the metrics on the category-level suggested thresholds, which are listed in the table above.
* `Macro@0.40` means the metrics on the threshold 0.40.
* `Macro@Best` means the metrics on the tag-level thresholds on each tags, which should have the best F1 scores.
For tag-level thresholds, you can find them in [selected_tags.csv](https://huggingface.co/animetimm/swinv2_base_window8_256.dbv4-full/resolve/main/selected_tags.csv).
## How to Use
We provided a sample image for our code samples, you can find it [here](https://huggingface.co/animetimm/swinv2_base_window8_256.dbv4-full/blob/main/sample.webp).
### Use TIMM And Torch
Install [dghs-imgutils](https://github.com/deepghs/imgutils), [timm](https://github.com/huggingface/pytorch-image-models) and other necessary requirements with the following command
```shell
pip install 'dghs-imgutils>=0.17.0' torch huggingface_hub timm pillow pandas
```
After that you can load this model with timm library, and use it for train, validation and test, with the following code
```python
import json
import pandas as pd
import torch
from huggingface_hub import hf_hub_download
from imgutils.data import load_image
from imgutils.preprocess import create_torchvision_transforms
from timm import create_model
repo_id = 'animetimm/swinv2_base_window8_256.dbv4-full'
model = create_model(f'hf-hub:{repo_id}', pretrained=True)
model.eval()
with open(hf_hub_download(repo_id=repo_id, repo_type='model', filename='preprocess.json'), 'r') as f:
preprocessor = create_torchvision_transforms(json.load(f)['test'])
# Compose(
# PadToSize(size=(512, 512), interpolation=bilinear, background_color=white)
# Resize(size=448, interpolation=bicubic, max_size=None, antialias=True)
# CenterCrop(size=[448, 448])
# MaybeToTensor()
# Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250]))
# )
image = load_image('https://huggingface.co/animetimm/swinv2_base_window8_256.dbv4-full/resolve/main/sample.webp')
input_ = preprocessor(image).unsqueeze(0)
# input_, shape: torch.Size([1, 3, 448, 448]), dtype: torch.float32
with torch.no_grad():
output = model(input_)
prediction = torch.sigmoid(output)[0]
# output, shape: torch.Size([1, 12476]), dtype: torch.float32
# prediction, shape: torch.Size([12476]), dtype: torch.float32
df_tags = pd.read_csv(
hf_hub_download(repo_id=repo_id, repo_type='model', filename='selected_tags.csv'),
keep_default_na=False
)
tags = df_tags['name']
mask = prediction.numpy() >= df_tags['best_threshold']
print(dict(zip(tags[mask].tolist(), prediction[mask].tolist())))
# {'sensitive': 0.7605047821998596,
# '1girl': 0.9980626702308655,
# 'solo': 0.985005795955658,
# 'looking_at_viewer': 0.8788912892341614,
# 'blush': 0.8115326762199402,
# 'smile': 0.9378465414047241,
# 'short_hair': 0.8466857671737671,
# 'shirt': 0.49170181155204773,
# 'long_sleeves': 0.7332525849342346,
# 'brown_hair': 0.6334490180015564,
# 'holding': 0.5199263691902161,
# 'dress': 0.6529194116592407,
# 'closed_mouth': 0.43448883295059204,
# 'sitting': 0.6391631364822388,
# 'purple_eyes': 0.7848204970359802,
# 'flower': 0.9325912594795227,
# 'braid': 0.8920556902885437,
# 'outdoors': 0.41246461868286133,
# 'red_hair': 0.6809423565864563,
# 'blunt_bangs': 0.4314112067222595,
# 'tears': 0.8375990986824036,
# 'floral_print': 0.4037105143070221,
# 'crying': 0.3995090425014496,
# 'plant': 0.6664840579032898,
# 'blue_flower': 0.7186758518218994,
# 'backlighting': 0.27747398614883423,
# 'crown_braid': 0.7316360473632812,
# 'potted_plant': 0.5671563148498535,
# 'yellow_dress': 0.44971445202827454,
# 'flower_pot': 0.539954423904419,
# 'happy_tears': 0.37840017676353455,
# 'pavement': 0.22281722724437714,
# 'wiping_tears': 0.8595536351203918,
# 'brick_floor': 0.10392400622367859}
```
### Use ONNX Model For Inference
Install [dghs-imgutils](https://github.com/deepghs/imgutils) with the following command
```shell
pip install 'dghs-imgutils>=0.17.0'
```
Use `multilabel_timm_predict` function with the following code
```python
from imgutils.generic import multilabel_timm_predict
general, character, rating = multilabel_timm_predict(
'https://huggingface.co/animetimm/swinv2_base_window8_256.dbv4-full/resolve/main/sample.webp',
repo_id='animetimm/swinv2_base_window8_256.dbv4-full',
fmt=('general', 'character', 'rating'),
)
print(general)
# {'1girl': 0.9980627298355103,
# 'solo': 0.985005795955658,
# 'smile': 0.9378466010093689,
# 'flower': 0.932591438293457,
# 'braid': 0.8920557498931885,
# 'looking_at_viewer': 0.8788915872573853,
# 'wiping_tears': 0.8595534563064575,
# 'short_hair': 0.8466861248016357,
# 'tears': 0.8375992178916931,
# 'blush': 0.8115329742431641,
# 'purple_eyes': 0.784820556640625,
# 'long_sleeves': 0.7332528829574585,
# 'crown_braid': 0.7316359281539917,
# 'blue_flower': 0.7186765074729919,
# 'red_hair': 0.6809430122375488,
# 'plant': 0.6664847731590271,
# 'dress': 0.6529207229614258,
# 'sitting': 0.6391631364822388,
# 'brown_hair': 0.6334487199783325,
# 'potted_plant': 0.567157506942749,
# 'flower_pot': 0.5399554371833801,
# 'holding': 0.5199264287948608,
# 'shirt': 0.4917019009590149,
# 'yellow_dress': 0.44971588253974915,
# 'closed_mouth': 0.4344888925552368,
# 'blunt_bangs': 0.4314114451408386,
# 'outdoors': 0.4124644994735718,
# 'floral_print': 0.40371057391166687,
# 'crying': 0.399509072303772,
# 'happy_tears': 0.37840035557746887,
# 'backlighting': 0.2774738669395447,
# 'pavement': 0.22281798720359802,
# 'brick_floor': 0.10392436385154724}
print(character)
# {}
print(rating)
# {'sensitive': 0.7605049014091492}
```
For further information, see [documentation of function multilabel_timm_predict](https://dghs-imgutils.deepghs.org/main/api_doc/generic/multilabel_timm.html#multilabel-timm-predict).