Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import json | |
| import yaml | |
| import os | |
| import torch | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from model.pipeline import JiTModel, JiTConfig | |
| from model.config import ClassContextConfig | |
| MODEL_REPO = os.environ.get("MODEL_REPO", "p1atdev/JiT-AnimeFace-experiment") | |
| MODEL_PATH = os.environ.get( | |
| "MODEL_PATH", "jit-b256-p16-cls/12-jit-animeface_00043e_033368s.safetensors" | |
| ) | |
| LABEL2ID_PATH = os.environ.get("LABEL2ID_PATH", "jit-b256-p16-cls/label2id.json") | |
| CONFIG_PATH = os.environ.get("CONFIG_PATH", "jit-b256-p16-cls/config.yml") | |
| DEVICE = ( | |
| torch.device("cuda") | |
| if torch.cuda.is_available() | |
| else torch.device("mps") | |
| if torch.backends.mps.is_available() | |
| else torch.device("cpu") | |
| ) | |
| DTYPE = torch.bfloat16 if DEVICE.type in ["cuda"] else torch.float16 | |
| MAX_TOKEN_LENGTH = 32 | |
| model_map: dict[str, JiTModel] = {} # {model_path: model} | |
| label2id_map: dict[str, dict] = {} # {label2id_path: label2id} | |
| def get_file_path(repo: str, path: str) -> str: | |
| """Hugging Face Hub からファイルを取得""" | |
| return hf_hub_download(repo, path) | |
| def load_label2id(label2id_path: str) -> dict: | |
| """label2id.json を読み込む""" | |
| with open(label2id_path, "r") as f: | |
| return json.load(f) | |
| def load_config(config_path: str) -> JiTConfig: | |
| """設定ファイルを読み込む""" | |
| with open(config_path, "r") as f: | |
| if config_path.endswith(".json"): | |
| config_dict = json.load(f) | |
| elif config_path.endswith((".yaml", ".yml")): | |
| config_dict = yaml.safe_load(f) | |
| else: | |
| raise ValueError("Unsupported config file format. Use .json or .yaml/.yml") | |
| return JiTConfig.model_validate(config_dict) | |
| def load_model( | |
| model_path: str, | |
| label2id_path: str, | |
| config_path: str, | |
| device: torch.device, | |
| dtype: torch.dtype = DTYPE, | |
| ) -> tuple[JiTModel, dict]: | |
| """モデルを読み込む""" | |
| if model_path in model_map: # use cache | |
| model = model_map[model_path] | |
| label2id = label2id_map[label2id_path] | |
| return model, label2id | |
| config = load_config(get_file_path(MODEL_REPO, config_path)) | |
| if isinstance(config.context_encoder, ClassContextConfig): | |
| config.context_encoder.label2id_map_path = get_file_path( | |
| MODEL_REPO, label2id_path | |
| ) | |
| model = JiTModel.from_pretrained( | |
| config=config, | |
| checkpoint_path=get_file_path(MODEL_REPO, model_path), | |
| ) | |
| model.eval() | |
| model.requires_grad_(False) | |
| model.to(device=device, dtype=dtype) | |
| model_map[model_path] = model # cache | |
| label2id = load_label2id(get_file_path(MODEL_REPO, label2id_path)) | |
| label2id_map[label2id_path] = label2id # cache | |
| return model, label2id | |
| def generate_images( | |
| prompt: str, | |
| negative_prompt: str, | |
| num_steps: int, | |
| cfg_scale: float, | |
| batch_size: int, | |
| size: int, | |
| seed: int, | |
| # | |
| model_path: str = MODEL_PATH, | |
| label2id_path: str = LABEL2ID_PATH, | |
| config_path: str = CONFIG_PATH, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| model, _label2id = load_model( | |
| model_path=model_path, | |
| label2id_path=label2id_path, | |
| config_path=config_path, | |
| device=DEVICE, | |
| dtype=DTYPE, | |
| ) | |
| with torch.inference_mode(), torch.autocast(device_type=DEVICE.type, dtype=DTYPE): | |
| images = model.generate( | |
| prompt=[prompt] * batch_size, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=num_steps, | |
| cfg_scale=cfg_scale, | |
| height=size, | |
| width=size, | |
| max_token_length=MAX_TOKEN_LENGTH, | |
| cfg_time_range=[0.1, 1.0], | |
| seed=seed if seed >= 0 else None, | |
| device=DEVICE, | |
| execution_dtype=DTYPE, | |
| ) | |
| return images | |
| LABEL2ID_URL = f"https://huggingface.co/{MODEL_REPO}/blob/main/{LABEL2ID_PATH}" | |
| def demo(): | |
| with gr.Blocks() as ui: | |
| gr.Markdown(f""" | |
| # JiT-AnimeFace Demo | |
| Pixel-space x-prediction flow-matching 90M parameter model for anime face generation, trained from scratch. | |
| - See full supported tags: [label2id.json]({LABEL2ID_URL}). 対応しているタグ一覧は [こちら]({LABEL2ID_URL}) から確認できます。ここに載っていないタグは反応しません。 | |
| - Current model: [{MODEL_PATH}](https://huggingface.co/{MODEL_REPO}/blob/main/{MODEL_PATH}) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.TextArea( | |
| label="Prompt", | |
| info=f"Space-separated tags. Not all of danbooru tags are supported. See [the full supported tags]({LABEL2ID_URL}). スペースで区切ってください。カンマ区切りは対応してません。", | |
| value="general 1girl solo portrait looking_at_viewer medium_hair parted_lips blue_ribbon hair_ornament hairclip half_updo halterneck bokeh depth_of_field blurry_background head_tilt", | |
| placeholder="e.g.: general 1girl solo portrait looking_at_viewer", | |
| ) | |
| negative_prompt = gr.TextArea( | |
| label="Negative Prompt", | |
| info="Space-separated negative tags to avoid in generation. スペースで区切ってください。カンマ区切りは対応してません。", | |
| value="retro_artstyle 1990s_(style) sketch", | |
| lines=2, | |
| placeholder="e.g.: retro_artstyle 1990s_(style) sketch", | |
| ) | |
| num_steps = gr.Slider( | |
| minimum=1, | |
| maximum=100, | |
| value=25, | |
| step=1, | |
| label="Number of Steps", | |
| info="Recommended: more than 20 steps for better quality.", | |
| ) | |
| cfg_scale = gr.Slider( | |
| minimum=1.0, | |
| maximum=10.0, | |
| value=5.0, | |
| step=0.25, | |
| label="CFG Scale", | |
| info="Recommended: more than 2.0 for better adherence to the prompt.", | |
| ) | |
| batch_size = gr.Slider( | |
| minimum=1, | |
| maximum=64, | |
| value=25, | |
| step=1, | |
| label="Batch Size", | |
| info="Number of images to generate in one batch.", | |
| ) | |
| size = gr.Slider( | |
| minimum=64, | |
| maximum=320, | |
| value=256, | |
| step=64, | |
| label="Image Size", | |
| info="Only 256x256 is supported in the current model. Other sizes may cause quality degradation.", | |
| ) | |
| seed = gr.Number( | |
| value=-1, | |
| label="Seed (-1 for random)", | |
| ) | |
| with gr.Column(scale=2): | |
| generate_button = gr.Button("Generate Images", variant="primary") | |
| output_gallery = gr.Gallery( | |
| label="Generated Images", | |
| columns=5, | |
| height="768px", | |
| preview=False, | |
| show_label=True, | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "general 1girl solo portrait looking_at_viewer medium_hair parted_lips blue_ribbon hair_ornament hairclip half_updo halterneck bokeh depth_of_field blurry_background head_tilt", | |
| "retro_artstyle 1990s_(style) sketch", | |
| ], | |
| [ | |
| "general 1girl solo portrait looking_at_viewer", | |
| "retro_artstyle 1990s_(style) sketch", | |
| ], | |
| [ | |
| "general 1girl solo portrait looking_at_viewer blue_hair short_hair blush open_mouth cat_ears animal_ears red_eyes white_background", | |
| "retro_artstyle 1990s_(style) sketch", | |
| ], | |
| [ | |
| "general 1girl aqua_eyes baseball_cap blonde_hair closed_mouth earrings green_background hat jewelry looking_at_viewer shirt short_hair simple_background solo portrait yellow_shirt", | |
| "retro_artstyle 1990s_(style) sketch", | |
| ], | |
| [ | |
| "general 1girl solo portrait looking_at_viewer brown_hair ahoge long_hair :| expressionless closed_mouth swept_bangs pink_eyes pink_background simple_background dutch_angle", | |
| "retro_artstyle 1990s_(style) sketch smile", | |
| ], | |
| [ | |
| "general 1girl solo portrait looking_at_viewer hatsune_miku twintails long_hair blue_eyes one_eye_closed simple_background green_background", | |
| "retro_artstyle 1990s_(style) sketch", | |
| ], | |
| [ | |
| "general 1girl portrait looking_at_viewer sketch head_tilt white_background monochrome open_mouth long_hair", | |
| "retro_artstyle 1990s_(style)", | |
| ], | |
| [ | |
| "general 1girl solo from_behind short_hair simple_background black_background", | |
| "retro_artstyle 1990s_(style) sketch", | |
| ], | |
| [ | |
| "general 1girl portrait looking_to_the_side glasses", | |
| "retro_artstyle 1990s_(style) sketch", | |
| ], | |
| [ | |
| "general 1girl portrait looking_at_viewer cat_ears purple_theme ;d forehead animal_ears animal_ear_fluff cat_ears", | |
| "retro_artstyle 1990s_(style) sketch", | |
| ], | |
| ], | |
| inputs=[prompt, negative_prompt], | |
| label="Examples", | |
| examples_per_page=20, | |
| ) | |
| gr.on( | |
| triggers=[generate_button.click, prompt.submit], | |
| fn=generate_images, | |
| inputs=[ | |
| prompt, | |
| negative_prompt, | |
| num_steps, | |
| cfg_scale, | |
| batch_size, | |
| size, | |
| seed, | |
| ], | |
| outputs=output_gallery, | |
| ) | |
| return ui | |
| if __name__ == "__main__": | |
| load_model( | |
| model_path=MODEL_PATH, | |
| label2id_path=LABEL2ID_PATH, | |
| config_path=CONFIG_PATH, | |
| device=DEVICE, | |
| dtype=DTYPE, | |
| ) | |
| demo().launch() | |