Spaces:
Paused
Paused
| try: | |
| import flash_attn | |
| except: | |
| import subprocess | |
| print("Installing flash-attn...") | |
| subprocess.run( | |
| "pip install flash-attn --no-build-isolation", | |
| env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
| shell=True, | |
| ) | |
| import flash_attn | |
| print("flash-attn installed.") | |
| import os | |
| import time | |
| import spaces | |
| import torch | |
| from transformers import ( | |
| AutoModelForPreTraining, | |
| AutoProcessor, | |
| AutoConfig, | |
| PreTrainedTokenizerFast, | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| from safetensors.torch import load_file | |
| import gradio as gr | |
| MODEL_NAME = os.environ.get("MODEL_NAME", None) | |
| assert MODEL_NAME is not None | |
| MODEL_PATH = hf_hub_download(repo_id=MODEL_NAME, filename="model.safetensors") | |
| DEVICE = ( | |
| torch.device("mps") if torch.backends.mps.is_available() else torch.device("cuda") | |
| ) | |
| BAD_WORD_KEYWORDS = ["(medium)", " text", "(style)"] | |
| def get_bad_words_ids(tokenizer: PreTrainedTokenizerFast): | |
| ids = [ | |
| [id] | |
| for token, id in tokenizer.vocab.items() | |
| if any(word in token for word in BAD_WORD_KEYWORDS) | |
| ] | |
| return ids | |
| def prepare_models(): | |
| model = AutoModelForPreTraining.from_pretrained( | |
| MODEL_NAME, torch_dtype=torch.bfloat16, trust_remote_code=True | |
| ) | |
| model.decoder_model.use_cache = True | |
| processor = AutoProcessor.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| model.eval() | |
| model = model.to(DEVICE) | |
| # model = torch.compile(model) | |
| return model, processor | |
| def demo(): | |
| model, processor = prepare_models() | |
| ban_ids = get_bad_words_ids(processor.decoder_tokenizer) | |
| translation_mode_map = { | |
| "translate": "exact", | |
| "translate + extend": "approx", | |
| } | |
| def generate_tags( | |
| text: str, | |
| auto_detect: bool, | |
| mode: str = "translate", | |
| copyright_tags: str = "", | |
| length: str = "short", | |
| max_new_tokens: int = 128, | |
| do_sample: bool = False, | |
| temperature: float = 0.1, | |
| top_k: int = 10, | |
| top_p: float = 0.1, | |
| ): | |
| tag_text = ( | |
| "<|bos|>" | |
| f"<|aspect_ratio:tall|><|rating:general|><|length:{length}|>" | |
| "<|reserved_2|><|reserved_3|><|reserved_4|>" | |
| f"<|translate:{translation_mode_map[mode]}|><|input_end|>" | |
| "<copyright>" + copyright_tags.strip() | |
| ) | |
| if not auto_detect: | |
| tag_text += "</copyright><character></character><general>" | |
| inputs = processor( | |
| encoder_text=text, decoder_text=tag_text, return_tensors="pt" | |
| ) | |
| start_time = time.time() | |
| outputs = model.generate( | |
| input_ids=inputs["input_ids"].to(model.device), | |
| attention_mask=inputs["attention_mask"].to(model.device), | |
| encoder_input_ids=inputs["encoder_input_ids"].to(model.device), | |
| encoder_attention_mask=inputs["encoder_attention_mask"].to(model.device), | |
| max_new_tokens=max_new_tokens, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| no_repeat_ngram_size=1, | |
| eos_token_id=processor.decoder_tokenizer.eos_token_id, | |
| pad_token_id=processor.decoder_tokenizer.pad_token_id, | |
| bad_words_ids=ban_ids, | |
| ) | |
| elapsed = time.time() - start_time | |
| deocded = ", ".join( | |
| [ | |
| tag | |
| for tag in processor.batch_decode(outputs[0], skip_special_tokens=True) | |
| if tag.strip() != "" | |
| ] | |
| ) | |
| return [deocded, f"Time elapsed: {elapsed:.2f} seconds"] | |
| # warmup | |
| print("warming up...") | |
| print(generate_tags("Hatsune Miku is looking at viewer.", True)) | |
| print("done.") | |
| with gr.Blocks() as ui: | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(): | |
| text = gr.Text( | |
| label="Text", | |
| info="Enter a prompt in natural language (currently only English is supported). But maybe danbooru tags are also supported.", | |
| lines=4, | |
| placeholder="A girl with fox ears and tail in maid costume is looking at viewer.", | |
| ) | |
| auto_detect = gr.Checkbox( | |
| label="Auto detect copyright tags.", value=False | |
| ) | |
| copyright_tags = gr.Textbox( | |
| label="Copyright tags", | |
| info="You can specify copyright tags manually. This must be valid danbooru tags.", | |
| placeholder="e.g.) vocaloid, blue archive", | |
| ) | |
| length = gr.Dropdown( | |
| label="Length", | |
| choices=[ | |
| "very_short", | |
| "short", | |
| "long", | |
| "very_long", | |
| ], | |
| value="short", | |
| ) | |
| translation_mode = gr.Radio( | |
| label="Translation mode", | |
| choices=list(translation_mode_map.keys()), | |
| value=list(translation_mode_map.keys())[0], | |
| ) | |
| translate_btn = gr.Button(value="Translate", variant="primary") | |
| with gr.Accordion(label="Advanced", open=False): | |
| max_new_tokens = gr.Number(label="Max new tokens", value=128) | |
| do_sample = gr.Checkbox(label="Do sample", value=False) | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.3, | |
| step=0.1, | |
| ) | |
| top_k = gr.Slider( | |
| label="Top k", | |
| minimum=1, | |
| maximum=100, | |
| value=10, | |
| step=10, | |
| ) | |
| top_p = gr.Slider( | |
| label="Top p", | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.5, | |
| step=0.1, | |
| ) | |
| with gr.Column(): | |
| output_translation = gr.Textbox(label="Output", lines=4, interactive=False) | |
| # output_extension = gr.Textbox(label="Output (extension)", lines=4, interactive=False) | |
| time_elapsed = gr.Markdown(value="") | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "猫耳で黒髪ロング、黄色い目で制服を着た少女がこっちを見てる。青背景で白い枠がついてる。ソファに座って足を組んでいる。", | |
| False, | |
| "", | |
| "very_short", | |
| "translate", | |
| ], | |
| [ | |
| "猫耳で黒髪ロング、黄色い目で制服を着た少女がこっちを見てる。青背景で白い枠がついてる。ソファに座って足を組んでいる。", | |
| False, | |
| "", | |
| "long", | |
| "translate + extend", | |
| ], | |
| [ | |
| "猫耳少女のポートレート。:3 ", | |
| False, | |
| "", | |
| "very_short", | |
| "translate + extend", | |
| ], | |
| [ | |
| "学園アイドルマスター。ジャージを着た篠澤広が疲れ切っており、床に座って笑いながらこっちを見ている", | |
| True, | |
| "", | |
| "short", | |
| "translate", | |
| ], | |
| [ | |
| "ガールズバンドクライの井芹ニナと桃華。シンプル背景。小指を立ててこっちを向いている。feet out of frame", | |
| True, | |
| "", | |
| "long", | |
| "translate + extend", | |
| ], | |
| [ | |
| "夜の暗い路地で、黒い服に身を包んだ女がこっちを振り返っている。白いシャツとネクタイ、ジャケットに、手袋をしている", | |
| False, | |
| "", | |
| "long", | |
| "translate + extend", | |
| ], | |
| [ | |
| "一人の少女の横顔で、全体的に赤い雰囲気。髪は肩までの長さで、横を向いている。", | |
| False, | |
| "", | |
| "short", | |
| "translate + extend", | |
| ], | |
| [ | |
| "二人の少女がいる。一人は、blonde hair で long hair、もう一人は brown hair で short hair。二人とも制服。少なくとも片方はブレザーを着ている。場所は教室で、窓から日差しが差し込んでいる。cowboy shot。一人は机に座っていて、もう一人は立っている。", | |
| False, | |
| "", | |
| "long", | |
| "translate + extend", | |
| ], | |
| ], | |
| inputs=[text, auto_detect, copyright_tags, length, translation_mode], | |
| ) | |
| gr.on( | |
| triggers=[ | |
| translate_btn.click, | |
| ], | |
| fn=generate_tags, | |
| inputs=[ | |
| text, | |
| auto_detect, | |
| translation_mode, | |
| copyright_tags, | |
| length, | |
| max_new_tokens, | |
| do_sample, | |
| temperature, | |
| top_k, | |
| top_p, | |
| ], | |
| outputs=[ | |
| output_translation, | |
| # output_extension, | |
| time_elapsed, | |
| ], | |
| ) | |
| ui.launch() | |
| if __name__ == "__main__": | |
| demo() | |