Spaces:
Running on Zero
Running on Zero
| import asyncio | |
| import os | |
| import sys | |
| import threading | |
| import time | |
| import gradio as gr | |
| ROOT_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| GROUPED_GEMM_SRC = os.path.join(ROOT_DIR, "third_party", "grouped_gemm") | |
| MEGABLOCKS_SRC = os.path.join(ROOT_DIR, "third_party", "megablocks") | |
| for _path in (GROUPED_GEMM_SRC, MEGABLOCKS_SRC): | |
| if os.path.isdir(_path) and _path not in sys.path: | |
| sys.path.insert(0, _path) | |
| from src.utils.device_utils import resolve_device_ids | |
| from src.utils.inference_config import ( | |
| DEFAULT_HEIGHT, | |
| DEFAULT_NEGATIVE_PROMPT, | |
| DEFAULT_NUM_INFERENCE_STEPS, | |
| DEFAULT_SEED, | |
| DEFAULT_TRUE_CFG_SCALE, | |
| DEFAULT_WIDTH, | |
| generate_random_seed, | |
| ) | |
| try: | |
| import spaces | |
| except ImportError: | |
| spaces = None | |
| def _suppress_asyncio_fd_noise_for_py310() -> None: | |
| if sys.version_info[:2] != (3, 10): | |
| return | |
| original_del = getattr(asyncio.BaseEventLoop, "__del__", None) | |
| if original_del is None: | |
| return | |
| if getattr(original_del, "_tagmoe_fd_noise_patched", False): | |
| return | |
| def _safe_del(self): | |
| try: | |
| original_del(self) | |
| except ValueError as exc: | |
| if "Invalid file descriptor: -1" not in str(exc): | |
| raise | |
| _safe_del._tagmoe_fd_noise_patched = True | |
| asyncio.BaseEventLoop.__del__ = _safe_del | |
| _suppress_asyncio_fd_noise_for_py310() | |
| def _env_bool(name: str, default: bool = False) -> bool: | |
| value = os.getenv(name) | |
| if value is None: | |
| return default | |
| return value.strip().lower() in {"1", "true", "yes", "on"} | |
| def _env_int(name: str, default: int) -> int: | |
| value = os.getenv(name) | |
| if value is None or not value.strip(): | |
| return default | |
| return int(value.strip()) | |
| PRETRAINED_MODEL_PATH = os.getenv("PRETRAINED_MODEL_PATH", "Qwen/Qwen-Image") | |
| TRANSFORMER_MODEL_PATH = os.getenv("TRANSFORMER_MODEL_PATH", "YUXU915/TAG-MoE") | |
| TRANSFORMER_WEIGHT_NAME = os.getenv("TRANSFORMER_WEIGHT_NAME", "diffusion_pytorch_model.safetensors") | |
| TRANSFORMER_SUBFOLDER = os.getenv("TRANSFORMER_SUBFOLDER", "transformer") | |
| TRANSFORMER_REVISION = os.getenv("TRANSFORMER_REVISION", "").strip() or None | |
| LOCAL_FILES_ONLY = _env_bool("LOCAL_FILES_ONLY", default=False) | |
| TAGMOE_DEVICE = os.getenv("TAGMOE_DEVICE", "auto").strip().lower() | |
| ZERO_GPU_DURATION = _env_int("ZERO_GPU_DURATION", default=120) | |
| LINKS_HTML = """ | |
| <div class="tagmoe-links"> | |
| <a href="https://arxiv.org/abs/2601.08881" target="_blank" rel="noopener noreferrer"> | |
| <img alt="ArXiv" src="https://img.shields.io/badge/ArXiv-2601.08881-red"> | |
| </a> | |
| <a href="https://yuci-gpt.github.io/TAG-MoE/" target="_blank" rel="noopener noreferrer"> | |
| <img alt="Project Page" src="https://img.shields.io/badge/Project%20Page-homepage-green"> | |
| </a> | |
| <a href="https://github.com/ICTMCG/TAG-MoE" target="_blank" rel="noopener noreferrer"> | |
| <img alt="GitHub Repo" src="https://img.shields.io/badge/GitHub-repo-181717?logo=github"> | |
| </a> | |
| <a href="https://huggingface.co/YUXU915/TAG-MoE" target="_blank" rel="noopener noreferrer"> | |
| <img alt="Model Weights" src="https://img.shields.io/badge/HuggingFace-weights-F9A825?logo=huggingface"> | |
| </a> | |
| </div> | |
| """ | |
| CPU_MODE_BANNER_HTML = """ | |
| <div class="cpu-warning"> | |
| <div class="cpu-warning-title">No GPU available</div> | |
| <div class="cpu-warning-desc"> | |
| Inference is currently unavailable because TAG-MoE requires grouped_gemm on CUDA runtime. | |
| </div> | |
| </div> | |
| """ | |
| CUSTOM_CSS = """ | |
| .tagmoe-header { | |
| display: flex; | |
| align-items: center; | |
| gap: 12px; | |
| margin-bottom: 8px; | |
| } | |
| .tagmoe-header img { | |
| width: 48px; | |
| height: 48px; | |
| object-fit: contain; | |
| } | |
| .tagmoe-header h1 { | |
| margin: 0; | |
| font-size: 1.8rem; | |
| } | |
| .tagmoe-header p { | |
| margin: 0; | |
| opacity: 0.85; | |
| font-size: 0.95rem; | |
| } | |
| .param-card { | |
| border: 1px solid var(--border-color-primary); | |
| border-radius: 12px; | |
| padding: 14px 14px 10px; | |
| margin-bottom: 10px; | |
| } | |
| .param-card .gradio-textbox textarea { | |
| min-height: 110px !important; | |
| } | |
| .run-btn button { | |
| height: 46px !important; | |
| font-weight: 600; | |
| } | |
| .image-panel { | |
| border: 1px solid var(--border-color-primary); | |
| border-radius: 12px; | |
| padding: 10px; | |
| } | |
| .tool-btn { | |
| margin-top: 28px !important; | |
| min-width: 42px !important; | |
| height: 42px !important; | |
| padding: 0 !important; | |
| display: flex; | |
| align-items: center; | |
| justify-content: center; | |
| flex-shrink: 0; | |
| } | |
| .tagmoe-links { | |
| margin: 6px 0 14px 0; | |
| display: flex; | |
| flex-wrap: wrap; | |
| gap: 8px; | |
| } | |
| .tagmoe-links a { | |
| text-decoration: none; | |
| } | |
| .tagmoe-links img { | |
| height: 22px; | |
| } | |
| .cpu-warning { | |
| margin: 2px 0 14px 0; | |
| border: 1px solid #f59e0b; | |
| background: #fff7ed; | |
| color: #7c2d12; | |
| border-radius: 10px; | |
| padding: 10px 12px; | |
| } | |
| .cpu-warning-title { | |
| font-weight: 700; | |
| margin-bottom: 4px; | |
| } | |
| .cpu-warning-desc { | |
| font-size: 0.92rem; | |
| line-height: 1.35; | |
| } | |
| """ | |
| _RUNTIME_LOCK = threading.Lock() | |
| _PIPELINE = None | |
| _BASE64_TO_IMAGE_FN = None | |
| _ASSETS_READY = False | |
| _ASSETS_LOCK = threading.Lock() | |
| _RUNTIME_STARTUP_ERROR = None | |
| def _has_cuda_runtime() -> bool: | |
| try: | |
| import torch | |
| except Exception: | |
| return False | |
| try: | |
| return bool(torch.cuda.is_available()) | |
| except Exception: | |
| return False | |
| def _has_spaces_gpu_backend() -> bool: | |
| if spaces is None: | |
| return False | |
| try: | |
| from spaces.config import Config | |
| except Exception: | |
| return False | |
| try: | |
| return bool(getattr(Config, "zero_gpu", False)) | |
| except Exception: | |
| return False | |
| def _has_gpu_runtime() -> bool: | |
| if _has_cuda_runtime(): | |
| return True | |
| return _has_spaces_gpu_backend() | |
| GPU_RUNTIME_AVAILABLE = _has_gpu_runtime() | |
| def _is_hf_repo_id(value: str) -> bool: | |
| if not value or "://" in value: | |
| return False | |
| if os.path.exists(value): | |
| return False | |
| return value.count("/") == 1 | |
| def _prepare_model_assets() -> None: | |
| global _ASSETS_READY | |
| if _ASSETS_READY or LOCAL_FILES_ONLY: | |
| return | |
| with _ASSETS_LOCK: | |
| if _ASSETS_READY or LOCAL_FILES_ONLY: | |
| return | |
| total_start = time.perf_counter() | |
| print("[TAG-MoE] Asset preload start") | |
| try: | |
| from huggingface_hub import snapshot_download | |
| except Exception: | |
| print("[TAG-MoE] huggingface_hub not available, skip preload") | |
| return | |
| os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1") | |
| if _is_hf_repo_id(PRETRAINED_MODEL_PATH): | |
| try: | |
| t0 = time.perf_counter() | |
| snapshot_download( | |
| repo_id=PRETRAINED_MODEL_PATH, | |
| local_files_only=LOCAL_FILES_ONLY, | |
| ) | |
| print( | |
| f"[TAG-MoE] Preloaded base repo {PRETRAINED_MODEL_PATH} in " | |
| f"{time.perf_counter()-t0:.2f}s" | |
| ) | |
| except Exception as exc: | |
| print(f"[TAG-MoE] Preload warning for {PRETRAINED_MODEL_PATH}: {exc}") | |
| if _is_hf_repo_id(TRANSFORMER_MODEL_PATH): | |
| try: | |
| t0 = time.perf_counter() | |
| snapshot_download( | |
| repo_id=TRANSFORMER_MODEL_PATH, | |
| revision=TRANSFORMER_REVISION, | |
| local_files_only=LOCAL_FILES_ONLY, | |
| ) | |
| print( | |
| f"[TAG-MoE] Preloaded transformer repo {TRANSFORMER_MODEL_PATH} " | |
| f"(rev={TRANSFORMER_REVISION or 'default'}) in {time.perf_counter()-t0:.2f}s" | |
| ) | |
| except Exception as exc: | |
| print(f"[TAG-MoE] Preload warning for {TRANSFORMER_MODEL_PATH}: {exc}") | |
| _ASSETS_READY = True | |
| print(f"[TAG-MoE] Asset preload done in {time.perf_counter()-total_start:.2f}s") | |
| def _resolve_runtime_device_ids(): | |
| if TAGMOE_DEVICE in {"", "auto", "default"}: | |
| return None | |
| if TAGMOE_DEVICE in {"none", "framework"}: | |
| return None | |
| return resolve_device_ids(TAGMOE_DEVICE) | |
| def _ensure_runtime_loaded(): | |
| global _PIPELINE, _BASE64_TO_IMAGE_FN | |
| if _PIPELINE is not None and _BASE64_TO_IMAGE_FN is not None: | |
| return _PIPELINE, _BASE64_TO_IMAGE_FN | |
| with _RUNTIME_LOCK: | |
| if _PIPELINE is not None and _BASE64_TO_IMAGE_FN is not None: | |
| return _PIPELINE, _BASE64_TO_IMAGE_FN | |
| total_start = time.perf_counter() | |
| print("[TAG-MoE] Runtime build start") | |
| t0 = time.perf_counter() | |
| _prepare_model_assets() | |
| print(f"[TAG-MoE] Runtime step: preload assets took {time.perf_counter()-t0:.2f}s") | |
| t0 = time.perf_counter() | |
| from src.infer_tagmoe import End2End, base64_to_image | |
| print(f"[TAG-MoE] Runtime step: import infer module took {time.perf_counter()-t0:.2f}s") | |
| device_ids = _resolve_runtime_device_ids() | |
| print(f"[TAG-MoE] Runtime config: TAGMOE_DEVICE={TAGMOE_DEVICE}, device_ids={device_ids}") | |
| t0 = time.perf_counter() | |
| _PIPELINE = End2End( | |
| pretrained_model_path=PRETRAINED_MODEL_PATH, | |
| transformer_model_path=TRANSFORMER_MODEL_PATH, | |
| device_ids=device_ids, | |
| transformer_weight_name=TRANSFORMER_WEIGHT_NAME, | |
| transformer_subfolder=TRANSFORMER_SUBFOLDER, | |
| transformer_revision=TRANSFORMER_REVISION, | |
| local_files_only=LOCAL_FILES_ONLY, | |
| ) | |
| print(f"[TAG-MoE] Runtime step: End2End init took {time.perf_counter()-t0:.2f}s") | |
| _BASE64_TO_IMAGE_FN = base64_to_image | |
| print(f"[TAG-MoE] Runtime build done in {time.perf_counter()-total_start:.2f}s") | |
| return _PIPELINE, _BASE64_TO_IMAGE_FN | |
| def _initialize_runtime_on_startup() -> None: | |
| global _RUNTIME_STARTUP_ERROR | |
| if not GPU_RUNTIME_AVAILABLE: | |
| return | |
| print("[TAG-MoE] Startup runtime initialization begin") | |
| t0 = time.perf_counter() | |
| try: | |
| _ensure_runtime_loaded() | |
| print( | |
| f"[TAG-MoE] Startup runtime initialization finished in " | |
| f"{time.perf_counter()-t0:.2f}s" | |
| ) | |
| except Exception as exc: | |
| _RUNTIME_STARTUP_ERROR = str(exc) | |
| print(f"[TAG-MoE] Runtime initialization failed: {exc}") | |
| def _infer_decorator(): | |
| if spaces is None: | |
| return lambda fn: fn | |
| return spaces.GPU(duration=ZERO_GPU_DURATION) | |
| def build_demo(gr): | |
| runtime_ready = ( | |
| GPU_RUNTIME_AVAILABLE | |
| and _RUNTIME_STARTUP_ERROR is None | |
| and _PIPELINE is not None | |
| and _BASE64_TO_IMAGE_FN is not None | |
| ) | |
| def infer( | |
| image, | |
| prompt, | |
| negative_prompt, | |
| seed, | |
| gen_width, | |
| gen_height, | |
| cfg_scale, | |
| inference_steps, | |
| ): | |
| if _RUNTIME_STARTUP_ERROR is not None: | |
| raise gr.Error(f"Runtime initialization failed: {_RUNTIME_STARTUP_ERROR}") | |
| if _PIPELINE is None or _BASE64_TO_IMAGE_FN is None: | |
| raise gr.Error("Runtime is not ready yet. Please wait and retry.") | |
| if not GPU_RUNTIME_AVAILABLE: | |
| raise gr.Error( | |
| "Inference is disabled with GPU undetected. " | |
| ) | |
| if prompt is None or not str(prompt).strip(): | |
| raise gr.Error("Prompt cannot be empty.") | |
| if image is None: | |
| raise gr.Error("Image is required.") | |
| width_value = int(gen_width) if gen_width is not None else int(image.size[0]) | |
| height_value = int(gen_height) if gen_height is not None else int(image.size[1]) | |
| input_dict = { | |
| "image": image.convert("RGB"), | |
| "prompt": str(prompt).strip(), | |
| "negative_prompt": str(negative_prompt or DEFAULT_NEGATIVE_PROMPT), | |
| "seed": int(seed if seed is not None else DEFAULT_SEED), | |
| "target_width": width_value, | |
| "target_height": height_value, | |
| "true_cfg_scale": float(cfg_scale), | |
| "num_inference_steps": int(inference_steps), | |
| "keep_original_size": False, | |
| } | |
| result = _PIPELINE.predict(input_dict) | |
| out_image = _BASE64_TO_IMAGE_FN(result["generate_imgs_buffer"][0]) | |
| return out_image, int(result["seed"]) | |
| def randomize_seed(): | |
| return generate_random_seed() | |
| def on_image_upload(image): | |
| if image is None: | |
| return gr.update(), gr.update() | |
| return int(image.size[0]), int(image.size[1]) | |
| title_html = """ | |
| <div class="tagmoe-header"> | |
| <picture> | |
| <source srcset="https://raw.githubusercontent.com/yuci-gpt/TAG-MoE/refs/heads/master/static/images/logo_dark.png" media="(prefers-color-scheme: dark)"> | |
| <img src="https://raw.githubusercontent.com/yuci-gpt/TAG-MoE/refs/heads/master/static/images/logo_light.png" alt="TAG-MoE logo"> | |
| </picture> | |
| <div> | |
| <h1>TAG-MoE</h1> | |
| <p>Task-Aware Gating for Unified Generative Mixture-of-Experts</p> | |
| </div> | |
| </div> | |
| """ | |
| infer_fn = _infer_decorator()(infer) | |
| with gr.Blocks(title="TAG-MoE Space Demo") as demo: | |
| gr.HTML(title_html) | |
| gr.HTML(LINKS_HTML) | |
| if not GPU_RUNTIME_AVAILABLE: | |
| gr.HTML(CPU_MODE_BANNER_HTML) | |
| elif not runtime_ready and _RUNTIME_STARTUP_ERROR is not None: | |
| gr.HTML( | |
| f'<div class="cpu-warning"><div class="cpu-warning-title">Runtime initialization failed</div>' | |
| f'<div class="cpu-warning-desc">{_RUNTIME_STARTUP_ERROR}</div></div>' | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1, elem_classes=["image-panel"]): | |
| image_input = gr.Image(type="pil", label="Input Image", height=520) | |
| with gr.Column(scale=1, elem_classes=["image-panel"]): | |
| image_output = gr.Image(type="pil", label="Output Image", height=520) | |
| with gr.Group(elem_classes=["param-card"]): | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| placeholder="Describe the instruction", | |
| lines=3, | |
| ) | |
| negative_prompt_input = gr.Textbox( | |
| label="Negative Prompt", | |
| value=DEFAULT_NEGATIVE_PROMPT, | |
| lines=2, | |
| placeholder="Optional negative prompt", | |
| ) | |
| with gr.Row(): | |
| gen_width_input = gr.Slider(minimum=64, maximum=4096, step=1, value=DEFAULT_WIDTH, label="Width") | |
| gen_height_input = gr.Slider(minimum=64, maximum=4096, step=1, value=DEFAULT_HEIGHT, label="Height") | |
| with gr.Row(): | |
| cfg_scale_input = gr.Slider( | |
| minimum=1.0, | |
| maximum=10.0, | |
| step=0.1, | |
| value=DEFAULT_TRUE_CFG_SCALE, | |
| label="CFG Scale", | |
| ) | |
| inference_steps_input = gr.Slider( | |
| minimum=10, | |
| maximum=100, | |
| step=1, | |
| value=DEFAULT_NUM_INFERENCE_STEPS, | |
| label="Inference Steps", | |
| ) | |
| with gr.Column(scale=1, min_width=200): | |
| with gr.Row(): | |
| seed_input = gr.Number( | |
| label="Seed", | |
| value=generate_random_seed(), | |
| precision=0, | |
| scale=1, | |
| ) | |
| random_seed_btn = gr.Button( | |
| "🎲", | |
| elem_classes=["tool-btn"], | |
| scale=0, | |
| min_width=42, | |
| variant="secondary", | |
| ) | |
| run_btn = gr.Button( | |
| "Run Inference", | |
| variant="primary", | |
| elem_classes=["run-btn"], | |
| interactive=runtime_ready, | |
| ) | |
| run_btn.click( | |
| fn=infer_fn, | |
| inputs=[ | |
| image_input, | |
| prompt_input, | |
| negative_prompt_input, | |
| seed_input, | |
| gen_width_input, | |
| gen_height_input, | |
| cfg_scale_input, | |
| inference_steps_input, | |
| ], | |
| outputs=[image_output, seed_input], | |
| ) | |
| image_input.change( | |
| fn=on_image_upload, | |
| inputs=[image_input], | |
| outputs=[gen_width_input, gen_height_input], | |
| ) | |
| random_seed_btn.click(fn=randomize_seed, outputs=[seed_input]) | |
| return demo | |
| _initialize_runtime_on_startup() | |
| demo = build_demo(gr) | |
| demo.queue(default_concurrency_limit=1, max_size=8) | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=int(os.getenv("PORT", "7860")), | |
| css=CUSTOM_CSS, | |
| ssr_mode=False, | |
| ) | |