|
|
from os import getenv |
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
|
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import onnxruntime as rt |
|
|
from PIL import Image |
|
|
|
|
|
from tagger.common import LabelData, load_labels_hf, preprocess_image |
|
|
from tagger.model import create_session |
|
|
|
|
|
TITLE = "WaifuDiffusion Tagger" |
|
|
DESCRIPTION = """ |
|
|
Tag images with the WaifuDiffusion Tagger models! |
|
|
|
|
|
Primarily used as a backend for a Discord bot. |
|
|
""" |
|
|
HF_TOKEN = getenv("HF_TOKEN", None) |
|
|
|
|
|
MODEL_VARIANTS: dict[str, str] = { |
|
|
"v3": { |
|
|
"SwinV2": "SmilingWolf/wd-swinv2-tagger-v3", |
|
|
"ConvNeXT": "SmilingWolf/wd-convnext-tagger-v3", |
|
|
"ViT": "SmilingWolf/wd-vit-tagger-v3", |
|
|
}, |
|
|
"v2": { |
|
|
"MOAT": "SmilingWolf/wd-v1-4-moat-tagger-v2", |
|
|
"SwinV2": "SmilingWolf/wd-v1-4-swinv2-tagger-v2", |
|
|
"ConvNeXT": "SmilingWolf/wd-v1-4-convnext-tagger-v2", |
|
|
"ConvNeXTv2": "SmilingWolf/wd-v1-4-convnextv2-tagger-v2", |
|
|
"ViT": "SmilingWolf/wd-v1-4-vit-tagger-v2", |
|
|
}, |
|
|
} |
|
|
|
|
|
cache_keys = ["-".join([x, y]) for x in MODEL_VARIANTS.keys() for y in MODEL_VARIANTS[x].keys()] |
|
|
loaded_models: dict[str, Optional[rt.InferenceSession]] = {k: None for k in cache_keys} |
|
|
|
|
|
|
|
|
WORK_DIR = Path(__file__).parent.resolve() if "__file__" in globals() else Path().resolve() |
|
|
|
|
|
IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"] |
|
|
|
|
|
|
|
|
example_images = sorted( |
|
|
[ |
|
|
str(x.relative_to(WORK_DIR)) |
|
|
for x in WORK_DIR.joinpath("examples").iterdir() |
|
|
if x.is_file() and x.suffix.lower() in IMAGE_EXTENSIONS |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
def load_model(version: str, variant: str) -> rt.InferenceSession: |
|
|
global loaded_models |
|
|
|
|
|
|
|
|
model_repo = MODEL_VARIANTS.get(version, {}).get(variant, None) |
|
|
if model_repo is None: |
|
|
raise ValueError(f"Unknown model variant: {version}-{variant}") |
|
|
|
|
|
cache_key = f"{version}-{variant}" |
|
|
if loaded_models.get(cache_key, None) is None: |
|
|
|
|
|
loaded_models[cache_key] = create_session(model_repo, token=HF_TOKEN) |
|
|
|
|
|
return loaded_models[cache_key] |
|
|
|
|
|
|
|
|
def mcut_threshold(probs: np.ndarray) -> float: |
|
|
""" |
|
|
Maximum Cut Thresholding (MCut) |
|
|
Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy |
|
|
for Multi-label Classification. In 11th International Symposium, IDA 2012 |
|
|
(pp. 172-183). |
|
|
""" |
|
|
probs = probs[probs.argsort()[::-1]] |
|
|
diffs = probs[:-1] - probs[1:] |
|
|
idx = diffs.argmax() |
|
|
thresh = (probs[idx] + probs[idx + 1]) / 2 |
|
|
return float(thresh) |
|
|
|
|
|
|
|
|
def predict( |
|
|
image: Image.Image, |
|
|
version: str, |
|
|
variant: str, |
|
|
gen_threshold: float = 0.35, |
|
|
gen_use_mcut: bool = False, |
|
|
char_threshold: float = 0.85, |
|
|
char_use_mcut: bool = False, |
|
|
): |
|
|
|
|
|
model: rt.InferenceSession = load_model(version, variant) |
|
|
|
|
|
labels: LabelData = load_labels_hf(MODEL_VARIANTS[version][variant]) |
|
|
|
|
|
|
|
|
_, h, w, _ = model.get_inputs()[0].shape |
|
|
input_name = model.get_inputs()[0].name |
|
|
output_name = model.get_outputs()[0].name |
|
|
|
|
|
|
|
|
image = preprocess_image(image, (h, w)) |
|
|
|
|
|
|
|
|
inputs = image.convert("RGB").convert("BGR;24") |
|
|
inputs = np.array(inputs).astype(np.float32) |
|
|
inputs = np.expand_dims(inputs, axis=0) |
|
|
|
|
|
|
|
|
probs = model.run([output_name], {input_name: inputs}) |
|
|
|
|
|
|
|
|
probs = list(zip(labels.names, probs[0][0].astype(float))) |
|
|
|
|
|
|
|
|
rating_labels = dict([probs[i] for i in labels.rating]) |
|
|
|
|
|
|
|
|
if gen_use_mcut: |
|
|
gen_array = np.array([probs[i][1] for i in labels.general]) |
|
|
gen_threshold = mcut_threshold(gen_array) |
|
|
|
|
|
gen_labels = [probs[i] for i in labels.general] |
|
|
gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold]) |
|
|
gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True)) |
|
|
|
|
|
|
|
|
if char_use_mcut: |
|
|
char_array = np.array([probs[i][1] for i in labels.character]) |
|
|
char_threshold = round(mcut_threshold(char_array), 2) |
|
|
|
|
|
char_labels = [probs[i] for i in labels.character] |
|
|
char_labels = dict([x for x in char_labels if x[1] > char_threshold]) |
|
|
char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True)) |
|
|
|
|
|
|
|
|
combined_names = [x for x in gen_labels] |
|
|
combined_names.extend([x for x in char_labels]) |
|
|
|
|
|
|
|
|
caption = ", ".join(combined_names) |
|
|
booru = caption.replace("_", " ").replace("(", "\(").replace(")", "\)") |
|
|
|
|
|
return image, caption, booru, rating_labels, char_labels, char_threshold, gen_labels, gen_threshold |
|
|
|
|
|
|
|
|
css = """ |
|
|
#gen_mcut, #char_mcut { |
|
|
padding-top: var(--scale-3); |
|
|
} |
|
|
#gen_threshold.dimmed, #char_threshold.dimmed { |
|
|
filter: brightness(75%); |
|
|
} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(theme="NoCrypt/miku", analytics_enabled=False, title=TITLE, css=css) as demo: |
|
|
with gr.Row(equal_height=False): |
|
|
with gr.Column(min_width=720): |
|
|
with gr.Group(): |
|
|
img_input = gr.Image( |
|
|
label="Input", |
|
|
type="pil", |
|
|
image_mode="RGB", |
|
|
sources=["upload", "clipboard"], |
|
|
) |
|
|
show_processed = gr.Checkbox(label="Show Preprocessed Image", value=False) |
|
|
with gr.Row(): |
|
|
version = gr.Radio( |
|
|
choices=list(MODEL_VARIANTS.keys()), |
|
|
label="Model Version", |
|
|
value="v3", |
|
|
min_width=160, |
|
|
scale=1, |
|
|
) |
|
|
variant = gr.Radio( |
|
|
choices=list(MODEL_VARIANTS[version.value].keys()), |
|
|
label="Model Variant", |
|
|
value="SwinV2", |
|
|
min_width=560, |
|
|
) |
|
|
with gr.Group(): |
|
|
with gr.Row(): |
|
|
gen_threshold = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=0.35, |
|
|
step=0.01, |
|
|
label="General Tag Threshold", |
|
|
scale=5, |
|
|
elem_id="gen_threshold", |
|
|
) |
|
|
gen_mcut = gr.Checkbox(label="Use Max-Cut", value=False, scale=1, elem_id="gen_mcut") |
|
|
with gr.Row(): |
|
|
char_threshold = gr.Slider( |
|
|
minimum=0.0, |
|
|
maximum=1.0, |
|
|
value=0.85, |
|
|
step=0.01, |
|
|
label="Character Tag Threshold", |
|
|
scale=5, |
|
|
elem_id="char_threshold", |
|
|
) |
|
|
char_mcut = gr.Checkbox(label="Use Max-Cut", value=False, scale=1, elem_id="char_mcut") |
|
|
with gr.Row(): |
|
|
clear = gr.ClearButton( |
|
|
components=[], |
|
|
variant="secondary", |
|
|
size="lg", |
|
|
) |
|
|
submit = gr.Button(value="Submit", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(min_width=720): |
|
|
img_output = gr.Image( |
|
|
label="Preprocessed Image", type="pil", image_mode="RGB", scale=1, visible=False |
|
|
) |
|
|
with gr.Group(): |
|
|
caption = gr.Textbox(label="Caption", show_copy_button=True) |
|
|
tags = gr.Textbox(label="Tags", show_copy_button=True) |
|
|
with gr.Group(): |
|
|
rating = gr.Label(label="Rating") |
|
|
with gr.Group(): |
|
|
char_mcut_out = gr.Number(label="Max-Cut Threshold", precision=2, visible=False) |
|
|
character = gr.Label(label="Character") |
|
|
with gr.Group(): |
|
|
gen_mcut_out = gr.Number(label="Max-Cut Threshold", precision=2, visible=False) |
|
|
general = gr.Label(label="General") |
|
|
|
|
|
with gr.Row(): |
|
|
examples = [[imgpath, 0.35, mc, 0.85, mc] for mc in [False, True] for imgpath in example_images] |
|
|
|
|
|
examples = gr.Examples( |
|
|
examples=examples, |
|
|
inputs=[img_input, gen_threshold, gen_mcut, char_threshold, char_mcut], |
|
|
) |
|
|
|
|
|
|
|
|
clear.add([img_input, img_output, caption, rating, character, general]) |
|
|
|
|
|
def on_select_variant(evt: gr.SelectData, variant: str): |
|
|
if evt.selected: |
|
|
choices = list(MODEL_VARIANTS[variant]) |
|
|
return gr.update(choices=choices, value=choices[0]) |
|
|
return gr.update() |
|
|
|
|
|
version.select(on_select_variant, inputs=[version], outputs=[variant]) |
|
|
|
|
|
|
|
|
def on_change_show(val: gr.Checkbox): |
|
|
return gr.update(visible=val) |
|
|
|
|
|
show_processed.select(on_change_show, inputs=[show_processed], outputs=[img_output]) |
|
|
|
|
|
|
|
|
def on_change_mcut(val: gr.Checkbox): |
|
|
return ( |
|
|
gr.update(interactive=not val, elem_classes=["dimmed"] if val else []), |
|
|
gr.update(visible=val), |
|
|
) |
|
|
|
|
|
gen_mcut.change(on_change_mcut, inputs=[gen_mcut], outputs=[gen_threshold, gen_mcut_out]) |
|
|
char_mcut.change(on_change_mcut, inputs=[char_mcut], outputs=[char_threshold, char_mcut_out]) |
|
|
|
|
|
submit.click( |
|
|
predict, |
|
|
inputs=[img_input, version, variant, gen_threshold, gen_mcut, char_threshold, char_mcut], |
|
|
outputs=[img_output, caption, tags, rating, character, char_threshold, general, gen_threshold], |
|
|
api_name="predict", |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue(max_size=10) |
|
|
if getenv("SPACE_ID", None) is not None: |
|
|
demo.launch() |
|
|
else: |
|
|
demo.launch( |
|
|
server_name="0.0.0.0", |
|
|
server_port=7871, |
|
|
) |
|
|
|