TAG-MoE / app.py
YUXU915's picture
Update app.py
ab094ba verified
import asyncio
import os
import sys
import threading
import time
import gradio as gr
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
GROUPED_GEMM_SRC = os.path.join(ROOT_DIR, "third_party", "grouped_gemm")
MEGABLOCKS_SRC = os.path.join(ROOT_DIR, "third_party", "megablocks")
for _path in (GROUPED_GEMM_SRC, MEGABLOCKS_SRC):
if os.path.isdir(_path) and _path not in sys.path:
sys.path.insert(0, _path)
from src.utils.device_utils import resolve_device_ids
from src.utils.inference_config import (
DEFAULT_HEIGHT,
DEFAULT_NEGATIVE_PROMPT,
DEFAULT_NUM_INFERENCE_STEPS,
DEFAULT_SEED,
DEFAULT_TRUE_CFG_SCALE,
DEFAULT_WIDTH,
generate_random_seed,
)
try:
import spaces
except ImportError:
spaces = None
def _suppress_asyncio_fd_noise_for_py310() -> None:
if sys.version_info[:2] != (3, 10):
return
original_del = getattr(asyncio.BaseEventLoop, "__del__", None)
if original_del is None:
return
if getattr(original_del, "_tagmoe_fd_noise_patched", False):
return
def _safe_del(self):
try:
original_del(self)
except ValueError as exc:
if "Invalid file descriptor: -1" not in str(exc):
raise
_safe_del._tagmoe_fd_noise_patched = True
asyncio.BaseEventLoop.__del__ = _safe_del
_suppress_asyncio_fd_noise_for_py310()
def _env_bool(name: str, default: bool = False) -> bool:
value = os.getenv(name)
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "on"}
def _env_int(name: str, default: int) -> int:
value = os.getenv(name)
if value is None or not value.strip():
return default
return int(value.strip())
PRETRAINED_MODEL_PATH = os.getenv("PRETRAINED_MODEL_PATH", "Qwen/Qwen-Image")
TRANSFORMER_MODEL_PATH = os.getenv("TRANSFORMER_MODEL_PATH", "YUXU915/TAG-MoE")
TRANSFORMER_WEIGHT_NAME = os.getenv("TRANSFORMER_WEIGHT_NAME", "diffusion_pytorch_model.safetensors")
TRANSFORMER_SUBFOLDER = os.getenv("TRANSFORMER_SUBFOLDER", "transformer")
TRANSFORMER_REVISION = os.getenv("TRANSFORMER_REVISION", "").strip() or None
LOCAL_FILES_ONLY = _env_bool("LOCAL_FILES_ONLY", default=False)
TAGMOE_DEVICE = os.getenv("TAGMOE_DEVICE", "auto").strip().lower()
ZERO_GPU_DURATION = _env_int("ZERO_GPU_DURATION", default=120)
LINKS_HTML = """
<div class="tagmoe-links">
<a href="https://arxiv.org/abs/2601.08881" target="_blank" rel="noopener noreferrer">
<img alt="ArXiv" src="https://img.shields.io/badge/ArXiv-2601.08881-red">
</a>
<a href="https://yuci-gpt.github.io/TAG-MoE/" target="_blank" rel="noopener noreferrer">
<img alt="Project Page" src="https://img.shields.io/badge/Project%20Page-homepage-green">
</a>
<a href="https://github.com/ICTMCG/TAG-MoE" target="_blank" rel="noopener noreferrer">
<img alt="GitHub Repo" src="https://img.shields.io/badge/GitHub-repo-181717?logo=github">
</a>
<a href="https://huggingface.co/YUXU915/TAG-MoE" target="_blank" rel="noopener noreferrer">
<img alt="Model Weights" src="https://img.shields.io/badge/HuggingFace-weights-F9A825?logo=huggingface">
</a>
</div>
"""
CPU_MODE_BANNER_HTML = """
<div class="cpu-warning">
<div class="cpu-warning-title">No GPU available</div>
<div class="cpu-warning-desc">
Inference is currently unavailable because TAG-MoE requires grouped_gemm on CUDA runtime.
</div>
</div>
"""
CUSTOM_CSS = """
.tagmoe-header {
display: flex;
align-items: center;
gap: 12px;
margin-bottom: 8px;
}
.tagmoe-header img {
width: 48px;
height: 48px;
object-fit: contain;
}
.tagmoe-header h1 {
margin: 0;
font-size: 1.8rem;
}
.tagmoe-header p {
margin: 0;
opacity: 0.85;
font-size: 0.95rem;
}
.param-card {
border: 1px solid var(--border-color-primary);
border-radius: 12px;
padding: 14px 14px 10px;
margin-bottom: 10px;
}
.param-card .gradio-textbox textarea {
min-height: 110px !important;
}
.run-btn button {
height: 46px !important;
font-weight: 600;
}
.image-panel {
border: 1px solid var(--border-color-primary);
border-radius: 12px;
padding: 10px;
}
.tool-btn {
margin-top: 28px !important;
min-width: 42px !important;
height: 42px !important;
padding: 0 !important;
display: flex;
align-items: center;
justify-content: center;
flex-shrink: 0;
}
.tagmoe-links {
margin: 6px 0 14px 0;
display: flex;
flex-wrap: wrap;
gap: 8px;
}
.tagmoe-links a {
text-decoration: none;
}
.tagmoe-links img {
height: 22px;
}
.cpu-warning {
margin: 2px 0 14px 0;
border: 1px solid #f59e0b;
background: #fff7ed;
color: #7c2d12;
border-radius: 10px;
padding: 10px 12px;
}
.cpu-warning-title {
font-weight: 700;
margin-bottom: 4px;
}
.cpu-warning-desc {
font-size: 0.92rem;
line-height: 1.35;
}
"""
_RUNTIME_LOCK = threading.Lock()
_PIPELINE = None
_BASE64_TO_IMAGE_FN = None
_ASSETS_READY = False
_ASSETS_LOCK = threading.Lock()
_RUNTIME_STARTUP_ERROR = None
def _has_cuda_runtime() -> bool:
try:
import torch
except Exception:
return False
try:
return bool(torch.cuda.is_available())
except Exception:
return False
def _has_spaces_gpu_backend() -> bool:
if spaces is None:
return False
try:
from spaces.config import Config
except Exception:
return False
try:
return bool(getattr(Config, "zero_gpu", False))
except Exception:
return False
def _has_gpu_runtime() -> bool:
if _has_cuda_runtime():
return True
return _has_spaces_gpu_backend()
GPU_RUNTIME_AVAILABLE = _has_gpu_runtime()
def _is_hf_repo_id(value: str) -> bool:
if not value or "://" in value:
return False
if os.path.exists(value):
return False
return value.count("/") == 1
def _prepare_model_assets() -> None:
global _ASSETS_READY
if _ASSETS_READY or LOCAL_FILES_ONLY:
return
with _ASSETS_LOCK:
if _ASSETS_READY or LOCAL_FILES_ONLY:
return
total_start = time.perf_counter()
print("[TAG-MoE] Asset preload start")
try:
from huggingface_hub import snapshot_download
except Exception:
print("[TAG-MoE] huggingface_hub not available, skip preload")
return
os.environ.setdefault("HF_HUB_ENABLE_HF_TRANSFER", "1")
if _is_hf_repo_id(PRETRAINED_MODEL_PATH):
try:
t0 = time.perf_counter()
snapshot_download(
repo_id=PRETRAINED_MODEL_PATH,
local_files_only=LOCAL_FILES_ONLY,
)
print(
f"[TAG-MoE] Preloaded base repo {PRETRAINED_MODEL_PATH} in "
f"{time.perf_counter()-t0:.2f}s"
)
except Exception as exc:
print(f"[TAG-MoE] Preload warning for {PRETRAINED_MODEL_PATH}: {exc}")
if _is_hf_repo_id(TRANSFORMER_MODEL_PATH):
try:
t0 = time.perf_counter()
snapshot_download(
repo_id=TRANSFORMER_MODEL_PATH,
revision=TRANSFORMER_REVISION,
local_files_only=LOCAL_FILES_ONLY,
)
print(
f"[TAG-MoE] Preloaded transformer repo {TRANSFORMER_MODEL_PATH} "
f"(rev={TRANSFORMER_REVISION or 'default'}) in {time.perf_counter()-t0:.2f}s"
)
except Exception as exc:
print(f"[TAG-MoE] Preload warning for {TRANSFORMER_MODEL_PATH}: {exc}")
_ASSETS_READY = True
print(f"[TAG-MoE] Asset preload done in {time.perf_counter()-total_start:.2f}s")
def _resolve_runtime_device_ids():
if TAGMOE_DEVICE in {"", "auto", "default"}:
return None
if TAGMOE_DEVICE in {"none", "framework"}:
return None
return resolve_device_ids(TAGMOE_DEVICE)
def _ensure_runtime_loaded():
global _PIPELINE, _BASE64_TO_IMAGE_FN
if _PIPELINE is not None and _BASE64_TO_IMAGE_FN is not None:
return _PIPELINE, _BASE64_TO_IMAGE_FN
with _RUNTIME_LOCK:
if _PIPELINE is not None and _BASE64_TO_IMAGE_FN is not None:
return _PIPELINE, _BASE64_TO_IMAGE_FN
total_start = time.perf_counter()
print("[TAG-MoE] Runtime build start")
t0 = time.perf_counter()
_prepare_model_assets()
print(f"[TAG-MoE] Runtime step: preload assets took {time.perf_counter()-t0:.2f}s")
t0 = time.perf_counter()
from src.infer_tagmoe import End2End, base64_to_image
print(f"[TAG-MoE] Runtime step: import infer module took {time.perf_counter()-t0:.2f}s")
device_ids = _resolve_runtime_device_ids()
print(f"[TAG-MoE] Runtime config: TAGMOE_DEVICE={TAGMOE_DEVICE}, device_ids={device_ids}")
t0 = time.perf_counter()
_PIPELINE = End2End(
pretrained_model_path=PRETRAINED_MODEL_PATH,
transformer_model_path=TRANSFORMER_MODEL_PATH,
device_ids=device_ids,
transformer_weight_name=TRANSFORMER_WEIGHT_NAME,
transformer_subfolder=TRANSFORMER_SUBFOLDER,
transformer_revision=TRANSFORMER_REVISION,
local_files_only=LOCAL_FILES_ONLY,
)
print(f"[TAG-MoE] Runtime step: End2End init took {time.perf_counter()-t0:.2f}s")
_BASE64_TO_IMAGE_FN = base64_to_image
print(f"[TAG-MoE] Runtime build done in {time.perf_counter()-total_start:.2f}s")
return _PIPELINE, _BASE64_TO_IMAGE_FN
def _initialize_runtime_on_startup() -> None:
global _RUNTIME_STARTUP_ERROR
if not GPU_RUNTIME_AVAILABLE:
return
print("[TAG-MoE] Startup runtime initialization begin")
t0 = time.perf_counter()
try:
_ensure_runtime_loaded()
print(
f"[TAG-MoE] Startup runtime initialization finished in "
f"{time.perf_counter()-t0:.2f}s"
)
except Exception as exc:
_RUNTIME_STARTUP_ERROR = str(exc)
print(f"[TAG-MoE] Runtime initialization failed: {exc}")
def _infer_decorator():
if spaces is None:
return lambda fn: fn
return spaces.GPU(duration=ZERO_GPU_DURATION)
def build_demo(gr):
runtime_ready = (
GPU_RUNTIME_AVAILABLE
and _RUNTIME_STARTUP_ERROR is None
and _PIPELINE is not None
and _BASE64_TO_IMAGE_FN is not None
)
def infer(
image,
prompt,
negative_prompt,
seed,
gen_width,
gen_height,
cfg_scale,
inference_steps,
):
if _RUNTIME_STARTUP_ERROR is not None:
raise gr.Error(f"Runtime initialization failed: {_RUNTIME_STARTUP_ERROR}")
if _PIPELINE is None or _BASE64_TO_IMAGE_FN is None:
raise gr.Error("Runtime is not ready yet. Please wait and retry.")
if not GPU_RUNTIME_AVAILABLE:
raise gr.Error(
"Inference is disabled with GPU undetected. "
)
if prompt is None or not str(prompt).strip():
raise gr.Error("Prompt cannot be empty.")
if image is None:
raise gr.Error("Image is required.")
width_value = int(gen_width) if gen_width is not None else int(image.size[0])
height_value = int(gen_height) if gen_height is not None else int(image.size[1])
input_dict = {
"image": image.convert("RGB"),
"prompt": str(prompt).strip(),
"negative_prompt": str(negative_prompt or DEFAULT_NEGATIVE_PROMPT),
"seed": int(seed if seed is not None else DEFAULT_SEED),
"target_width": width_value,
"target_height": height_value,
"true_cfg_scale": float(cfg_scale),
"num_inference_steps": int(inference_steps),
"keep_original_size": False,
}
result = _PIPELINE.predict(input_dict)
out_image = _BASE64_TO_IMAGE_FN(result["generate_imgs_buffer"][0])
return out_image, int(result["seed"])
def randomize_seed():
return generate_random_seed()
def on_image_upload(image):
if image is None:
return gr.update(), gr.update()
return int(image.size[0]), int(image.size[1])
title_html = """
<div class="tagmoe-header">
<picture>
<source srcset="https://raw.githubusercontent.com/yuci-gpt/TAG-MoE/refs/heads/master/static/images/logo_dark.png" media="(prefers-color-scheme: dark)">
<img src="https://raw.githubusercontent.com/yuci-gpt/TAG-MoE/refs/heads/master/static/images/logo_light.png" alt="TAG-MoE logo">
</picture>
<div>
<h1>TAG-MoE</h1>
<p>Task-Aware Gating for Unified Generative Mixture-of-Experts</p>
</div>
</div>
"""
infer_fn = _infer_decorator()(infer)
with gr.Blocks(title="TAG-MoE Space Demo") as demo:
gr.HTML(title_html)
gr.HTML(LINKS_HTML)
if not GPU_RUNTIME_AVAILABLE:
gr.HTML(CPU_MODE_BANNER_HTML)
elif not runtime_ready and _RUNTIME_STARTUP_ERROR is not None:
gr.HTML(
f'<div class="cpu-warning"><div class="cpu-warning-title">Runtime initialization failed</div>'
f'<div class="cpu-warning-desc">{_RUNTIME_STARTUP_ERROR}</div></div>'
)
with gr.Row(equal_height=True):
with gr.Column(scale=1, elem_classes=["image-panel"]):
image_input = gr.Image(type="pil", label="Input Image", height=520)
with gr.Column(scale=1, elem_classes=["image-panel"]):
image_output = gr.Image(type="pil", label="Output Image", height=520)
with gr.Group(elem_classes=["param-card"]):
prompt_input = gr.Textbox(
label="Prompt",
placeholder="Describe the instruction",
lines=3,
)
negative_prompt_input = gr.Textbox(
label="Negative Prompt",
value=DEFAULT_NEGATIVE_PROMPT,
lines=2,
placeholder="Optional negative prompt",
)
with gr.Row():
gen_width_input = gr.Slider(minimum=64, maximum=4096, step=1, value=DEFAULT_WIDTH, label="Width")
gen_height_input = gr.Slider(minimum=64, maximum=4096, step=1, value=DEFAULT_HEIGHT, label="Height")
with gr.Row():
cfg_scale_input = gr.Slider(
minimum=1.0,
maximum=10.0,
step=0.1,
value=DEFAULT_TRUE_CFG_SCALE,
label="CFG Scale",
)
inference_steps_input = gr.Slider(
minimum=10,
maximum=100,
step=1,
value=DEFAULT_NUM_INFERENCE_STEPS,
label="Inference Steps",
)
with gr.Column(scale=1, min_width=200):
with gr.Row():
seed_input = gr.Number(
label="Seed",
value=generate_random_seed(),
precision=0,
scale=1,
)
random_seed_btn = gr.Button(
"🎲",
elem_classes=["tool-btn"],
scale=0,
min_width=42,
variant="secondary",
)
run_btn = gr.Button(
"Run Inference",
variant="primary",
elem_classes=["run-btn"],
interactive=runtime_ready,
)
run_btn.click(
fn=infer_fn,
inputs=[
image_input,
prompt_input,
negative_prompt_input,
seed_input,
gen_width_input,
gen_height_input,
cfg_scale_input,
inference_steps_input,
],
outputs=[image_output, seed_input],
)
image_input.change(
fn=on_image_upload,
inputs=[image_input],
outputs=[gen_width_input, gen_height_input],
)
random_seed_btn.click(fn=randomize_seed, outputs=[seed_input])
return demo
_initialize_runtime_on_startup()
demo = build_demo(gr)
demo.queue(default_concurrency_limit=1, max_size=8)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=int(os.getenv("PORT", "7860")),
css=CUSTOM_CSS,
ssr_mode=False,
)