Spaces:
Running
Running
| import argparse | |
| import os | |
| import gradio as gr | |
| import huggingface_hub | |
| import numpy as np | |
| import onnxruntime as rt | |
| import pandas as pd | |
| from PIL import Image | |
| TITLE = "WaifuDiffusion Tagger" | |
| DESCRIPTION = """ | |
| Demo for the WaifuDiffusion tagger models | |
| """ | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| # Dataset v3 series of models: | |
| SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3" | |
| CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3" | |
| VIT_MODEL_DSV3_REPO = "ura23/wd-vit-tagger-v3" | |
| VIT_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-vit-large-tagger-v3" | |
| EVA02_LARGE_MODEL_DSV3_REPO = "SmilingWolf/wd-eva02-large-tagger-v3" | |
| # Dataset v2 series of models: | |
| MOAT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-moat-tagger-v2" | |
| SWIN_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-swinv2-tagger-v2" | |
| CONV_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnext-tagger-v2" | |
| CONV2_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-convnextv2-tagger-v2" | |
| VIT_MODEL_DSV2_REPO = "SmilingWolf/wd-v1-4-vit-tagger-v2" | |
| # IdolSankaku series of models: | |
| EVA02_LARGE_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-eva02-large-tagger-v1" | |
| SWINV2_MODEL_IS_DSV1_REPO = "deepghs/idolsankaku-swinv2-tagger-v1" | |
| # Files to download from the repos | |
| MODEL_FILENAME = "model.onnx" | |
| LABEL_FILENAME = "selected_tags.csv" | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--score-slider-step", type=float, default=0.05) | |
| parser.add_argument("--score-general-threshold", type=float, default=0.25) | |
| parser.add_argument("--score-character-threshold", type=float, default=1.0) | |
| return parser.parse_args() | |
| def load_labels(dataframe) -> list[str]: | |
| tag_names = dataframe["name"].tolist() | |
| general_indexes = list(np.where(dataframe["category"] == 0)[0]) | |
| character_indexes = list(np.where(dataframe["category"] == 4)[0]) | |
| return tag_names, general_indexes, character_indexes | |
| class Predictor: | |
| def __init__(self): | |
| self.model_target_size = None | |
| self.last_loaded_repo = None | |
| def download_model(self, model_repo): | |
| csv_path = huggingface_hub.hf_hub_download(model_repo, LABEL_FILENAME, use_auth_token=HF_TOKEN) | |
| model_path = huggingface_hub.hf_hub_download(model_repo, MODEL_FILENAME, use_auth_token=HF_TOKEN) | |
| return csv_path, model_path | |
| def load_model(self, model_repo): | |
| if model_repo == self.last_loaded_repo: | |
| return | |
| csv_path, model_path = self.download_model(model_repo) | |
| tags_df = pd.read_csv(csv_path) | |
| self.tag_names, self.general_indexes, self.character_indexes = load_labels(tags_df) | |
| model = rt.InferenceSession(model_path) | |
| _, height, width, _ = model.get_inputs()[0].shape | |
| self.model_target_size = height | |
| self.last_loaded_repo = model_repo | |
| self.model = model | |
| def prepare_image(self, image): | |
| # Create a white canvas with the same size as the input image | |
| canvas = Image.new("RGBA", image.size, (255, 255, 255)) | |
| # Ensure the input image has an alpha channel for compositing | |
| if image.mode != "RGBA": | |
| image = image.convert("RGBA") | |
| # Composite the input image onto the canvas | |
| canvas.alpha_composite(image) | |
| # Convert to RGB (alpha channel is no longer needed) | |
| image = canvas.convert("RGB") | |
| # Resize the image to a square of size (model_target_size x model_target_size) | |
| max_dim = max(image.size) | |
| padded_image = Image.new("RGB", (max_dim, max_dim), (255, 255, 255)) | |
| pad_left = (max_dim - image.width) // 2 | |
| pad_top = (max_dim - image.height) // 2 | |
| padded_image.paste(image, (pad_left, pad_top)) | |
| padded_image = padded_image.resize((self.model_target_size, self.model_target_size), Image.BICUBIC) | |
| # Convert the image to a NumPy array | |
| image_array = np.asarray(padded_image, dtype=np.float32)[:, :, ::-1] | |
| return np.expand_dims(image_array, axis=0) | |
| def predict(self, images, model_repo, general_thresh, character_thresh): | |
| self.load_model(model_repo) | |
| results = [] | |
| for image in images: | |
| image = self.prepare_image(image) | |
| input_name = self.model.get_inputs()[0].name | |
| label_name = self.model.get_outputs()[0].name | |
| preds = self.model.run([label_name], {input_name: image})[0] | |
| labels = list(zip(self.tag_names, preds[0].astype(float))) | |
| general_res = [x[0] for i, x in enumerate(labels) if i in self.general_indexes and x[1] > general_thresh] | |
| character_res = [x[0] for i, x in enumerate(labels) if i in self.character_indexes and x[1] > character_thresh] | |
| results.append((general_res, character_res)) | |
| return results | |
| def main(): | |
| args = parse_args() | |
| predictor = Predictor() | |
| model_repos = [ | |
| SWINV2_MODEL_DSV3_REPO, | |
| CONV_MODEL_DSV3_REPO, | |
| VIT_MODEL_DSV3_REPO, | |
| VIT_LARGE_MODEL_DSV3_REPO, | |
| EVA02_LARGE_MODEL_DSV3_REPO, | |
| # --- | |
| MOAT_MODEL_DSV2_REPO, | |
| SWIN_MODEL_DSV2_REPO, | |
| CONV_MODEL_DSV2_REPO, | |
| CONV2_MODEL_DSV2_REPO, | |
| VIT_MODEL_DSV2_REPO, | |
| # --- | |
| SWINV2_MODEL_IS_DSV1_REPO, | |
| EVA02_LARGE_MODEL_IS_DSV1_REPO, | |
| ] | |
| predefined_tags = ["loli", | |
| "oppai_loli", | |
| "2024", | |
| "2023", | |
| "2025", | |
| "spot_color", | |
| "holding_sex_toy", | |
| "too_many", | |
| "happy_halloween", | |
| "clothes_writing", | |
| "camera", | |
| "holding_camera", | |
| "selfie", | |
| "anus_peek", | |
| "mature_female", | |
| "copyright_notice", | |
| "puckered_anus", | |
| "multiple_boys", | |
| "alarm_clock", | |
| "clock", | |
| "obliques", | |
| "genderswap", | |
| "genderswap_(otm)", | |
| "genderswap_(otf)", | |
| "genderswap_(mtf)", | |
| "genderswap_(ftm)", | |
| "respirator", | |
| "head-mounted_display", | |
| "2022", | |
| "muscular_female", | |
| "muscular", | |
| "abs", | |
| "2021", | |
| "peeing", | |
| "pee", | |
| "round_eyewear", | |
| "yellow-framed_eyewear", | |
| "hetero", | |
| "vaginal", | |
| "straddling", | |
| "girl_on_top", | |
| "male_pubic_hair", | |
| "cowgirl_position", | |
| "happy_sex", | |
| "vibrator_under_panties", | |
| "vibrator_in_thighhighs", | |
| "anal_beads", | |
| "butt_plug", | |
| "sex_toy", | |
| "anal", | |
| "object_insertion", | |
| "dildo", | |
| "anal_object_insertion", | |
| "vaginal_object_insertion", | |
| "semi-rimless_eyewear", | |
| "red-framed_eyewear", | |
| "under-rim_eyewear", | |
| "3d_background", | |
| "sample_watermark", | |
| "onee-shota", | |
| "incest", | |
| "furry", | |
| "can", | |
| "drinking_can", | |
| "holding_can", | |
| "twitter_strip_game_(meme)", | |
| "like_and_retweet", | |
| "furry_female", | |
| "realistic", | |
| "egg_vibrator", | |
| "tongue_piercing", | |
| "handheld_game_console", | |
| "game_controller", | |
| "nintendo_switch", | |
| "talking", | |
| "swastika", | |
| "sagging_breasts", | |
| "condom", | |
| "novelty_censor", | |
| "no_nipples", | |
| "clitoris", | |
| "sharp_teeth", | |
| "reflection", | |
| "mirror", | |
| "character_name", | |
| "vibrator", | |
| "black-framed_eyewear", | |
| "heterochromia", | |
| "chibi", | |
| "mini_person", | |
| "controller", | |
| "remote_control_vibrator", | |
| "vibrator_under_clothes", | |
| "thank_you", | |
| "vibrator_cord", | |
| "shota", | |
| "cropped_legs", | |
| "cropped_torso", | |
| "traditional_media", | |
| "color_guide", | |
| "photorealistic", | |
| "male_focus", | |
| "black_babydoll", | |
| "signature", | |
| "web_address", | |
| "censored_nipples", | |
| "rhodes_island_logo_(arknights)", | |
| "gothic_lolita", | |
| "glasses", | |
| "reference_inset", | |
| "twitter_logo", | |
| "mother_and_daughter", | |
| "holding_controller", | |
| "holding_game_controller", | |
| "baby", | |
| "heart_censor", | |
| "pixiv_username", | |
| "korean_text", | |
| "pixiv_logo", | |
| "greyscale_with_colored_background", | |
| "water_bottle", | |
| "body_writing", | |
| "used_condom", | |
| "multiple_condoms", | |
| "condom_belt", | |
| "holding_phone", | |
| "multiple_views", | |
| "phone", | |
| "cellphone", | |
| "zoom_layer", | |
| "smartphone", | |
| "lolita_hairband", | |
| "lactation", | |
| "otoko_no_ko", | |
| "minigirl", | |
| "babydoll", | |
| "domino_mask", | |
| "pixiv_id", | |
| "qr_code", | |
| "monochrome", | |
| "trick_or_treat", | |
| "happy_birthday", | |
| "lolita_fashion", | |
| "arrow_(symbol)", | |
| "happy_new_year", | |
| "dated", | |
| "thought_bubble", | |
| "greyscale", | |
| "speech_bubble", | |
| "mask", | |
| "comic", | |
| "bottle", | |
| "holding_bottle", | |
| "milk", | |
| "milk_bottle", | |
| "english_text", | |
| "copyright_name", | |
| "twitter_username", | |
| "fanbox_username", | |
| "patreon_username", | |
| "patreon_logo", | |
| "cover", | |
| "weibo_logo", | |
| "weibo_username", | |
| "signature", | |
| "content_rating", | |
| "cover_page", | |
| "doujin_cover", | |
| "sex", | |
| "artist_name", | |
| "watermark", | |
| "censored", | |
| "bar_censor", | |
| "blank_censor", | |
| "blur_censor", | |
| "light_censor", | |
| "mosaic_censoring"] | |
| with gr.Blocks(title=TITLE) as demo: | |
| gr.Markdown(f"<h1 style='text-align: center;'>{TITLE}</h1>") | |
| gr.Markdown(DESCRIPTION) | |
| with gr.Row(): | |
| with gr.Column(): | |
| submit = gr.Button( | |
| value="Process Images", variant="primary" | |
| ) | |
| image_files = gr.File( | |
| file_types=["image"], label="Upload Images", file_count="multiple", | |
| ) | |
| # Wrap the model selection and sliders in an Accordion | |
| with gr.Accordion("Advanced Settings", open=False): # Collapsible by default | |
| model_repo = gr.Dropdown( | |
| model_repos, | |
| value=VIT_MODEL_DSV3_REPO, | |
| label="Select Model", | |
| ) | |
| general_thresh = gr.Slider( | |
| 0, 1, step=args.score_slider_step, value=args.score_general_threshold, label="General Tags Threshold" | |
| ) | |
| character_thresh = gr.Slider( | |
| 0, 1, step=args.score_slider_step, value=args.score_character_threshold, label="Character Tags Threshold" | |
| ) | |
| filter_tags = gr.Textbox( | |
| value=", ".join(predefined_tags), | |
| label="Filter Tags (comma-separated)", | |
| placeholder="Add tags to filter out (e.g., winter, red, from above)", | |
| lines=9 | |
| ) | |
| with gr.Column(): | |
| output = gr.Textbox(label="Output", lines=10) | |
| def process_images(files, model_repo, general_thresh, character_thresh, filter_tags): | |
| images = [Image.open(file.name) for file in files] | |
| results = predictor.predict(images, model_repo, general_thresh, character_thresh) | |
| # Parse filter tags | |
| filter_set = set(tag.strip().lower() for tag in filter_tags.split(",")) | |
| # Generate formatted output | |
| prompts = [] | |
| for i, (general_tags, character_tags) in enumerate(results): | |
| # Replace underscores with spaces for both character and general tags | |
| character_part = ", ".join( | |
| tag.replace('_', ' ') for tag in character_tags if tag.lower() not in filter_set | |
| ) | |
| general_part = ", ".join( | |
| tag.replace('_', ' ') for tag in general_tags if tag.lower() not in filter_set | |
| ) | |
| # Construct the prompt based on the presence of character_part | |
| if character_part: | |
| prompts.append(f"{character_part}, {general_part}") | |
| else: | |
| prompts.append(general_part) | |
| # Join all prompts with blank lines | |
| return "\n".join(prompts) | |
| submit.click( | |
| process_images, | |
| inputs=[image_files, model_repo, general_thresh, character_thresh, filter_tags], | |
| outputs=output | |
| ) | |
| demo.queue(max_size=10) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() |