File size: 2,854 Bytes
e8e07f4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from typing import Optional

import gradio as gr
from hfutils.repository import hf_hub_repo_url
from imgutils.generic import MultiLabelTIMMModel

KNOWN_MODELS = ['animetimm/caformer_b36.dbv4-full',
 'animetimm/caformer_m36.dbv4-full',
 'animetimm/caformer_s18.dbv4-full',
 'animetimm/caformer_s36.dbv4-full',
 'animetimm/convnext_base.dbv4-full',
 'animetimm/eva02_large_patch14_448.dbv4-full',
 'animetimm/mobilenetv3_large_100.dbv4-full',
 'animetimm/mobilenetv3_large_100.dbv4-full.r224',
 'animetimm/mobilenetv3_large_150d.dbv4-full',
 'animetimm/mobilenetv4_conv_aa_large.dbv4-full',
 'animetimm/mobilenetv4_conv_small.dbv4-full',
 'animetimm/mobilenetv4_conv_small_050.dbv4-full',
 'animetimm/mobilevitv2_200.dbv4-full',
 'animetimm/resnet18.dbv4-full',
 'animetimm/resnet34.dbv4-full',
 'animetimm/resnet50.dbv4-full',
 'animetimm/resnet101.dbv4-full',
 'animetimm/resnet152.dbv4-full',
 'animetimm/swinv2_base_window8_256.dbv4-full',
 'animetimm/vit_base_patch16_224.dbv4-full']
SPECIAL_MODELS = {'Recommended': 'animetimm/caformer_b36.dbv4-full',
 'Lightweight': 'animetimm/mobilenetv4_conv_aa_large.dbv4-full',
 'Classic EVA02': 'animetimm/eva02_large_patch14_448.dbv4-full',
 'Classic SwinV2': 'animetimm/swinv2_base_window8_256.dbv4-full'}


def render_model_demo(repo_id, label: Optional[str] = None):
    label = label or repo_id.split('/')[-1]
    with gr.Tab(label):
        model = MultiLabelTIMMModel(repo_id=repo_id)

        with gr.Row():
            with gr.Column():
                repo_url = hf_hub_repo_url(repo_id=repo_id, repo_type='model')
                gr.Markdown(f'This is the quick demo for tagger model [{repo_id}]({repo_url}).')

        with gr.Row():
            model.make_ui()


if __name__ == '__main__':
    with gr.Blocks() as demo:
        with gr.Row():
            with gr.Column():
                gr.HTML(f'<h2 style="text-align: center;">Tagger Playground For Dbv4 Full</h2>')
                gr.Markdown(f'This is the playground for taggers trained on [animetimm/danbooru-wdtagger-v4-w640-ws-full](https://huggingface.co/datasets/animetimm/danbooru-wdtagger-v4-w640-ws-full).'
                            f'Powered by `dghs-imgutils`\'s quick demo module.')
                gr.Markdown(f'Official ranklist is on [animetimm/dbv4-full-ranklist](https://huggingface.co/spaces/animetimm/dbv4-full-ranklist).')

        with gr.Row():
            with gr.Tabs():
                _exist_models = set()
                for t, repo_id in SPECIAL_MODELS.items():
                    render_model_demo(repo_id, f'{repo_id.split("/")[-1]} ({t})')
                    _exist_models.add(repo_id)

                for repo_id in KNOWN_MODELS:
                    if repo_id not in _exist_models:
                        render_model_demo(repo_id)
                        _exist_models.add(repo_id)

    demo.launch()