stevengrove's picture
init
0e0d430
import gradio as gr
import numpy as np
import random
import torch
import spaces
import os
import sys
import tempfile
from pathlib import Path
ROOT_DIR = Path(__file__).resolve().parent
SRC_DIR = ROOT_DIR / "src"
if str(SRC_DIR) not in sys.path:
sys.path.insert(0, str(SRC_DIR))
from huggingface_hub import snapshot_download
from demo_release import (
EditorApp,
build_demo_examples_from_config,
create_demo,
is_rank0,
read_local_js_inline,
resolve_local_three_js,
)
from modules.models.attention import describe_attention_backend
from modules.utils import clean_dist_env, maybe_init_distributed
TRUE_VALUES = {"1", "true", "yes", "y"}
GPU_PRELOAD_MODES = {
"startup",
"startup_preload",
"boot",
"auto",
"eager",
"preload",
"gpu",
"gpu_preload",
"cuda",
"global_cuda",
}
CPU_PRELOAD_MODES = {"cpu_preload", "cpu", "cpu_only", "cpu_global_preload"}
def env_flag(name: str, default: str = "0") -> bool:
return os.getenv(name, default).strip().lower() in TRUE_VALUES
def env_optional_int(name: str) -> int | None:
value = os.getenv(name, "").strip()
return int(value) if value else None
def env_optional_str(name: str) -> str | None:
value = os.getenv(name, "").strip()
return value or None
def resolve_ckpt_root(model_repo_id: str, explicit_ckpt_root: str | None, hf_token: str | None) -> str:
if explicit_ckpt_root:
return explicit_ckpt_root
return snapshot_download(repo_id=model_repo_id, token=hf_token)
def build_app() -> tuple[EditorApp, str, str | None, str | None, bool, bool, bool, int, bool, str]:
model_repo_id = os.getenv("MODEL_REPO_ID", "jdopensource/JoyAI-Image-Edit")
ckpt_root_env = env_optional_str("CKPT_ROOT")
config_path = env_optional_str("CONFIG_PATH")
rewrite_prompt = env_flag("REWRITE_PROMPT")
rewrite_model = os.getenv("REWRITE_MODEL", "gpt-5")
basesize = int(os.getenv("BASESIZE", "1024"))
hide_advanced_options = env_flag("HIDE_ADVANCED_OPTIONS")
auto_pe = env_flag("AUTO_PE")
default_save_dir = os.getenv("DEFAULT_SAVE_DIR", "")
hsdp_shard_dim = env_optional_int("HSDP_SHARD_DIM")
model_load_mode = os.getenv("MODEL_LOAD_MODE", "startup_preload").strip().lower()
hf_token = env_optional_str("HF_TOKEN") or env_optional_str("HUGGING_FACE_HUB_TOKEN")
ckpt_root = resolve_ckpt_root(model_repo_id, ckpt_root_env, hf_token)
app = EditorApp(
ckpt_root=ckpt_root,
config_path=config_path,
rewrite_model=rewrite_model,
hsdp_shard_dim=hsdp_shard_dim,
enable_prompt_rewrite=rewrite_prompt,
basesize=basesize,
device=None,
model_load_mode=model_load_mode,
)
return (
app,
model_repo_id,
ckpt_root,
config_path,
rewrite_prompt,
rewrite_model,
hide_advanced_options,
basesize,
auto_pe,
default_save_dir,
)
def print_startup_info(
*,
model_repo_id: str,
ckpt_root: str,
config_path: str | None,
rewrite_prompt: bool,
rewrite_model: str,
basesize: int,
auto_pe: bool,
hide_advanced_options: bool,
three_js_file: str | None,
) -> None:
if not is_rank0():
return
print("[Info] Direct GPU startup preload is enabled by default; the app will try to build the model globally on CUDA during startup.")
print(f"[Info] Attention backend: {describe_attention_backend()}")
print(f"[Info] MODEL_REPO_ID: {model_repo_id}")
print(f"[Info] CKPT_ROOT: {ckpt_root}")
print(f"[Info] CONFIG_PATH: {config_path or '(auto)'}")
print(f"[Info] REWRITE_PROMPT: {rewrite_prompt}")
print(f"[Info] REWRITE_MODEL: {rewrite_model}")
print(f"[Info] BASESIZE: {basesize}")
print(f"[Info] AUTO_PE: {auto_pe}")
print(f"[Info] HIDE_ADVANCED_OPTIONS: {hide_advanced_options}")
if three_js_file:
print(f"[Info] Using local three.js: {three_js_file}")
else:
print("[Info] No local three.min.js found. Falling back to slider-only mode.")
def maybe_preload(app: EditorApp) -> None:
mode = (app.model_load_mode or "").strip().lower()
if mode in GPU_PRELOAD_MODES:
print("[Model] Using direct global GPU preload mode.")
app.maybe_preload_model()
return
if mode in CPU_PRELOAD_MODES:
print("[Model] Using CPU preload mode.")
app.maybe_preload_model()
return
print(f"[Model] Using runtime loading mode: {mode}")
def build_demo(app: EditorApp, hide_advanced_options: bool, auto_pe: bool, default_save_dir: str):
examples_table, examples_full = build_demo_examples_from_config()
three_js_path = os.getenv("THREE_JS_PATH", str(ROOT_DIR / "three.min.js"))
three_js_file = resolve_local_three_js(three_js_path if Path(three_js_path).exists() else None)
inline_js = read_local_js_inline(three_js_file)
demo, _, page_css = create_demo(
app,
three_available=three_js_file is not None,
hide_advanced_options=hide_advanced_options,
examples_table=examples_table,
examples_full=examples_full,
auto_pe=auto_pe,
default_save_dir=default_save_dir,
)
launch_css = page_css + "\n.fillable{max-width: 1400px !important}"
allowed_paths = [
str(Path(tempfile.gettempdir()).resolve()),
str((ROOT_DIR / "images").resolve()),
]
return demo, inline_js, launch_css, allowed_paths, three_js_file
def main() -> None:
dist_initialized = maybe_init_distributed()
app, model_repo_id, ckpt_root, config_path, rewrite_prompt, rewrite_model, hide_advanced_options, basesize, auto_pe, default_save_dir = build_app()
demo, inline_js, launch_css, allowed_paths, three_js_file = build_demo(
app,
hide_advanced_options=hide_advanced_options,
auto_pe=auto_pe,
default_save_dir=default_save_dir,
)
print_startup_info(
model_repo_id=model_repo_id,
ckpt_root=ckpt_root,
config_path=config_path,
rewrite_prompt=rewrite_prompt,
rewrite_model=rewrite_model,
basesize=basesize,
auto_pe=auto_pe,
hide_advanced_options=hide_advanced_options,
three_js_file=three_js_file,
)
maybe_preload(app)
try:
demo.queue(default_concurrency_limit=1, max_size=20).launch(
server_name="0.0.0.0",
server_port=int(os.getenv("PORT", "7860")),
ssr_mode=False,
head=inline_js,
css=launch_css,
allowed_paths=allowed_paths,
)
finally:
if dist_initialized:
clean_dist_env()
if __name__ == "__main__":
main()