import argparse import atexit import importlib import os import signal import subprocess import sys import tempfile import time from pathlib import Path from typing import List, Optional, Tuple import gradio as gr import requests from torch import Tensor from tqdm import tqdm # --------------------------------------------------------------------------- # ZeroGPU compatibility shim. The hosted HF Space provides the `spaces` # package; running locally we substitute a no-op. # --------------------------------------------------------------------------- try: spaces = importlib.import_module("spaces") except Exception: class _SpacesCompat: @staticmethod def GPU(*args, **kwargs): if len(args) == 1 and callable(args[0]) and not kwargs: return args[0] def _decorator(fn): return fn return _decorator spaces = _SpacesCompat() os.environ.setdefault("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "1") gr.TEMP_DIR = "tmp_gradio" # --------------------------------------------------------------------------- # Install the bundled `bpy` wheel at runtime if it isn't already importable. # # Why this is non-trivial: # - Putting the wheel in requirements.txt fails: HF Spaces' Docker build # mounts only requirements.txt BEFORE the repo COPY, so the wheel path # doesn't exist at pip-install time. # - PyPI doesn't ship a bpy wheel matching this exact build (rc0 / cp312 / # manylinux_2_39). # - The `bpy-*.whl` committed in this repo gets auto-tracked by HF's LFS # layer (Hub auto-LFS for blobs > ~10 MB even when .gitattributes doesn't # list `*.whl`). The container's COPY-from-repo only carries the LFS # *pointer* file — a ~150-byte text stub — not the actual wheel binary. # So `pip install ` and `zipfile.ZipFile()` both fail with # "is not a zip file" / "Wheel is invalid". # # So: we detect the LFS-pointer case and re-fetch the real wheel from the # HF Hub at runtime (where the API resolves LFS server-side), then extract # it directly into site-packages. # --------------------------------------------------------------------------- def _ensure_bpy_installed(): try: import bpy # noqa: F401 return except Exception: pass import glob import sysconfig import zipfile here = os.path.dirname(os.path.abspath(__file__)) wheels = sorted(glob.glob(os.path.join(here, "bpy-*.whl"))) if not wheels: print("[demo] WARNING: bpy not importable and no bundled wheel found", flush=True) return wheel = wheels[-1] wheel_name = os.path.basename(wheel) # Detect LFS pointer (text stub starting with "version https://git-lfs..."). is_real_zip = False try: with open(wheel, "rb") as f: is_real_zip = f.read(4).startswith(b"PK") except Exception: pass if not is_real_zip: print( f"[demo] {wheel_name} on disk is an LFS pointer ({os.path.getsize(wheel)} B); " f"fetching real wheel from HF Hub...", flush=True, ) from huggingface_hub import hf_hub_download space_id = os.environ.get("SPACE_ID", "VAST-AI/SkinTokens") token = os.environ.get("HF_TOKEN") # set as a Space secret for private repos wheel = hf_hub_download( repo_id=space_id, repo_type="space", filename=wheel_name, token=token, ) print(f"[demo] fetched -> {wheel} ({os.path.getsize(wheel)} B)", flush=True) site = sysconfig.get_paths()["purelib"] print(f"[demo] Extracting {wheel_name} into {site}", flush=True) with zipfile.ZipFile(wheel) as z: z.extractall(site) print("[demo] bpy wheel extracted.", flush=True) _ensure_bpy_installed() # --------------------------------------------------------------------------- # Download model checkpoints (TokenRig + SkinTokens FSQ-CVAE) and the Qwen3 # tokenizer/config on first cold-start. # # These live in the *model* repo `VAST-AI/SkinTokens` (private), separate # from this Space repo, so they aren't COPYed into the container. Re-uses # `HF_TOKEN` from the Space secrets. # --------------------------------------------------------------------------- def _ensure_models_downloaded(): here = os.path.dirname(os.path.abspath(__file__)) needed_ckpts = [ "experiments/skin_vae_2_10_32768/last.ckpt", "experiments/articulation_xl_quantization_256_token_4/grpo_1400.ckpt", ] qwen_dir = os.path.join(here, "models", "Qwen3-0.6B") all_present = ( all(os.path.exists(os.path.join(here, p)) for p in needed_ckpts) and os.path.exists(os.path.join(qwen_dir, "tokenizer.json")) ) if all_present: return from huggingface_hub import hf_hub_download, snapshot_download token = os.environ.get("HF_TOKEN") for rel in needed_ckpts: target = os.path.join(here, rel) if os.path.exists(target): continue print(f"[demo] Downloading checkpoint: {rel}", flush=True) hf_hub_download( repo_id="VAST-AI/SkinTokens", filename=rel, local_dir=here, token=token, ) if not os.path.exists(os.path.join(qwen_dir, "tokenizer.json")): print("[demo] Downloading Qwen3-0.6B tokenizer/config", flush=True) snapshot_download( repo_id="Qwen/Qwen3-0.6B", local_dir=qwen_dir, ignore_patterns=["*.bin", "*.safetensors"], ) print("[demo] All checkpoints ready.", flush=True) _ensure_models_downloaded() from src.data.dataset import DatasetConfig, RigDatasetModule from src.data.transform import Transform from src.model.tokenrig import TokenRigResult from src.tokenizer.parse import get_tokenizer from src.server.spec import ( BPY_SERVER, get_model, object_to_bytes, bytes_to_object, ) from src.data.vertex_group import voxel_skin # --------------------------------------------------------------------------- # Pre-warm `bpy_server` in the main (Gradio) process at module load. # # Why this is necessary on ZeroGPU: each user request runs inside a fresh # `@spaces.GPU` worker process with a hard time budget (≈60 s on free tier). # Importing the Blender shared object inside that budget burns 30–60 s, so # the worker is killed *during* bpy import — manifesting as # "GPU task aborted" before any model code runs. # # We start `bpy_server.py` here, in the always-running main process, so the # slow bpy import happens exactly once at Space boot. Workers then just hit # `localhost:59876` over HTTP — sub-millisecond, no startup cost. # --------------------------------------------------------------------------- MODEL_CKPTS = [ "experiments/articulation_xl_quantization_256_token_4/grpo_1400.ckpt", ] HF_PATHS = [ "None", ] def get_dataloader_workers() -> int: if os.getenv("SPACE_ID"): return 0 return 1 # --------------------------------------------------------------------------- # bpy_server lifecycle — lazy start so the heavy import doesn't fight ZeroGPU # during module load. # --------------------------------------------------------------------------- _BPY_SERVER_PROC = None def is_bpy_server_alive(timeout: float = 1.0) -> bool: try: resp = requests.get(f"{BPY_SERVER}/ping", timeout=timeout) return resp.status_code == 200 except Exception: return False def start_bpy_server(): proc = subprocess.Popen( [sys.executable, "bpy_server.py"], stdout=None, stderr=None, preexec_fn=os.setsid, ) print(f"[Main] bpy_server.py started (pid={proc.pid})") def cleanup(): print(f"[Main] Terminating bpy_server.py (pid={proc.pid})") try: os.killpg(os.getpgid(proc.pid), signal.SIGTERM) except ProcessLookupError: pass atexit.register(cleanup) return proc def wait_for_bpy_server(timeout: float = 120): """Wait for bpy_server.py to come up. The first start of bpy_server is slow because importing the Blender `.so` (~200 MB shared object) takes 30–60 s on a cold container. We allow up to 120 s.""" t0 = time.time() last_log = 0.0 while True: try: requests.get(f"{BPY_SERVER}/ping", timeout=1) print(f"[Main] bpy_server is ready (after {time.time() - t0:.1f}s)") return except Exception: now = time.time() if now - t0 > timeout: raise RuntimeError( f"bpy_server failed to start after {timeout:.0f}s" ) if now - last_log > 10: # progress every 10s print(f"[Main] still waiting for bpy_server ({now - t0:.0f}s elapsed)") last_log = now time.sleep(0.5) def ensure_bpy_server_started(): global _BPY_SERVER_PROC if is_bpy_server_alive(): return if _BPY_SERVER_PROC is not None and _BPY_SERVER_PROC.poll() is None: return _BPY_SERVER_PROC = start_bpy_server() wait_for_bpy_server() # --------------------------------------------------------------------------- # Lazy model loading. # --------------------------------------------------------------------------- model = None tokenizer = None transform = None CURRENT_MODEL_CKPT: Optional[str] = None CURRENT_HF_PATH: Optional[str] = None def load_model(model_ckpt: str, hf_path: Optional[str]) -> Tuple[str, str]: global model, tokenizer, transform, CURRENT_MODEL_CKPT, CURRENT_HF_PATH if hf_path == "None": hf_path = None if model is not None and model_ckpt == CURRENT_MODEL_CKPT and hf_path == CURRENT_HF_PATH: return ("Model already loaded.", model_ckpt) if not model_ckpt: raise RuntimeError("model_ckpt is empty. Please select a checkpoint.") print(f"Loading model: {model_ckpt}, hf_path={hf_path}") model = get_model(model_ckpt, hf_path=hf_path) assert model.tokenizer_config is not None tokenizer = get_tokenizer(**model.tokenizer_config) transform = Transform.parse(**model.transform_config["predict_transform"]) CURRENT_MODEL_CKPT = model_ckpt CURRENT_HF_PATH = hf_path return ("Model loaded.", model_ckpt) # --------------------------------------------------------------------------- # File utilities (CLI-side). # --------------------------------------------------------------------------- SUPPORTED_EXT = {".obj", ".fbx", ".glb"} def collect_files(input_path: Path) -> List[Path]: if input_path.is_file(): return [input_path] files = [] for p in input_path.rglob("*"): if p.suffix.lower() in SUPPORTED_EXT: files.append(p) return files def map_output_path(in_path: Path, input_root: Path, output_root: Path) -> Path: rel = in_path.relative_to(input_root) return (output_root / rel).with_suffix(".glb") # --------------------------------------------------------------------------- # Core inference (shared by CLI and Gradio). # --------------------------------------------------------------------------- def run_rig( filepaths: List[Path], top_k: int, top_p: float, temperature: float, repetition_penalty: float, num_beams: int, use_skeleton: bool, use_transfer: bool, use_postprocess: bool, output_paths: List[Path], model_ckpt: str, hf_path: Optional[str], ): assert len(filepaths) == len(output_paths) ensure_bpy_server_started() load_model(model_ckpt, hf_path) datapath = { "data_name": None, "loader": "bpy_server", "filepaths": {"articulation": [str(p) for p in filepaths]}, } dataset_config = DatasetConfig.parse( shuffle=False, batch_size=1, num_workers=get_dataloader_workers(), pin_memory=get_dataloader_workers() > 0, persistent_workers=False, datapath=datapath, ).split_by_cls() module = RigDatasetModule( predict_dataset_config=dataset_config, predict_transform=transform, tokenizer=tokenizer, process_fn=model._process_fn, ) dataloader = module.predict_dataloader()["articulation"] results_out = [] infer_device = model.device if model is not None else "cuda" for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)): batch = { k: v.to(infer_device) if isinstance(v, Tensor) else v for k, v in batch.items() } if not use_skeleton: batch.pop("skeleton_tokens", None) batch.pop("skeleton_mask", None) batch["generate_kwargs"] = dict( max_length=2048, top_k=int(top_k), top_p=float(top_p), temperature=float(temperature), repetition_penalty=float(repetition_penalty), num_return_sequences=1, num_beams=int(num_beams), do_sample=True, ) if "skeleton_tokens" in batch and "skeleton_mask" in batch: mask = batch["skeleton_mask"][0] == 1 skeleton_tokens = batch["skeleton_tokens"][0][mask].cpu().numpy() else: skeleton_tokens = None preds: List[TokenRigResult] = model.predict_step( batch, skeleton_tokens=[skeleton_tokens] if skeleton_tokens is not None else None, make_asset=True, )["results"] asset = preds[0].asset assert asset is not None if use_postprocess: voxel = asset.voxel(resolution=196) asset.skin *= voxel_skin( grid=0, grid_coords=voxel.coords, joints=asset.joints, vertices=asset.vertices, faces=asset.faces, mode="square", voxel_size=voxel.voxel_size, ) asset.normalize_skin() out_path = output_paths[i] out_path.parent.mkdir(parents=True, exist_ok=True) if use_transfer: payload = dict( source_asset=asset, target_path=asset.path, export_path=str(out_path), group_per_vertex=4, ) res = bytes_to_object( requests.post( f"{BPY_SERVER}/transfer", data=object_to_bytes(payload), ).content ) else: payload = dict( asset=asset, filepath=str(out_path), group_per_vertex=4, ) res = bytes_to_object( requests.post( f"{BPY_SERVER}/export", data=object_to_bytes(payload), ).content ) if res != "ok": print(f"[Error] {res}") else: print(f"[OK] Exported: {out_path}") results_out.append(out_path) return results_out # --------------------------------------------------------------------------- # CLI entry point. # --------------------------------------------------------------------------- def run_cli(args): input_path = Path(args.input).resolve() output_path = Path(args.output).resolve() files = collect_files(input_path) if not files: raise RuntimeError("No valid 3D files found.") if len(files) == 1 and output_path.suffix: outputs = [output_path] else: outputs = [map_output_path(f, input_path, output_path) for f in files] run_rig( files, args.top_k, args.top_p, args.temperature, args.repetition_penalty, args.num_beams, args.use_skeleton, args.use_transfer, args.use_postprocess, outputs, args.model_ckpt, args.hf_path, ) # --------------------------------------------------------------------------- # Gradio wrapper (with ZeroGPU duration estimator). # --------------------------------------------------------------------------- TOT = 0 def _gpu_duration( files, top_k, top_p, temperature, repetition_penalty, num_beams, use_skeleton, use_transfer, use_postprocess, model_ckpt, hf_path, ): # Cold workers spend ~30–60 s importing bpy + loading the model before # any GPU work. Give every request a generous 240 s floor. file_count = len(files) if files is not None else 1 return min(900, max(240, 240 + 60 * file_count)) @spaces.GPU(duration=_gpu_duration) def run_gradio( files, top_k, top_p, temperature, repetition_penalty, num_beams, use_skeleton, use_transfer, use_postprocess, model_ckpt, hf_path, ): if not files: return "Please upload at least one 3D model.", None tmp_out = Path(tempfile.mkdtemp(prefix="tokenrig_")) filepaths = [Path(f.name) for f in files] global TOT outputs = [] for filepath in filepaths: TOT += 1 outputs.append(tmp_out / f"res_{TOT}.glb") run_rig( filepaths, top_k, top_p, temperature, repetition_penalty, num_beams, use_skeleton, use_transfer, use_postprocess, outputs, model_ckpt, hf_path, ) return f"Processed {len(outputs)} models.", [str(p) for p in outputs] # --------------------------------------------------------------------------- # Gradio UI. # --------------------------------------------------------------------------- def build_gradio_app(): model_ckpts = MODEL_CKPTS hf_paths = HF_PATHS default_ckpt = model_ckpts[0] if model_ckpts else "" default_hf = hf_paths[0] if hf_paths else "None" with gr.Blocks(title="SkinTokens · TokenRig Demo") as app: gr.Markdown( """ ## 🦴 Mesh to Rig with [SkinTokens](https://zjp-shadow.github.io/works/SkinTokens/) · TokenRig Automated **skeleton generation + skinning weight prediction** for any 3D mesh, via a unified autoregressive model over learned *SkinTokens*. Successor to [UniRig](https://github.com/VAST-AI-Research/UniRig) (SIGGRAPH '25). * Upload one or more meshes → click **Run** → download a rigged `.glb`. * **Paper**: [arXiv 2602.04805](https://arxiv.org/abs/2602.04805)  ·  **Code**: [VAST-AI-Research/SkinTokens](https://github.com/VAST-AI-Research/SkinTokens)  ·  **Weights**: [🤗 VAST-AI/SkinTokens](https://huggingface.co/VAST-AI/SkinTokens) * Looking for **image → rigged 3D** instead? Try our sibling Space [🤗 VAST-AI/AniGen](https://huggingface.co/spaces/VAST-AI/AniGen). * Want a full AI-powered 3D workspace? → [Tripo](https://www.tripo3d.ai) """ ) gr.HTML( """
💡 Tips  Defaults work well for most meshes.  • If your mesh already has a skeleton and you only want skinning, enable Use existing skeleton below.  • To keep your original textures and world scale, enable Preserve original texture & scale.
""" ) with gr.Row(): with gr.Column(scale=1): files = gr.File( label="3D Models ( .obj / .fbx / .glb, up to a few at a time )", file_count="multiple", file_types=[".obj", ".fbx", ".glb"], ) with gr.Accordion("⚙️ Generation Settings", open=False): model_ckpt = gr.Dropdown( choices=model_ckpts, value=default_ckpt, label="Model checkpoint", info="TokenRig autoregressive rigging model. The default is the GRPO-refined checkpoint recommended for most assets.", interactive=True, ) # Keep the hf_path component for callback compatibility, but hide it # from the UI since it currently only exposes the default ("None") option. hf_path = gr.Dropdown( choices=hf_paths, value=default_hf, label="HF path (advanced)", visible=False, ) gr.Markdown("**Sampling parameters** — control autoregressive decoding of the rig.") top_k = gr.Slider( 1, 200, value=5, step=1, label="top_k", info="Sample from the K most likely next tokens at each step. Lower = more deterministic output.", ) top_p = gr.Slider( 0.1, 1.0, value=0.95, step=0.01, label="top_p (nucleus)", info="Sample from the smallest set of tokens whose cumulative probability ≥ p.", ) temperature = gr.Slider( 0.1, 2.0, value=1.0, step=0.1, label="temperature", info="Softmax temperature. <1 sharpens the distribution (more conservative), >1 makes it flatter (more diverse).", ) repetition_penalty = gr.Slider( 0.5, 3.0, value=2.0, step=0.1, label="repetition_penalty", info="Multiplicative penalty on tokens that have already been generated. 1.0 = no penalty.", ) num_beams = gr.Slider( 1, 20, value=10, step=1, label="num_beams", info="Beam-search width. Larger = higher quality but slower; 1 disables beam search.", ) gr.Markdown("**Pipeline toggles**") use_skeleton = gr.Checkbox( False, label="Use existing skeleton (predict skinning only)", info="If the uploaded file already contains a skeleton, keep it and only predict per-vertex skinning weights.", ) use_transfer = gr.Checkbox( False, label="Preserve original texture & scale", info="Transfer the predicted rig back onto the original (unprocessed) mesh, so textures and world units are preserved.", ) use_postprocess = gr.Checkbox( False, label="Voxel skin post-processing", info="Apply a voxel-based mask to the predicted skin weights before normalization. Slower.", ) run_btn = gr.Button("🚀 Run", variant="primary") with gr.Column(scale=1): log = gr.Textbox(label="Status", lines=2, interactive=False) output = gr.File(label="Rigged GLB output", interactive=False) gr.Markdown( """ **Notes** - The output `.glb` contains the predicted **skeleton + skinning weights**. Import it in Blender (File → Import → glTF 2.0) or any DCC tool that reads glTF. - In Blender, if you see a `glTF_not_exported` placeholder node, you can safely remove it. - On busy moments Zero-GPU may queue your request for ~10–30 s before inference starts — the status box will update once the GPU is attached. - Please do **not** upload confidential or NSFW content. See the [project page](https://zjp-shadow.github.io/works/SkinTokens/) for paper-accurate results and the [code repo](https://github.com/VAST-AI-Research/SkinTokens) for local / batch inference. """ ) run_btn.click( run_gradio, inputs=[ files, top_k, top_p, temperature, repetition_penalty, num_beams, use_skeleton, use_transfer, use_postprocess, model_ckpt, hf_path, ], outputs=[log, output], ) return app demo = build_gradio_app() # Note: we do NOT pre-warm `bpy_server` in the main process. `bpy_server.py` # transitively imports `src.model.michelangelo.utils.misc`, whose # module-level `use_flash3 = FLASH3()` calls `torch.cuda.get_device_name(0)` # at import time. That call fails ("RuntimeError: No CUDA GPUs are # available") in the main Gradio process on ZeroGPU, where the GPU is only # attached inside `@spaces.GPU`-decorated workers. So the bpy_server boot # happens on first request, inside the worker. # --------------------------------------------------------------------------- # Entry point. # --------------------------------------------------------------------------- if __name__ == "__main__": parser = argparse.ArgumentParser("TokenRig Demo") parser.add_argument("--input", help="Input file or directory") parser.add_argument("--output", help="Output file or directory") parser.add_argument("--top_k", type=int, default=5) parser.add_argument("--top_p", type=float, default=0.95) parser.add_argument("--temperature", type=float, default=1.0) parser.add_argument("--repetition_penalty", type=float, default=2.0) parser.add_argument("--num_beams", type=int, default=10) parser.add_argument("--use_skeleton", action="store_true") parser.add_argument("--use_transfer", action="store_true") parser.add_argument("--use_postprocess", action="store_true") parser.add_argument("--model_ckpt", default=MODEL_CKPTS[0] if MODEL_CKPTS else "") parser.add_argument("--hf_path", default=None) parser.add_argument("--gradio", action="store_true") args = parser.parse_args() if args.gradio or not args.input: demo.queue() demo.launch(ssr_mode=False) else: ensure_bpy_server_started() run_cli(args)