Spaces:
Runtime error
Runtime error
| from os import getenv | |
| from pathlib import Path | |
| from typing import Optional | |
| import gradio as gr | |
| import numpy as np | |
| import onnxruntime as rt | |
| from PIL import Image | |
| from tagger.common import LabelData, load_labels_hf, preprocess_image | |
| from tagger.model import create_session | |
| TITLE = "WaifuDiffusion Tagger" | |
| DESCRIPTION = """ | |
| Tag images with the WaifuDiffusion Tagger models! | |
| Primarily used as a backend for a Discord bot. | |
| """ | |
| HF_TOKEN = getenv("HF_TOKEN", None) | |
| MODEL_VARIANTS: dict[str, str] = { | |
| "v3": { | |
| "SwinV2": "SmilingWolf/wd-swinv2-tagger-v3", | |
| "ConvNeXT": "SmilingWolf/wd-convnext-tagger-v3", | |
| "ViT": "SmilingWolf/wd-vit-tagger-v3", | |
| }, | |
| "v2": { | |
| "MOAT": "SmilingWolf/wd-v1-4-moat-tagger-v2", | |
| "SwinV2": "SmilingWolf/wd-v1-4-swinv2-tagger-v2", | |
| "ConvNeXT": "SmilingWolf/wd-v1-4-convnext-tagger-v2", | |
| "ConvNeXTv2": "SmilingWolf/wd-v1-4-convnextv2-tagger-v2", | |
| "ViT": "SmilingWolf/wd-v1-4-vit-tagger-v2", | |
| }, | |
| } | |
| # prepopulate cache keys in model cache | |
| cache_keys = ["-".join([x, y]) for x in MODEL_VARIANTS.keys() for y in MODEL_VARIANTS[x].keys()] | |
| loaded_models: dict[str, Optional[rt.InferenceSession]] = {k: None for k in cache_keys} | |
| # get the repo root (or the current working directory if running in ipython) | |
| WORK_DIR = Path(__file__).parent.resolve() if "__file__" in globals() else Path().resolve() | |
| # allowed extensions | |
| IMAGE_EXTENSIONS = [".jpg", ".jpeg", ".png", ".gif", ".webp", ".bmp", ".tiff", ".tif"] | |
| # get the example images | |
| example_images = sorted( | |
| [ | |
| str(x.relative_to(WORK_DIR)) | |
| for x in WORK_DIR.joinpath("examples").iterdir() | |
| if x.is_file() and x.suffix.lower() in IMAGE_EXTENSIONS | |
| ] | |
| ) | |
| def load_model(version: str, variant: str) -> rt.InferenceSession: | |
| global loaded_models | |
| # resolve the repo name | |
| model_repo = MODEL_VARIANTS.get(version, {}).get(variant, None) | |
| if model_repo is None: | |
| raise ValueError(f"Unknown model variant: {version}-{variant}") | |
| cache_key = f"{version}-{variant}" | |
| if loaded_models.get(cache_key, None) is None: | |
| # save model to cache | |
| loaded_models[cache_key] = create_session(model_repo, token=HF_TOKEN) | |
| return loaded_models[cache_key] | |
| def mcut_threshold(probs: np.ndarray) -> float: | |
| """ | |
| Maximum Cut Thresholding (MCut) | |
| Largeron, C., Moulin, C., & Gery, M. (2012). MCut: A Thresholding Strategy | |
| for Multi-label Classification. In 11th International Symposium, IDA 2012 | |
| (pp. 172-183). | |
| """ | |
| probs = probs[probs.argsort()[::-1]] | |
| diffs = probs[:-1] - probs[1:] | |
| idx = diffs.argmax() | |
| thresh = (probs[idx] + probs[idx + 1]) / 2 | |
| return float(thresh) | |
| def predict( | |
| image: Image.Image, | |
| version: str, | |
| variant: str, | |
| gen_threshold: float = 0.35, | |
| gen_use_mcut: bool = False, | |
| char_threshold: float = 0.85, | |
| char_use_mcut: bool = False, | |
| ): | |
| # join variant for cache key | |
| model: rt.InferenceSession = load_model(version, variant) | |
| # load labels | |
| labels: LabelData = load_labels_hf(MODEL_VARIANTS[version][variant]) | |
| # get input size and name | |
| _, h, w, _ = model.get_inputs()[0].shape | |
| input_name = model.get_inputs()[0].name | |
| output_name = model.get_outputs()[0].name | |
| # preprocess image | |
| image = preprocess_image(image, (h, w)) | |
| # turn into BGR24 numpy array of N,H,W,C since thats what these want | |
| inputs = image.convert("RGB").convert("BGR;24") | |
| inputs = np.array(inputs).astype(np.float32) | |
| inputs = np.expand_dims(inputs, axis=0) | |
| # Run the ONNX model | |
| probs = model.run([output_name], {input_name: inputs}) | |
| # Convert indices+probs to labels | |
| probs = list(zip(labels.names, probs[0][0].astype(float))) | |
| # First 4 labels are actually ratings | |
| rating_labels = dict([probs[i] for i in labels.rating]) | |
| # General labels, pick any where prediction confidence > threshold | |
| if gen_use_mcut: | |
| gen_array = np.array([probs[i][1] for i in labels.general]) | |
| gen_threshold = mcut_threshold(gen_array) | |
| gen_labels = [probs[i] for i in labels.general] | |
| gen_labels = dict([x for x in gen_labels if x[1] > gen_threshold]) | |
| gen_labels = dict(sorted(gen_labels.items(), key=lambda item: item[1], reverse=True)) | |
| # Character labels, pick any where prediction confidence > threshold | |
| if char_use_mcut: | |
| char_array = np.array([probs[i][1] for i in labels.character]) | |
| char_threshold = round(mcut_threshold(char_array), 2) | |
| char_labels = [probs[i] for i in labels.character] | |
| char_labels = dict([x for x in char_labels if x[1] > char_threshold]) | |
| char_labels = dict(sorted(char_labels.items(), key=lambda item: item[1], reverse=True)) | |
| # Combine general and character labels, sort by confidence | |
| combined_names = [x for x in gen_labels] | |
| combined_names.extend([x for x in char_labels]) | |
| # Convert to a string suitable for use as a training caption | |
| caption = ", ".join(combined_names) | |
| booru = caption.replace("_", " ").replace("(", "\(").replace(")", "\)") | |
| return image, caption, booru, rating_labels, char_labels, char_threshold, gen_labels, gen_threshold | |
| css = """ | |
| #gen_mcut, #char_mcut { | |
| padding-top: var(--scale-3); | |
| } | |
| #gen_threshold.dimmed, #char_threshold.dimmed { | |
| filter: brightness(75%); | |
| } | |
| """ | |
| with gr.Blocks(theme="NoCrypt/miku", analytics_enabled=False, title=TITLE, css=css) as demo: | |
| with gr.Row(equal_height=False): | |
| with gr.Column(min_width=720): | |
| with gr.Group(): | |
| img_input = gr.Image( | |
| label="Input", | |
| type="pil", | |
| image_mode="RGB", | |
| sources=["upload", "clipboard"], | |
| ) | |
| show_processed = gr.Checkbox(label="Show Preprocessed Image", value=False) | |
| with gr.Row(): | |
| version = gr.Radio( | |
| choices=list(MODEL_VARIANTS.keys()), | |
| label="Model Version", | |
| value="v3", | |
| min_width=160, | |
| scale=1, | |
| ) # gen_threshold > div.wrap.hide | |
| variant = gr.Radio( | |
| choices=list(MODEL_VARIANTS[version.value].keys()), | |
| label="Model Variant", | |
| value="SwinV2", | |
| min_width=560, | |
| ) | |
| with gr.Group(): | |
| with gr.Row(): | |
| gen_threshold = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.35, | |
| step=0.01, | |
| label="General Tag Threshold", | |
| scale=5, | |
| elem_id="gen_threshold", | |
| ) | |
| gen_mcut = gr.Checkbox(label="Use Max-Cut", value=False, scale=1, elem_id="gen_mcut") | |
| with gr.Row(): | |
| char_threshold = gr.Slider( | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.85, | |
| step=0.01, | |
| label="Character Tag Threshold", | |
| scale=5, | |
| elem_id="char_threshold", | |
| ) | |
| char_mcut = gr.Checkbox(label="Use Max-Cut", value=False, scale=1, elem_id="char_mcut") | |
| with gr.Row(): | |
| clear = gr.ClearButton( | |
| components=[], | |
| variant="secondary", | |
| size="lg", | |
| ) | |
| submit = gr.Button(value="Submit", variant="primary", size="lg") | |
| with gr.Column(min_width=720): | |
| img_output = gr.Image( | |
| label="Preprocessed Image", type="pil", image_mode="RGB", scale=1, visible=False | |
| ) | |
| with gr.Group(): | |
| caption = gr.Textbox(label="Caption", show_copy_button=True) | |
| tags = gr.Textbox(label="Tags", show_copy_button=True) | |
| with gr.Group(): | |
| rating = gr.Label(label="Rating") | |
| with gr.Group(): | |
| char_mcut_out = gr.Number(label="Max-Cut Threshold", precision=2, visible=False) | |
| character = gr.Label(label="Character") | |
| with gr.Group(): | |
| gen_mcut_out = gr.Number(label="Max-Cut Threshold", precision=2, visible=False) | |
| general = gr.Label(label="General") | |
| with gr.Row(): | |
| examples = [[imgpath, 0.35, mc, 0.85, mc] for mc in [False, True] for imgpath in example_images] | |
| examples = gr.Examples( | |
| examples=examples, | |
| inputs=[img_input, gen_threshold, gen_mcut, char_threshold, char_mcut], | |
| ) | |
| # tell clear button which components to clear | |
| clear.add([img_input, img_output, caption, rating, character, general]) | |
| def on_select_variant(evt: gr.SelectData, variant: str): | |
| if evt.selected: | |
| choices = list(MODEL_VARIANTS[variant]) | |
| return gr.update(choices=choices, value=choices[0]) | |
| return gr.update() | |
| version.select(on_select_variant, inputs=[version], outputs=[variant]) | |
| # show/hide processed image | |
| def on_change_show(val: gr.Checkbox): | |
| return gr.update(visible=val) | |
| show_processed.select(on_change_show, inputs=[show_processed], outputs=[img_output]) | |
| # handle mcut thresholding (auto-calculate threshold from probs, disable slider) | |
| def on_change_mcut(val: gr.Checkbox): | |
| return ( | |
| gr.update(interactive=not val, elem_classes=["dimmed"] if val else []), | |
| gr.update(visible=val), | |
| ) | |
| gen_mcut.change(on_change_mcut, inputs=[gen_mcut], outputs=[gen_threshold, gen_mcut_out]) | |
| char_mcut.change(on_change_mcut, inputs=[char_mcut], outputs=[char_threshold, char_mcut_out]) | |
| submit.click( | |
| predict, | |
| inputs=[img_input, version, variant, gen_threshold, gen_mcut, char_threshold, char_mcut], | |
| outputs=[img_output, caption, tags, rating, character, char_threshold, general, gen_threshold], | |
| api_name="predict", | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=10) | |
| if getenv("SPACE_ID", None) is not None: | |
| demo.launch() | |
| else: | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7871, | |
| ) | |