Spaces:
Sleeping
Sleeping
| from PIL import Image | |
| import torch | |
| from transformers import ( | |
| AutoImageProcessor, | |
| AutoModelForImageClassification, | |
| ) | |
| import gradio as gr | |
| import spaces # ZERO GPU | |
| MODEL_NAMES = ["p1atdev/wd-swinv2-tagger-v3-hf"] | |
| MODEL_NAME = MODEL_NAMES[0] | |
| model = AutoModelForImageClassification.from_pretrained( | |
| MODEL_NAME, | |
| ) | |
| model.to("cuda" if torch.cuda.is_available() else "cpu") | |
| processor = AutoImageProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| def _people_tag(noun: str, minimum: int = 1, maximum: int = 5): | |
| return ( | |
| [f"1{noun}"] | |
| + [f"{num}{noun}s" for num in range(minimum + 1, maximum + 1)] | |
| + [f"{maximum+1}+{noun}s"] | |
| ) | |
| PEOPLE_TAGS = ( | |
| _people_tag("girl") + _people_tag("boy") + _people_tag("other") + ["no humans"] | |
| ) | |
| RATING_MAP = { | |
| "general": "safe", | |
| "sensitive": "sensitive", | |
| "questionable": "nsfw", | |
| "explicit": "explicit, nsfw", | |
| } | |
| DESCRIPTION_MD = """ | |
| # WD Tagger with 🤗 transformers | |
| Currently supports the following model(s): | |
| - [p1atdev/wd-swinv2-tagger-v3-hf](https://huggingface.co/p1atdev/wd-swinv2-tagger-v3-hf) | |
| """.strip() | |
| def postprocess_results( | |
| results: dict[str, float], general_threshold: float, character_threshold: float | |
| ): | |
| results = { | |
| k: v for k, v in sorted(results.items(), key=lambda item: item[1], reverse=True) | |
| } | |
| rating = {} | |
| character = {} | |
| general = {} | |
| for k, v in results.items(): | |
| if k.startswith("rating:"): | |
| rating[k.replace("rating:", "")] = v | |
| continue | |
| elif k.startswith("character:"): | |
| character[k.replace("character:", "")] = v | |
| continue | |
| general[k] = v | |
| character = {k: v for k, v in character.items() if v >= character_threshold} | |
| general = {k: v for k, v in general.items() if v >= general_threshold} | |
| return rating, character, general | |
| def animagine_prompt(rating: list[str], character: list[str], general: list[str]): | |
| people_tags: list[str] = [] | |
| other_tags: list[str] = [] | |
| rating_tag = RATING_MAP[rating[0]] | |
| for tag in general: | |
| if tag in PEOPLE_TAGS: | |
| people_tags.append(tag) | |
| else: | |
| other_tags.append(tag) | |
| all_tags = people_tags + character + other_tags + [rating_tag] | |
| return ", ".join(all_tags) | |
| def predict_tags( | |
| image: Image.Image, general_threshold: float = 0.3, character_threshold: float = 0.8 | |
| ): | |
| inputs = processor.preprocess(image, return_tensors="pt") | |
| outputs = model(**inputs.to(model.device, model.dtype)) | |
| logits = torch.sigmoid(outputs.logits[0]) # take the first logits | |
| # get probabilities | |
| results = { | |
| model.config.id2label[i]: float(logit.float()) for i, logit in enumerate(logits) | |
| } | |
| # rating, character, general | |
| rating, character, general = postprocess_results( | |
| results, general_threshold, character_threshold | |
| ) | |
| prompt = animagine_prompt( | |
| list(rating.keys()), list(character.keys()), list(general.keys()) | |
| ) | |
| return rating, character, general, prompt | |
| def demo(): | |
| with gr.Blocks() as ui: | |
| gr.Markdown(DESCRIPTION_MD) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(label="Input image", type="pil") | |
| with gr.Group(): | |
| general_threshold = gr.Slider( | |
| label="Threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.3, | |
| step=0.01, | |
| interactive=True, | |
| ) | |
| character_threshold = gr.Slider( | |
| label="Character threshold", | |
| minimum=0.0, | |
| maximum=1.0, | |
| value=0.8, | |
| step=0.01, | |
| interactive=True, | |
| ) | |
| _model_radio = gr.Dropdown( | |
| choices=MODEL_NAMES, | |
| label="Model", | |
| value=MODEL_NAMES[0], | |
| interactive=True, | |
| ) | |
| start_btn = gr.Button(value="Start", variant="primary") | |
| with gr.Column(): | |
| prompt_text = gr.Text(label="Prompt") | |
| rating_tags_label = gr.Label(label="Rating tags") | |
| character_tags_label = gr.Label(label="Character tags") | |
| general_tags_label = gr.Label(label="General tags") | |
| start_btn.click( | |
| predict_tags, | |
| inputs=[input_image, general_threshold, character_threshold], | |
| outputs=[ | |
| rating_tags_label, | |
| character_tags_label, | |
| general_tags_label, | |
| prompt_text, | |
| ], | |
| ) | |
| return ui | |
| if __name__ == "__main__": | |
| demo().queue().launch() | |