|
|
--- |
|
|
tags: |
|
|
- image-classification |
|
|
- timm |
|
|
- transformers |
|
|
- animetimm |
|
|
- dghs-imgutils |
|
|
library_name: timm |
|
|
license: gpl-3.0 |
|
|
datasets: |
|
|
- animetimm/danbooru-wdtagger-v4-w640-ws-full |
|
|
base_model: |
|
|
- timm/convnext_base.fb_in22k_ft_in1k |
|
|
--- |
|
|
|
|
|
# Anime Tagger convnext_base.dbv4-full |
|
|
|
|
|
## Model Details |
|
|
|
|
|
- **Model Type:** Multilabel Image classification / feature backbone |
|
|
- **Model Stats:** |
|
|
- Params: 100.4M |
|
|
- FLOPs / MACs: 123.1G / 61.4G |
|
|
- Image size: train = 448 x 448, test = 448 x 448 |
|
|
- **Dataset:** [animetimm/danbooru-wdtagger-v4-w640-ws-full](https://huggingface.co/datasets/animetimm/danbooru-wdtagger-v4-w640-ws-full) |
|
|
- Tags Count: 12476 |
|
|
- General (#0) Tags Count: 9225 |
|
|
- Character (#4) Tags Count: 3247 |
|
|
- Rating (#9) Tags Count: 4 |
|
|
|
|
|
## Results |
|
|
|
|
|
| # | Macro@0.40 (F1/MCC/P/R) | Micro@0.40 (F1/MCC/P/R) | Macro@Best (F1/P/R) | |
|
|
|:----------:|:-----------------------------:|:-----------------------------:|:---------------------:| |
|
|
| Validation | 0.521 / 0.525 / 0.538 / 0.524 | 0.669 / 0.668 / 0.663 / 0.675 | --- | |
|
|
| Test | 0.521 / 0.525 / 0.538 / 0.525 | 0.669 / 0.668 / 0.663 / 0.675 | 0.553 / 0.564 / 0.565 | |
|
|
|
|
|
* `Macro/Micro@0.40` means the metrics on the threshold 0.40. |
|
|
* `Macro@Best` means the mean metrics on the tag-level thresholds on each tags, which should have the best F1 scores. |
|
|
|
|
|
## Thresholds |
|
|
|
|
|
| Category | Name | Alpha | Threshold | Micro@Thr (F1/P/R) | Macro@0.40 (F1/P/R) | Macro@Best (F1/P/R) | |
|
|
|:----------:|:---------:|:-------:|:-----------:|:---------------------:|:---------------------:|:---------------------:| |
|
|
| 0 | general | 1 | 0.4 | 0.657 / 0.652 / 0.662 | 0.398 / 0.426 / 0.397 | 0.431 / 0.434 / 0.457 | |
|
|
| 4 | character | 1 | 0.66 | 0.907 / 0.939 / 0.877 | 0.870 / 0.857 / 0.887 | 0.900 / 0.933 / 0.873 | |
|
|
| 9 | rating | 1 | 0.4 | 0.823 / 0.784 / 0.866 | 0.829 / 0.801 / 0.860 | 0.830 / 0.806 / 0.857 | |
|
|
|
|
|
* `Micro@Thr` means the metrics on the category-level suggested thresholds, which are listed in the table above. |
|
|
* `Macro@0.40` means the metrics on the threshold 0.40. |
|
|
* `Macro@Best` means the metrics on the tag-level thresholds on each tags, which should have the best F1 scores. |
|
|
|
|
|
For tag-level thresholds, you can find them in [selected_tags.csv](https://huggingface.co/animetimm/convnext_base.dbv4-full/resolve/main/selected_tags.csv). |
|
|
|
|
|
## How to Use |
|
|
|
|
|
We provided a sample image for our code samples, you can find it [here](https://huggingface.co/animetimm/convnext_base.dbv4-full/blob/main/sample.webp). |
|
|
|
|
|
### Use TIMM And Torch |
|
|
|
|
|
Install [dghs-imgutils](https://github.com/deepghs/imgutils), [timm](https://github.com/huggingface/pytorch-image-models) and other necessary requirements with the following command |
|
|
|
|
|
```shell |
|
|
pip install 'dghs-imgutils>=0.17.0' torch huggingface_hub timm pillow pandas |
|
|
``` |
|
|
|
|
|
After that you can load this model with timm library, and use it for train, validation and test, with the following code |
|
|
|
|
|
```python |
|
|
import json |
|
|
|
|
|
import pandas as pd |
|
|
import torch |
|
|
from huggingface_hub import hf_hub_download |
|
|
from imgutils.data import load_image |
|
|
from imgutils.preprocess import create_torchvision_transforms |
|
|
from timm import create_model |
|
|
|
|
|
repo_id = 'animetimm/convnext_base.dbv4-full' |
|
|
model = create_model(f'hf-hub:{repo_id}', pretrained=True) |
|
|
model.eval() |
|
|
|
|
|
with open(hf_hub_download(repo_id=repo_id, repo_type='model', filename='preprocess.json'), 'r') as f: |
|
|
preprocessor = create_torchvision_transforms(json.load(f)['test']) |
|
|
# Compose( |
|
|
# PadToSize(size=(512, 512), interpolation=bilinear, background_color=white) |
|
|
# Resize(size=448, interpolation=bicubic, max_size=None, antialias=True) |
|
|
# CenterCrop(size=[448, 448]) |
|
|
# MaybeToTensor() |
|
|
# Normalize(mean=tensor([0.4850, 0.4560, 0.4060]), std=tensor([0.2290, 0.2240, 0.2250])) |
|
|
# ) |
|
|
|
|
|
image = load_image('https://huggingface.co/animetimm/convnext_base.dbv4-full/resolve/main/sample.webp') |
|
|
input_ = preprocessor(image).unsqueeze(0) |
|
|
# input_, shape: torch.Size([1, 3, 448, 448]), dtype: torch.float32 |
|
|
with torch.no_grad(): |
|
|
output = model(input_) |
|
|
prediction = torch.sigmoid(output)[0] |
|
|
# output, shape: torch.Size([1, 12476]), dtype: torch.float32 |
|
|
# prediction, shape: torch.Size([12476]), dtype: torch.float32 |
|
|
|
|
|
df_tags = pd.read_csv( |
|
|
hf_hub_download(repo_id=repo_id, repo_type='model', filename='selected_tags.csv'), |
|
|
keep_default_na=False |
|
|
) |
|
|
tags = df_tags['name'] |
|
|
mask = prediction.numpy() >= df_tags['best_threshold'] |
|
|
print(dict(zip(tags[mask].tolist(), prediction[mask].tolist()))) |
|
|
# {'sensitive': 0.8121482729911804, |
|
|
# '1girl': 0.9985615611076355, |
|
|
# 'solo': 0.9850295186042786, |
|
|
# 'looking_at_viewer': 0.7997868061065674, |
|
|
# 'blush': 0.8888278603553772, |
|
|
# 'smile': 0.9216164350509644, |
|
|
# 'short_hair': 0.7902982234954834, |
|
|
# 'shirt': 0.4786570370197296, |
|
|
# 'long_sleeves': 0.7071284055709839, |
|
|
# 'brown_hair': 0.5399358868598938, |
|
|
# 'holding': 0.6793648600578308, |
|
|
# 'dress': 0.6637858152389526, |
|
|
# 'closed_mouth': 0.3529078960418701, |
|
|
# 'sitting': 0.6096656322479248, |
|
|
# 'purple_eyes': 0.9120762348175049, |
|
|
# 'flower': 0.9145119786262512, |
|
|
# 'braid': 0.8637544512748718, |
|
|
# 'red_hair': 0.5444879531860352, |
|
|
# 'blunt_bangs': 0.3928454518318176, |
|
|
# 'tears': 0.7742242217063904, |
|
|
# 'plant': 0.8320311307907104, |
|
|
# 'blue_flower': 0.4381242096424103, |
|
|
# 'crown_braid': 0.7148879766464233, |
|
|
# 'potted_plant': 0.8044518232345581, |
|
|
# 'flower_pot': 0.7585387229919434, |
|
|
# 'pavement': 0.21422843635082245, |
|
|
# 'wiping_tears': 0.8426964282989502, |
|
|
# 'stone_floor': 0.09097751975059509, |
|
|
# 'cobblestone': 0.12426003068685532} |
|
|
``` |
|
|
### Use ONNX Model For Inference |
|
|
|
|
|
Install [dghs-imgutils](https://github.com/deepghs/imgutils) with the following command |
|
|
|
|
|
```shell |
|
|
pip install 'dghs-imgutils>=0.17.0' |
|
|
``` |
|
|
|
|
|
Use `multilabel_timm_predict` function with the following code |
|
|
|
|
|
```python |
|
|
from imgutils.generic import multilabel_timm_predict |
|
|
|
|
|
general, character, rating = multilabel_timm_predict( |
|
|
'https://huggingface.co/animetimm/convnext_base.dbv4-full/resolve/main/sample.webp', |
|
|
repo_id='animetimm/convnext_base.dbv4-full', |
|
|
fmt=('general', 'character', 'rating'), |
|
|
) |
|
|
|
|
|
print(general) |
|
|
# {'1girl': 0.9985615611076355, |
|
|
# 'solo': 0.9850296378135681, |
|
|
# 'smile': 0.9216161966323853, |
|
|
# 'flower': 0.9145116806030273, |
|
|
# 'purple_eyes': 0.9120762944221497, |
|
|
# 'blush': 0.8888277411460876, |
|
|
# 'braid': 0.863754391670227, |
|
|
# 'wiping_tears': 0.8426950573921204, |
|
|
# 'plant': 0.8320306539535522, |
|
|
# 'potted_plant': 0.8044508695602417, |
|
|
# 'looking_at_viewer': 0.7997866868972778, |
|
|
# 'short_hair': 0.7902982234954834, |
|
|
# 'tears': 0.7742242813110352, |
|
|
# 'flower_pot': 0.758537232875824, |
|
|
# 'crown_braid': 0.7148873209953308, |
|
|
# 'long_sleeves': 0.7071276307106018, |
|
|
# 'holding': 0.6793639659881592, |
|
|
# 'dress': 0.6637860536575317, |
|
|
# 'sitting': 0.6096659898757935, |
|
|
# 'red_hair': 0.544488251209259, |
|
|
# 'brown_hair': 0.539935827255249, |
|
|
# 'shirt': 0.4786553978919983, |
|
|
# 'blue_flower': 0.4381209909915924, |
|
|
# 'blunt_bangs': 0.39284440875053406, |
|
|
# 'closed_mouth': 0.35290712118148804, |
|
|
# 'pavement': 0.21422752737998962, |
|
|
# 'cobblestone': 0.12425824999809265, |
|
|
# 'stone_floor': 0.0909762978553772} |
|
|
print(character) |
|
|
# {} |
|
|
print(rating) |
|
|
# {'sensitive': 0.812149167060852} |
|
|
``` |
|
|
|
|
|
For further information, see [documentation of function multilabel_timm_predict](https://dghs-imgutils.deepghs.org/main/api_doc/generic/multilabel_timm.html#multilabel-timm-predict). |
|
|
|
|
|
|