import gradio as gr import numpy as np import random import torch import spaces import os import sys import tempfile from pathlib import Path ROOT_DIR = Path(__file__).resolve().parent SRC_DIR = ROOT_DIR / "src" if str(SRC_DIR) not in sys.path: sys.path.insert(0, str(SRC_DIR)) from huggingface_hub import snapshot_download from demo_release import ( EditorApp, build_demo_examples_from_config, create_demo, is_rank0, read_local_js_inline, resolve_local_three_js, ) from modules.models.attention import describe_attention_backend from modules.utils import clean_dist_env, maybe_init_distributed TRUE_VALUES = {"1", "true", "yes", "y"} GPU_PRELOAD_MODES = { "startup", "startup_preload", "boot", "auto", "eager", "preload", "gpu", "gpu_preload", "cuda", "global_cuda", } CPU_PRELOAD_MODES = {"cpu_preload", "cpu", "cpu_only", "cpu_global_preload"} def env_flag(name: str, default: str = "0") -> bool: return os.getenv(name, default).strip().lower() in TRUE_VALUES def env_optional_int(name: str) -> int | None: value = os.getenv(name, "").strip() return int(value) if value else None def env_optional_str(name: str) -> str | None: value = os.getenv(name, "").strip() return value or None def resolve_ckpt_root(model_repo_id: str, explicit_ckpt_root: str | None, hf_token: str | None) -> str: if explicit_ckpt_root: return explicit_ckpt_root return snapshot_download(repo_id=model_repo_id, token=hf_token) def build_app() -> tuple[EditorApp, str, str | None, str | None, bool, bool, bool, int, bool, str]: model_repo_id = os.getenv("MODEL_REPO_ID", "jdopensource/JoyAI-Image-Edit") ckpt_root_env = env_optional_str("CKPT_ROOT") config_path = env_optional_str("CONFIG_PATH") rewrite_prompt = env_flag("REWRITE_PROMPT") rewrite_model = os.getenv("REWRITE_MODEL", "gpt-5") basesize = int(os.getenv("BASESIZE", "1024")) hide_advanced_options = env_flag("HIDE_ADVANCED_OPTIONS") auto_pe = env_flag("AUTO_PE") default_save_dir = os.getenv("DEFAULT_SAVE_DIR", "") hsdp_shard_dim = env_optional_int("HSDP_SHARD_DIM") model_load_mode = os.getenv("MODEL_LOAD_MODE", "startup_preload").strip().lower() hf_token = env_optional_str("HF_TOKEN") or env_optional_str("HUGGING_FACE_HUB_TOKEN") ckpt_root = resolve_ckpt_root(model_repo_id, ckpt_root_env, hf_token) app = EditorApp( ckpt_root=ckpt_root, config_path=config_path, rewrite_model=rewrite_model, hsdp_shard_dim=hsdp_shard_dim, enable_prompt_rewrite=rewrite_prompt, basesize=basesize, device=None, model_load_mode=model_load_mode, ) return ( app, model_repo_id, ckpt_root, config_path, rewrite_prompt, rewrite_model, hide_advanced_options, basesize, auto_pe, default_save_dir, ) def print_startup_info( *, model_repo_id: str, ckpt_root: str, config_path: str | None, rewrite_prompt: bool, rewrite_model: str, basesize: int, auto_pe: bool, hide_advanced_options: bool, three_js_file: str | None, ) -> None: if not is_rank0(): return print("[Info] Direct GPU startup preload is enabled by default; the app will try to build the model globally on CUDA during startup.") print(f"[Info] Attention backend: {describe_attention_backend()}") print(f"[Info] MODEL_REPO_ID: {model_repo_id}") print(f"[Info] CKPT_ROOT: {ckpt_root}") print(f"[Info] CONFIG_PATH: {config_path or '(auto)'}") print(f"[Info] REWRITE_PROMPT: {rewrite_prompt}") print(f"[Info] REWRITE_MODEL: {rewrite_model}") print(f"[Info] BASESIZE: {basesize}") print(f"[Info] AUTO_PE: {auto_pe}") print(f"[Info] HIDE_ADVANCED_OPTIONS: {hide_advanced_options}") if three_js_file: print(f"[Info] Using local three.js: {three_js_file}") else: print("[Info] No local three.min.js found. Falling back to slider-only mode.") def maybe_preload(app: EditorApp) -> None: mode = (app.model_load_mode or "").strip().lower() if mode in GPU_PRELOAD_MODES: print("[Model] Using direct global GPU preload mode.") app.maybe_preload_model() return if mode in CPU_PRELOAD_MODES: print("[Model] Using CPU preload mode.") app.maybe_preload_model() return print(f"[Model] Using runtime loading mode: {mode}") def build_demo(app: EditorApp, hide_advanced_options: bool, auto_pe: bool, default_save_dir: str): examples_table, examples_full = build_demo_examples_from_config() three_js_path = os.getenv("THREE_JS_PATH", str(ROOT_DIR / "three.min.js")) three_js_file = resolve_local_three_js(three_js_path if Path(three_js_path).exists() else None) inline_js = read_local_js_inline(three_js_file) demo, _, page_css = create_demo( app, three_available=three_js_file is not None, hide_advanced_options=hide_advanced_options, examples_table=examples_table, examples_full=examples_full, auto_pe=auto_pe, default_save_dir=default_save_dir, ) launch_css = page_css + "\n.fillable{max-width: 1400px !important}" allowed_paths = [ str(Path(tempfile.gettempdir()).resolve()), str((ROOT_DIR / "images").resolve()), ] return demo, inline_js, launch_css, allowed_paths, three_js_file def main() -> None: dist_initialized = maybe_init_distributed() app, model_repo_id, ckpt_root, config_path, rewrite_prompt, rewrite_model, hide_advanced_options, basesize, auto_pe, default_save_dir = build_app() demo, inline_js, launch_css, allowed_paths, three_js_file = build_demo( app, hide_advanced_options=hide_advanced_options, auto_pe=auto_pe, default_save_dir=default_save_dir, ) print_startup_info( model_repo_id=model_repo_id, ckpt_root=ckpt_root, config_path=config_path, rewrite_prompt=rewrite_prompt, rewrite_model=rewrite_model, basesize=basesize, auto_pe=auto_pe, hide_advanced_options=hide_advanced_options, three_js_file=three_js_file, ) maybe_preload(app) try: demo.queue(default_concurrency_limit=1, max_size=20).launch( server_name="0.0.0.0", server_port=int(os.getenv("PORT", "7860")), ssr_mode=False, head=inline_js, css=launch_css, allowed_paths=allowed_paths, ) finally: if dist_initialized: clean_dist_env() if __name__ == "__main__": main()