Spaces:
Paused
Paused
| import argparse | |
| import atexit | |
| import importlib | |
| import os | |
| import signal | |
| import subprocess | |
| import sys | |
| import tempfile | |
| import time | |
| from pathlib import Path | |
| from typing import List, Optional, Tuple | |
| import gradio as gr | |
| import requests | |
| from torch import Tensor | |
| from tqdm import tqdm | |
| # --------------------------------------------------------------------------- | |
| # ZeroGPU compatibility shim. The hosted HF Space provides the `spaces` | |
| # package; running locally we substitute a no-op. | |
| # --------------------------------------------------------------------------- | |
| try: | |
| spaces = importlib.import_module("spaces") | |
| except Exception: | |
| class _SpacesCompat: | |
| def GPU(*args, **kwargs): | |
| if len(args) == 1 and callable(args[0]) and not kwargs: | |
| return args[0] | |
| def _decorator(fn): | |
| return fn | |
| return _decorator | |
| spaces = _SpacesCompat() | |
| os.environ.setdefault("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "1") | |
| gr.TEMP_DIR = "tmp_gradio" | |
| # --------------------------------------------------------------------------- | |
| # Install the bundled `bpy` wheel at runtime if it isn't already importable. | |
| # | |
| # Why this is non-trivial: | |
| # - Putting the wheel in requirements.txt fails: HF Spaces' Docker build | |
| # mounts only requirements.txt BEFORE the repo COPY, so the wheel path | |
| # doesn't exist at pip-install time. | |
| # - PyPI doesn't ship a bpy wheel matching this exact build (rc0 / cp312 / | |
| # manylinux_2_39). | |
| # - The `bpy-*.whl` committed in this repo gets auto-tracked by HF's LFS | |
| # layer (Hub auto-LFS for blobs > ~10 MB even when .gitattributes doesn't | |
| # list `*.whl`). The container's COPY-from-repo only carries the LFS | |
| # *pointer* file — a ~150-byte text stub — not the actual wheel binary. | |
| # So `pip install <wheel>` and `zipfile.ZipFile(<wheel>)` both fail with | |
| # "is not a zip file" / "Wheel is invalid". | |
| # | |
| # So: we detect the LFS-pointer case and re-fetch the real wheel from the | |
| # HF Hub at runtime (where the API resolves LFS server-side), then extract | |
| # it directly into site-packages. | |
| # --------------------------------------------------------------------------- | |
| def _ensure_bpy_installed(): | |
| try: | |
| import bpy # noqa: F401 | |
| return | |
| except Exception: | |
| pass | |
| import glob | |
| import sysconfig | |
| import zipfile | |
| here = os.path.dirname(os.path.abspath(__file__)) | |
| wheels = sorted(glob.glob(os.path.join(here, "bpy-*.whl"))) | |
| if not wheels: | |
| print("[demo] WARNING: bpy not importable and no bundled wheel found", flush=True) | |
| return | |
| wheel = wheels[-1] | |
| wheel_name = os.path.basename(wheel) | |
| # Detect LFS pointer (text stub starting with "version https://git-lfs..."). | |
| is_real_zip = False | |
| try: | |
| with open(wheel, "rb") as f: | |
| is_real_zip = f.read(4).startswith(b"PK") | |
| except Exception: | |
| pass | |
| if not is_real_zip: | |
| print( | |
| f"[demo] {wheel_name} on disk is an LFS pointer ({os.path.getsize(wheel)} B); " | |
| f"fetching real wheel from HF Hub...", | |
| flush=True, | |
| ) | |
| from huggingface_hub import hf_hub_download | |
| space_id = os.environ.get("SPACE_ID", "VAST-AI/SkinTokens") | |
| token = os.environ.get("HF_TOKEN") # set as a Space secret for private repos | |
| wheel = hf_hub_download( | |
| repo_id=space_id, | |
| repo_type="space", | |
| filename=wheel_name, | |
| token=token, | |
| ) | |
| print(f"[demo] fetched -> {wheel} ({os.path.getsize(wheel)} B)", flush=True) | |
| site = sysconfig.get_paths()["purelib"] | |
| print(f"[demo] Extracting {wheel_name} into {site}", flush=True) | |
| with zipfile.ZipFile(wheel) as z: | |
| z.extractall(site) | |
| print("[demo] bpy wheel extracted.", flush=True) | |
| _ensure_bpy_installed() | |
| # --------------------------------------------------------------------------- | |
| # Download model checkpoints (TokenRig + SkinTokens FSQ-CVAE) and the Qwen3 | |
| # tokenizer/config on first cold-start. | |
| # | |
| # These live in the *model* repo `VAST-AI/SkinTokens` (private), separate | |
| # from this Space repo, so they aren't COPYed into the container. Re-uses | |
| # `HF_TOKEN` from the Space secrets. | |
| # --------------------------------------------------------------------------- | |
| def _ensure_models_downloaded(): | |
| here = os.path.dirname(os.path.abspath(__file__)) | |
| needed_ckpts = [ | |
| "experiments/skin_vae_2_10_32768/last.ckpt", | |
| "experiments/articulation_xl_quantization_256_token_4/grpo_1400.ckpt", | |
| ] | |
| qwen_dir = os.path.join(here, "models", "Qwen3-0.6B") | |
| all_present = ( | |
| all(os.path.exists(os.path.join(here, p)) for p in needed_ckpts) | |
| and os.path.exists(os.path.join(qwen_dir, "tokenizer.json")) | |
| ) | |
| if all_present: | |
| return | |
| from huggingface_hub import hf_hub_download, snapshot_download | |
| token = os.environ.get("HF_TOKEN") | |
| for rel in needed_ckpts: | |
| target = os.path.join(here, rel) | |
| if os.path.exists(target): | |
| continue | |
| print(f"[demo] Downloading checkpoint: {rel}", flush=True) | |
| hf_hub_download( | |
| repo_id="VAST-AI/SkinTokens", | |
| filename=rel, | |
| local_dir=here, | |
| token=token, | |
| ) | |
| if not os.path.exists(os.path.join(qwen_dir, "tokenizer.json")): | |
| print("[demo] Downloading Qwen3-0.6B tokenizer/config", flush=True) | |
| snapshot_download( | |
| repo_id="Qwen/Qwen3-0.6B", | |
| local_dir=qwen_dir, | |
| ignore_patterns=["*.bin", "*.safetensors"], | |
| ) | |
| print("[demo] All checkpoints ready.", flush=True) | |
| _ensure_models_downloaded() | |
| from src.data.dataset import DatasetConfig, RigDatasetModule | |
| from src.data.transform import Transform | |
| from src.model.tokenrig import TokenRigResult | |
| from src.tokenizer.parse import get_tokenizer | |
| from src.server.spec import ( | |
| BPY_SERVER, | |
| get_model, | |
| object_to_bytes, | |
| bytes_to_object, | |
| ) | |
| from src.data.vertex_group import voxel_skin | |
| # --------------------------------------------------------------------------- | |
| # Pre-warm `bpy_server` in the main (Gradio) process at module load. | |
| # | |
| # Why this is necessary on ZeroGPU: each user request runs inside a fresh | |
| # `@spaces.GPU` worker process with a hard time budget (≈60 s on free tier). | |
| # Importing the Blender shared object inside that budget burns 30–60 s, so | |
| # the worker is killed *during* bpy import — manifesting as | |
| # "GPU task aborted" before any model code runs. | |
| # | |
| # We start `bpy_server.py` here, in the always-running main process, so the | |
| # slow bpy import happens exactly once at Space boot. Workers then just hit | |
| # `localhost:59876` over HTTP — sub-millisecond, no startup cost. | |
| # --------------------------------------------------------------------------- | |
| MODEL_CKPTS = [ | |
| "experiments/articulation_xl_quantization_256_token_4/grpo_1400.ckpt", | |
| ] | |
| HF_PATHS = [ | |
| "None", | |
| ] | |
| def get_dataloader_workers() -> int: | |
| if os.getenv("SPACE_ID"): | |
| return 0 | |
| return 1 | |
| # --------------------------------------------------------------------------- | |
| # bpy_server lifecycle — lazy start so the heavy import doesn't fight ZeroGPU | |
| # during module load. | |
| # --------------------------------------------------------------------------- | |
| _BPY_SERVER_PROC = None | |
| def is_bpy_server_alive(timeout: float = 1.0) -> bool: | |
| try: | |
| resp = requests.get(f"{BPY_SERVER}/ping", timeout=timeout) | |
| return resp.status_code == 200 | |
| except Exception: | |
| return False | |
| def start_bpy_server(): | |
| proc = subprocess.Popen( | |
| [sys.executable, "bpy_server.py"], | |
| stdout=None, | |
| stderr=None, | |
| preexec_fn=os.setsid, | |
| ) | |
| print(f"[Main] bpy_server.py started (pid={proc.pid})") | |
| def cleanup(): | |
| print(f"[Main] Terminating bpy_server.py (pid={proc.pid})") | |
| try: | |
| os.killpg(os.getpgid(proc.pid), signal.SIGTERM) | |
| except ProcessLookupError: | |
| pass | |
| atexit.register(cleanup) | |
| return proc | |
| def wait_for_bpy_server(timeout: float = 120): | |
| """Wait for bpy_server.py to come up. The first start of bpy_server is | |
| slow because importing the Blender `.so` (~200 MB shared object) takes | |
| 30–60 s on a cold container. We allow up to 120 s.""" | |
| t0 = time.time() | |
| last_log = 0.0 | |
| while True: | |
| try: | |
| requests.get(f"{BPY_SERVER}/ping", timeout=1) | |
| print(f"[Main] bpy_server is ready (after {time.time() - t0:.1f}s)") | |
| return | |
| except Exception: | |
| now = time.time() | |
| if now - t0 > timeout: | |
| raise RuntimeError( | |
| f"bpy_server failed to start after {timeout:.0f}s" | |
| ) | |
| if now - last_log > 10: # progress every 10s | |
| print(f"[Main] still waiting for bpy_server ({now - t0:.0f}s elapsed)") | |
| last_log = now | |
| time.sleep(0.5) | |
| def ensure_bpy_server_started(): | |
| global _BPY_SERVER_PROC | |
| if is_bpy_server_alive(): | |
| return | |
| if _BPY_SERVER_PROC is not None and _BPY_SERVER_PROC.poll() is None: | |
| return | |
| _BPY_SERVER_PROC = start_bpy_server() | |
| wait_for_bpy_server() | |
| # --------------------------------------------------------------------------- | |
| # Lazy model loading. | |
| # --------------------------------------------------------------------------- | |
| model = None | |
| tokenizer = None | |
| transform = None | |
| CURRENT_MODEL_CKPT: Optional[str] = None | |
| CURRENT_HF_PATH: Optional[str] = None | |
| def load_model(model_ckpt: str, hf_path: Optional[str]) -> Tuple[str, str]: | |
| global model, tokenizer, transform, CURRENT_MODEL_CKPT, CURRENT_HF_PATH | |
| if hf_path == "None": | |
| hf_path = None | |
| if model is not None and model_ckpt == CURRENT_MODEL_CKPT and hf_path == CURRENT_HF_PATH: | |
| return ("Model already loaded.", model_ckpt) | |
| if not model_ckpt: | |
| raise RuntimeError("model_ckpt is empty. Please select a checkpoint.") | |
| print(f"Loading model: {model_ckpt}, hf_path={hf_path}") | |
| model = get_model(model_ckpt, hf_path=hf_path) | |
| assert model.tokenizer_config is not None | |
| tokenizer = get_tokenizer(**model.tokenizer_config) | |
| transform = Transform.parse(**model.transform_config["predict_transform"]) | |
| CURRENT_MODEL_CKPT = model_ckpt | |
| CURRENT_HF_PATH = hf_path | |
| return ("Model loaded.", model_ckpt) | |
| # --------------------------------------------------------------------------- | |
| # File utilities (CLI-side). | |
| # --------------------------------------------------------------------------- | |
| SUPPORTED_EXT = {".obj", ".fbx", ".glb"} | |
| def collect_files(input_path: Path) -> List[Path]: | |
| if input_path.is_file(): | |
| return [input_path] | |
| files = [] | |
| for p in input_path.rglob("*"): | |
| if p.suffix.lower() in SUPPORTED_EXT: | |
| files.append(p) | |
| return files | |
| def map_output_path(in_path: Path, input_root: Path, output_root: Path) -> Path: | |
| rel = in_path.relative_to(input_root) | |
| return (output_root / rel).with_suffix(".glb") | |
| # --------------------------------------------------------------------------- | |
| # Core inference (shared by CLI and Gradio). | |
| # --------------------------------------------------------------------------- | |
| def run_rig( | |
| filepaths: List[Path], | |
| top_k: int, | |
| top_p: float, | |
| temperature: float, | |
| repetition_penalty: float, | |
| num_beams: int, | |
| use_skeleton: bool, | |
| use_transfer: bool, | |
| use_postprocess: bool, | |
| output_paths: List[Path], | |
| model_ckpt: str, | |
| hf_path: Optional[str], | |
| ): | |
| assert len(filepaths) == len(output_paths) | |
| ensure_bpy_server_started() | |
| load_model(model_ckpt, hf_path) | |
| datapath = { | |
| "data_name": None, | |
| "loader": "bpy_server", | |
| "filepaths": {"articulation": [str(p) for p in filepaths]}, | |
| } | |
| dataset_config = DatasetConfig.parse( | |
| shuffle=False, | |
| batch_size=1, | |
| num_workers=get_dataloader_workers(), | |
| pin_memory=get_dataloader_workers() > 0, | |
| persistent_workers=False, | |
| datapath=datapath, | |
| ).split_by_cls() | |
| module = RigDatasetModule( | |
| predict_dataset_config=dataset_config, | |
| predict_transform=transform, | |
| tokenizer=tokenizer, | |
| process_fn=model._process_fn, | |
| ) | |
| dataloader = module.predict_dataloader()["articulation"] | |
| results_out = [] | |
| infer_device = model.device if model is not None else "cuda" | |
| for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)): | |
| batch = { | |
| k: v.to(infer_device) if isinstance(v, Tensor) else v | |
| for k, v in batch.items() | |
| } | |
| if not use_skeleton: | |
| batch.pop("skeleton_tokens", None) | |
| batch.pop("skeleton_mask", None) | |
| batch["generate_kwargs"] = dict( | |
| max_length=2048, | |
| top_k=int(top_k), | |
| top_p=float(top_p), | |
| temperature=float(temperature), | |
| repetition_penalty=float(repetition_penalty), | |
| num_return_sequences=1, | |
| num_beams=int(num_beams), | |
| do_sample=True, | |
| ) | |
| if "skeleton_tokens" in batch and "skeleton_mask" in batch: | |
| mask = batch["skeleton_mask"][0] == 1 | |
| skeleton_tokens = batch["skeleton_tokens"][0][mask].cpu().numpy() | |
| else: | |
| skeleton_tokens = None | |
| preds: List[TokenRigResult] = model.predict_step( | |
| batch, | |
| skeleton_tokens=[skeleton_tokens] if skeleton_tokens is not None else None, | |
| make_asset=True, | |
| )["results"] | |
| asset = preds[0].asset | |
| assert asset is not None | |
| if use_postprocess: | |
| voxel = asset.voxel(resolution=196) | |
| asset.skin *= voxel_skin( | |
| grid=0, | |
| grid_coords=voxel.coords, | |
| joints=asset.joints, | |
| vertices=asset.vertices, | |
| faces=asset.faces, | |
| mode="square", | |
| voxel_size=voxel.voxel_size, | |
| ) | |
| asset.normalize_skin() | |
| out_path = output_paths[i] | |
| out_path.parent.mkdir(parents=True, exist_ok=True) | |
| if use_transfer: | |
| payload = dict( | |
| source_asset=asset, | |
| target_path=asset.path, | |
| export_path=str(out_path), | |
| group_per_vertex=4, | |
| ) | |
| res = bytes_to_object( | |
| requests.post( | |
| f"{BPY_SERVER}/transfer", | |
| data=object_to_bytes(payload), | |
| ).content | |
| ) | |
| else: | |
| payload = dict( | |
| asset=asset, | |
| filepath=str(out_path), | |
| group_per_vertex=4, | |
| ) | |
| res = bytes_to_object( | |
| requests.post( | |
| f"{BPY_SERVER}/export", | |
| data=object_to_bytes(payload), | |
| ).content | |
| ) | |
| if res != "ok": | |
| print(f"[Error] {res}") | |
| else: | |
| print(f"[OK] Exported: {out_path}") | |
| results_out.append(out_path) | |
| return results_out | |
| # --------------------------------------------------------------------------- | |
| # CLI entry point. | |
| # --------------------------------------------------------------------------- | |
| def run_cli(args): | |
| input_path = Path(args.input).resolve() | |
| output_path = Path(args.output).resolve() | |
| files = collect_files(input_path) | |
| if not files: | |
| raise RuntimeError("No valid 3D files found.") | |
| if len(files) == 1 and output_path.suffix: | |
| outputs = [output_path] | |
| else: | |
| outputs = [map_output_path(f, input_path, output_path) for f in files] | |
| run_rig( | |
| files, | |
| args.top_k, | |
| args.top_p, | |
| args.temperature, | |
| args.repetition_penalty, | |
| args.num_beams, | |
| args.use_skeleton, | |
| args.use_transfer, | |
| args.use_postprocess, | |
| outputs, | |
| args.model_ckpt, | |
| args.hf_path, | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Gradio wrapper (with ZeroGPU duration estimator). | |
| # --------------------------------------------------------------------------- | |
| TOT = 0 | |
| def _gpu_duration( | |
| files, | |
| top_k, | |
| top_p, | |
| temperature, | |
| repetition_penalty, | |
| num_beams, | |
| use_skeleton, | |
| use_transfer, | |
| use_postprocess, | |
| model_ckpt, | |
| hf_path, | |
| ): | |
| # Cold workers spend ~30–60 s importing bpy + loading the model before | |
| # any GPU work. Give every request a generous 240 s floor. | |
| file_count = len(files) if files is not None else 1 | |
| return min(900, max(240, 240 + 60 * file_count)) | |
| def run_gradio( | |
| files, | |
| top_k, | |
| top_p, | |
| temperature, | |
| repetition_penalty, | |
| num_beams, | |
| use_skeleton, | |
| use_transfer, | |
| use_postprocess, | |
| model_ckpt, | |
| hf_path, | |
| ): | |
| if not files: | |
| return "Please upload at least one 3D model.", None | |
| tmp_out = Path(tempfile.mkdtemp(prefix="tokenrig_")) | |
| filepaths = [Path(f.name) for f in files] | |
| global TOT | |
| outputs = [] | |
| for filepath in filepaths: | |
| TOT += 1 | |
| outputs.append(tmp_out / f"res_{TOT}.glb") | |
| run_rig( | |
| filepaths, | |
| top_k, | |
| top_p, | |
| temperature, | |
| repetition_penalty, | |
| num_beams, | |
| use_skeleton, | |
| use_transfer, | |
| use_postprocess, | |
| outputs, | |
| model_ckpt, | |
| hf_path, | |
| ) | |
| return f"Processed {len(outputs)} models.", [str(p) for p in outputs] | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI. | |
| # --------------------------------------------------------------------------- | |
| def build_gradio_app(): | |
| model_ckpts = MODEL_CKPTS | |
| hf_paths = HF_PATHS | |
| default_ckpt = model_ckpts[0] if model_ckpts else "" | |
| default_hf = hf_paths[0] if hf_paths else "None" | |
| with gr.Blocks(title="SkinTokens · TokenRig Demo") as app: | |
| gr.Markdown( | |
| """ | |
| ## 🦴 Mesh to Rig with [SkinTokens](https://zjp-shadow.github.io/works/SkinTokens/) · TokenRig | |
| Automated **skeleton generation + skinning weight prediction** for any 3D mesh, via a unified | |
| autoregressive model over learned *SkinTokens*. Successor to | |
| [UniRig](https://github.com/VAST-AI-Research/UniRig) (SIGGRAPH '25). | |
| * Upload one or more meshes → click **Run** → download a rigged `.glb`. | |
| * **Paper**: [arXiv 2602.04805](https://arxiv.org/abs/2602.04805) · | |
| **Code**: [VAST-AI-Research/SkinTokens](https://github.com/VAST-AI-Research/SkinTokens) · | |
| **Weights**: [🤗 VAST-AI/SkinTokens](https://huggingface.co/VAST-AI/SkinTokens) | |
| * Looking for **image → rigged 3D** instead? Try our sibling Space | |
| [🤗 VAST-AI/AniGen](https://huggingface.co/spaces/VAST-AI/AniGen). | |
| * Want a full AI-powered 3D workspace? → [Tripo](https://www.tripo3d.ai) | |
| """ | |
| ) | |
| gr.HTML( | |
| """ | |
| <style> | |
| @keyframes gentle-pulse { | |
| 0%, 100% { opacity: 1; } | |
| 50% { opacity: 0.35; } | |
| } | |
| </style> | |
| <div style="text-align:left; color:#888; font-size:1em; line-height:1.6; margin: 4px 0 -4px 0;"> | |
| <span style="animation: gentle-pulse 3s ease-in-out infinite; display:inline-block;">💡 <b>Tips</b></span>  | |
| Defaults work well for most meshes. | |
| • If your mesh already has a skeleton and you only want skinning, enable | |
| <b>Use existing skeleton</b> below. | |
| • To keep your original textures and world scale, enable <b>Preserve original texture & scale</b>. | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| files = gr.File( | |
| label="3D Models ( .obj / .fbx / .glb, up to a few at a time )", | |
| file_count="multiple", | |
| file_types=[".obj", ".fbx", ".glb"], | |
| ) | |
| with gr.Accordion("⚙️ Generation Settings", open=False): | |
| model_ckpt = gr.Dropdown( | |
| choices=model_ckpts, | |
| value=default_ckpt, | |
| label="Model checkpoint", | |
| info="TokenRig autoregressive rigging model. The default is the GRPO-refined checkpoint recommended for most assets.", | |
| interactive=True, | |
| ) | |
| # Keep the hf_path component for callback compatibility, but hide it | |
| # from the UI since it currently only exposes the default ("None") option. | |
| hf_path = gr.Dropdown( | |
| choices=hf_paths, | |
| value=default_hf, | |
| label="HF path (advanced)", | |
| visible=False, | |
| ) | |
| gr.Markdown("**Sampling parameters** — control autoregressive decoding of the rig.") | |
| top_k = gr.Slider( | |
| 1, 200, value=5, step=1, | |
| label="top_k", | |
| info="Sample from the K most likely next tokens at each step. Lower = more deterministic output.", | |
| ) | |
| top_p = gr.Slider( | |
| 0.1, 1.0, value=0.95, step=0.01, | |
| label="top_p (nucleus)", | |
| info="Sample from the smallest set of tokens whose cumulative probability ≥ p.", | |
| ) | |
| temperature = gr.Slider( | |
| 0.1, 2.0, value=1.0, step=0.1, | |
| label="temperature", | |
| info="Softmax temperature. <1 sharpens the distribution (more conservative), >1 makes it flatter (more diverse).", | |
| ) | |
| repetition_penalty = gr.Slider( | |
| 0.5, 3.0, value=2.0, step=0.1, | |
| label="repetition_penalty", | |
| info="Multiplicative penalty on tokens that have already been generated. 1.0 = no penalty.", | |
| ) | |
| num_beams = gr.Slider( | |
| 1, 20, value=10, step=1, | |
| label="num_beams", | |
| info="Beam-search width. Larger = higher quality but slower; 1 disables beam search.", | |
| ) | |
| gr.Markdown("**Pipeline toggles**") | |
| use_skeleton = gr.Checkbox( | |
| False, | |
| label="Use existing skeleton (predict skinning only)", | |
| info="If the uploaded file already contains a skeleton, keep it and only predict per-vertex skinning weights.", | |
| ) | |
| use_transfer = gr.Checkbox( | |
| False, | |
| label="Preserve original texture & scale", | |
| info="Transfer the predicted rig back onto the original (unprocessed) mesh, so textures and world units are preserved.", | |
| ) | |
| use_postprocess = gr.Checkbox( | |
| False, | |
| label="Voxel skin post-processing", | |
| info="Apply a voxel-based mask to the predicted skin weights before normalization. Slower.", | |
| ) | |
| run_btn = gr.Button("🚀 Run", variant="primary") | |
| with gr.Column(scale=1): | |
| log = gr.Textbox(label="Status", lines=2, interactive=False) | |
| output = gr.File(label="Rigged GLB output", interactive=False) | |
| gr.Markdown( | |
| """ | |
| **Notes** | |
| - The output `.glb` contains the predicted **skeleton + skinning weights**. Import it in Blender (File → Import → glTF 2.0) or any DCC tool that reads glTF. | |
| - In Blender, if you see a `glTF_not_exported` placeholder node, you can safely remove it. | |
| - On busy moments Zero-GPU may queue your request for ~10–30 s before inference starts — the status box will update once the GPU is attached. | |
| - Please do **not** upload confidential or NSFW content. See the | |
| [project page](https://zjp-shadow.github.io/works/SkinTokens/) for paper-accurate results and the | |
| [code repo](https://github.com/VAST-AI-Research/SkinTokens) for local / batch inference. | |
| """ | |
| ) | |
| run_btn.click( | |
| run_gradio, | |
| inputs=[ | |
| files, | |
| top_k, | |
| top_p, | |
| temperature, | |
| repetition_penalty, | |
| num_beams, | |
| use_skeleton, | |
| use_transfer, | |
| use_postprocess, | |
| model_ckpt, | |
| hf_path, | |
| ], | |
| outputs=[log, output], | |
| ) | |
| return app | |
| demo = build_gradio_app() | |
| # Note: we do NOT pre-warm `bpy_server` in the main process. `bpy_server.py` | |
| # transitively imports `src.model.michelangelo.utils.misc`, whose | |
| # module-level `use_flash3 = FLASH3()` calls `torch.cuda.get_device_name(0)` | |
| # at import time. That call fails ("RuntimeError: No CUDA GPUs are | |
| # available") in the main Gradio process on ZeroGPU, where the GPU is only | |
| # attached inside `@spaces.GPU`-decorated workers. So the bpy_server boot | |
| # happens on first request, inside the worker. | |
| # --------------------------------------------------------------------------- | |
| # Entry point. | |
| # --------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser("TokenRig Demo") | |
| parser.add_argument("--input", help="Input file or directory") | |
| parser.add_argument("--output", help="Output file or directory") | |
| parser.add_argument("--top_k", type=int, default=5) | |
| parser.add_argument("--top_p", type=float, default=0.95) | |
| parser.add_argument("--temperature", type=float, default=1.0) | |
| parser.add_argument("--repetition_penalty", type=float, default=2.0) | |
| parser.add_argument("--num_beams", type=int, default=10) | |
| parser.add_argument("--use_skeleton", action="store_true") | |
| parser.add_argument("--use_transfer", action="store_true") | |
| parser.add_argument("--use_postprocess", action="store_true") | |
| parser.add_argument("--model_ckpt", default=MODEL_CKPTS[0] if MODEL_CKPTS else "") | |
| parser.add_argument("--hf_path", default=None) | |
| parser.add_argument("--gradio", action="store_true") | |
| args = parser.parse_args() | |
| if args.gradio or not args.input: | |
| demo.queue() | |
| demo.launch(ssr_mode=False) | |
| else: | |
| ensure_bpy_server_started() | |
| run_cli(args) | |