import spaces import json import yaml import os import torch import gradio as gr from huggingface_hub import hf_hub_download from model.pipeline import JiTModel, JiTConfig from model.config import ClassContextConfig MODEL_REPO = os.environ.get("MODEL_REPO", "p1atdev/JiT-AnimeFace-experiment") MODEL_PATH = os.environ.get( "MODEL_PATH", "jit-b256-p16-cls/12-jit-animeface_00043e_033368s.safetensors" ) LABEL2ID_PATH = os.environ.get("LABEL2ID_PATH", "jit-b256-p16-cls/label2id.json") CONFIG_PATH = os.environ.get("CONFIG_PATH", "jit-b256-p16-cls/config.yml") DEVICE = ( torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu") ) MAX_TOKEN_LENGTH = 32 model_map: dict[str, JiTModel] = {} # {model_path: model} label2id_map: dict[str, dict] = {} # {label2id_path: label2id} def get_file_path(repo: str, path: str) -> str: """Hugging Face Hub からファイルを取得""" return hf_hub_download(repo, path) def load_label2id(label2id_path: str) -> dict: """label2id.json を読み込む""" with open(label2id_path, "r") as f: return json.load(f) def load_config(config_path: str) -> JiTConfig: """設定ファイルを読み込む""" with open(config_path, "r") as f: if config_path.endswith(".json"): config_dict = json.load(f) elif config_path.endswith((".yaml", ".yml")): config_dict = yaml.safe_load(f) else: raise ValueError("Unsupported config file format. Use .json or .yaml/.yml") return JiTConfig.model_validate(config_dict) def load_model( model_path: str, label2id_path: str, config_path: str, device: torch.device, ) -> tuple[JiTModel, dict]: """モデルを読み込む""" if model_path in model_map: # use cache model = model_map[model_path] label2id = label2id_map[label2id_path] return model, label2id config = load_config(get_file_path(MODEL_REPO, config_path)) if isinstance(config.context_encoder, ClassContextConfig): config.context_encoder.label2id_map_path = get_file_path( MODEL_REPO, label2id_path ) model = JiTModel.from_pretrained( config=config, checkpoint_path=get_file_path(MODEL_REPO, model_path), ) model.eval() model.requires_grad_(False) model.to(device=device) model_map[model_path] = model # cache label2id = load_label2id(get_file_path(MODEL_REPO, label2id_path)) label2id_map[label2id_path] = label2id # cache return model, label2id @spaces.GPU(duration=5) def generate_images( prompt: str, negative_prompt: str, num_steps: int, cfg_scale: float, batch_size: int, size: int, seed: int, # model_path: str = MODEL_PATH, label2id_path: str = LABEL2ID_PATH, config_path: str = CONFIG_PATH, progress=gr.Progress(track_tqdm=True), ): model, _label2id = load_model( model_path=model_path, label2id_path=label2id_path, config_path=config_path, device=DEVICE, ) with torch.inference_mode(): images = model.generate( prompt=[prompt] * batch_size, negative_prompt=negative_prompt, num_inference_steps=num_steps, cfg_scale=cfg_scale, height=size, width=size, max_token_length=MAX_TOKEN_LENGTH, cfg_time_range=[0.1, 1.0], seed=seed if seed >= 0 else None, device=DEVICE, execution_dtype=model.config.torch_dtype, ) return images def demo(): with gr.Blocks() as ui: gr.Markdown(f""" # JiT-AnimeFace Demo Pixel-space x-prediction flow-matching model for anime face generation, trained from scratch. - See full supported tags: [label2id.json](https://huggingface.co/{MODEL_REPO}/blob/main/{LABEL2ID_PATH}). - Current model: [{MODEL_PATH}](https://huggingface.co/{MODEL_REPO}/blob/main/{MODEL_PATH}) """) with gr.Row(): with gr.Column(): prompt = gr.TextArea( label="Prompt", info="Space-separated tags. Not all of danbooru tags are supported. See the link above for full supported tags.", value="general 1girl solo portrait looking_at_viewer blue_hair short_hair blush cat_ears open_mouth cat_ears animal_ears red_eyes white_background", placeholder="e.g.: general 1girl solo portrait looking_at_viewer", ) negative_prompt = gr.TextArea( label="Negative Prompt", value="retro_artstyle 1990s_(style) sketch", lines=2, placeholder="e.g.: retro_artstyle 1990s_(style) sketch", ) num_steps = gr.Slider( minimum=1, maximum=100, value=25, step=4, label="Number of Steps", ) cfg_scale = gr.Slider( minimum=1.0, maximum=10.0, value=3.0, step=0.25, label="CFG Scale", ) batch_size = gr.Slider( minimum=1, maximum=64, value=16, step=1, label="Batch Size", ) size = gr.Slider( minimum=64, maximum=320, value=256, step=64, label="Image Size", ) seed = gr.Number( value=-1, label="Seed (-1 for random)", ) with gr.Column(scale=2): generate_button = gr.Button("Generate Images", variant="primary") output_gallery = gr.Gallery( label="Generated Images", columns=4, height="768px", preview=False, show_label=True, ) gr.Examples( examples=[ [ "general 1girl solo portrait looking_at_viewer blue_hair short_hair blush cat_ears open_mouth cat_ears animal_ears red_eyes white_background", "retro_artstyle 1990s_(style) sketch", ] ], inputs=[prompt, negative_prompt], label="Examples", ) gr.on( triggers=[generate_button.click, prompt.submit], fn=generate_images, inputs=[ prompt, negative_prompt, num_steps, cfg_scale, batch_size, size, seed, ], outputs=output_gallery, ) return ui if __name__ == "__main__": load_model( model_path=MODEL_PATH, label2id_path=LABEL2ID_PATH, config_path=CONFIG_PATH, device=DEVICE, ) demo().launch()