SkinTokens / demo.py
pookiefoof's picture
Public release: SkinTokens · TokenRig demo
9d7cf7f
import argparse
import atexit
import importlib
import os
import signal
import subprocess
import sys
import tempfile
import time
from pathlib import Path
from typing import List, Optional, Tuple
import gradio as gr
import requests
from torch import Tensor
from tqdm import tqdm
# ---------------------------------------------------------------------------
# ZeroGPU compatibility shim. The hosted HF Space provides the `spaces`
# package; running locally we substitute a no-op.
# ---------------------------------------------------------------------------
try:
spaces = importlib.import_module("spaces")
except Exception:
class _SpacesCompat:
@staticmethod
def GPU(*args, **kwargs):
if len(args) == 1 and callable(args[0]) and not kwargs:
return args[0]
def _decorator(fn):
return fn
return _decorator
spaces = _SpacesCompat()
os.environ.setdefault("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "1")
gr.TEMP_DIR = "tmp_gradio"
# ---------------------------------------------------------------------------
# Install the bundled `bpy` wheel at runtime if it isn't already importable.
#
# Why this is non-trivial:
# - Putting the wheel in requirements.txt fails: HF Spaces' Docker build
# mounts only requirements.txt BEFORE the repo COPY, so the wheel path
# doesn't exist at pip-install time.
# - PyPI doesn't ship a bpy wheel matching this exact build (rc0 / cp312 /
# manylinux_2_39).
# - The `bpy-*.whl` committed in this repo gets auto-tracked by HF's LFS
# layer (Hub auto-LFS for blobs > ~10 MB even when .gitattributes doesn't
# list `*.whl`). The container's COPY-from-repo only carries the LFS
# *pointer* file — a ~150-byte text stub — not the actual wheel binary.
# So `pip install <wheel>` and `zipfile.ZipFile(<wheel>)` both fail with
# "is not a zip file" / "Wheel is invalid".
#
# So: we detect the LFS-pointer case and re-fetch the real wheel from the
# HF Hub at runtime (where the API resolves LFS server-side), then extract
# it directly into site-packages.
# ---------------------------------------------------------------------------
def _ensure_bpy_installed():
try:
import bpy # noqa: F401
return
except Exception:
pass
import glob
import sysconfig
import zipfile
here = os.path.dirname(os.path.abspath(__file__))
wheels = sorted(glob.glob(os.path.join(here, "bpy-*.whl")))
if not wheels:
print("[demo] WARNING: bpy not importable and no bundled wheel found", flush=True)
return
wheel = wheels[-1]
wheel_name = os.path.basename(wheel)
# Detect LFS pointer (text stub starting with "version https://git-lfs...").
is_real_zip = False
try:
with open(wheel, "rb") as f:
is_real_zip = f.read(4).startswith(b"PK")
except Exception:
pass
if not is_real_zip:
print(
f"[demo] {wheel_name} on disk is an LFS pointer ({os.path.getsize(wheel)} B); "
f"fetching real wheel from HF Hub...",
flush=True,
)
from huggingface_hub import hf_hub_download
space_id = os.environ.get("SPACE_ID", "VAST-AI/SkinTokens")
token = os.environ.get("HF_TOKEN") # set as a Space secret for private repos
wheel = hf_hub_download(
repo_id=space_id,
repo_type="space",
filename=wheel_name,
token=token,
)
print(f"[demo] fetched -> {wheel} ({os.path.getsize(wheel)} B)", flush=True)
site = sysconfig.get_paths()["purelib"]
print(f"[demo] Extracting {wheel_name} into {site}", flush=True)
with zipfile.ZipFile(wheel) as z:
z.extractall(site)
print("[demo] bpy wheel extracted.", flush=True)
_ensure_bpy_installed()
# ---------------------------------------------------------------------------
# Download model checkpoints (TokenRig + SkinTokens FSQ-CVAE) and the Qwen3
# tokenizer/config on first cold-start.
#
# These live in the *model* repo `VAST-AI/SkinTokens` (private), separate
# from this Space repo, so they aren't COPYed into the container. Re-uses
# `HF_TOKEN` from the Space secrets.
# ---------------------------------------------------------------------------
def _ensure_models_downloaded():
here = os.path.dirname(os.path.abspath(__file__))
needed_ckpts = [
"experiments/skin_vae_2_10_32768/last.ckpt",
"experiments/articulation_xl_quantization_256_token_4/grpo_1400.ckpt",
]
qwen_dir = os.path.join(here, "models", "Qwen3-0.6B")
all_present = (
all(os.path.exists(os.path.join(here, p)) for p in needed_ckpts)
and os.path.exists(os.path.join(qwen_dir, "tokenizer.json"))
)
if all_present:
return
from huggingface_hub import hf_hub_download, snapshot_download
token = os.environ.get("HF_TOKEN")
for rel in needed_ckpts:
target = os.path.join(here, rel)
if os.path.exists(target):
continue
print(f"[demo] Downloading checkpoint: {rel}", flush=True)
hf_hub_download(
repo_id="VAST-AI/SkinTokens",
filename=rel,
local_dir=here,
token=token,
)
if not os.path.exists(os.path.join(qwen_dir, "tokenizer.json")):
print("[demo] Downloading Qwen3-0.6B tokenizer/config", flush=True)
snapshot_download(
repo_id="Qwen/Qwen3-0.6B",
local_dir=qwen_dir,
ignore_patterns=["*.bin", "*.safetensors"],
)
print("[demo] All checkpoints ready.", flush=True)
_ensure_models_downloaded()
from src.data.dataset import DatasetConfig, RigDatasetModule
from src.data.transform import Transform
from src.model.tokenrig import TokenRigResult
from src.tokenizer.parse import get_tokenizer
from src.server.spec import (
BPY_SERVER,
get_model,
object_to_bytes,
bytes_to_object,
)
from src.data.vertex_group import voxel_skin
# ---------------------------------------------------------------------------
# Pre-warm `bpy_server` in the main (Gradio) process at module load.
#
# Why this is necessary on ZeroGPU: each user request runs inside a fresh
# `@spaces.GPU` worker process with a hard time budget (≈60 s on free tier).
# Importing the Blender shared object inside that budget burns 30–60 s, so
# the worker is killed *during* bpy import — manifesting as
# "GPU task aborted" before any model code runs.
#
# We start `bpy_server.py` here, in the always-running main process, so the
# slow bpy import happens exactly once at Space boot. Workers then just hit
# `localhost:59876` over HTTP — sub-millisecond, no startup cost.
# ---------------------------------------------------------------------------
MODEL_CKPTS = [
"experiments/articulation_xl_quantization_256_token_4/grpo_1400.ckpt",
]
HF_PATHS = [
"None",
]
def get_dataloader_workers() -> int:
if os.getenv("SPACE_ID"):
return 0
return 1
# ---------------------------------------------------------------------------
# bpy_server lifecycle — lazy start so the heavy import doesn't fight ZeroGPU
# during module load.
# ---------------------------------------------------------------------------
_BPY_SERVER_PROC = None
def is_bpy_server_alive(timeout: float = 1.0) -> bool:
try:
resp = requests.get(f"{BPY_SERVER}/ping", timeout=timeout)
return resp.status_code == 200
except Exception:
return False
def start_bpy_server():
proc = subprocess.Popen(
[sys.executable, "bpy_server.py"],
stdout=None,
stderr=None,
preexec_fn=os.setsid,
)
print(f"[Main] bpy_server.py started (pid={proc.pid})")
def cleanup():
print(f"[Main] Terminating bpy_server.py (pid={proc.pid})")
try:
os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
except ProcessLookupError:
pass
atexit.register(cleanup)
return proc
def wait_for_bpy_server(timeout: float = 120):
"""Wait for bpy_server.py to come up. The first start of bpy_server is
slow because importing the Blender `.so` (~200 MB shared object) takes
30–60 s on a cold container. We allow up to 120 s."""
t0 = time.time()
last_log = 0.0
while True:
try:
requests.get(f"{BPY_SERVER}/ping", timeout=1)
print(f"[Main] bpy_server is ready (after {time.time() - t0:.1f}s)")
return
except Exception:
now = time.time()
if now - t0 > timeout:
raise RuntimeError(
f"bpy_server failed to start after {timeout:.0f}s"
)
if now - last_log > 10: # progress every 10s
print(f"[Main] still waiting for bpy_server ({now - t0:.0f}s elapsed)")
last_log = now
time.sleep(0.5)
def ensure_bpy_server_started():
global _BPY_SERVER_PROC
if is_bpy_server_alive():
return
if _BPY_SERVER_PROC is not None and _BPY_SERVER_PROC.poll() is None:
return
_BPY_SERVER_PROC = start_bpy_server()
wait_for_bpy_server()
# ---------------------------------------------------------------------------
# Lazy model loading.
# ---------------------------------------------------------------------------
model = None
tokenizer = None
transform = None
CURRENT_MODEL_CKPT: Optional[str] = None
CURRENT_HF_PATH: Optional[str] = None
def load_model(model_ckpt: str, hf_path: Optional[str]) -> Tuple[str, str]:
global model, tokenizer, transform, CURRENT_MODEL_CKPT, CURRENT_HF_PATH
if hf_path == "None":
hf_path = None
if model is not None and model_ckpt == CURRENT_MODEL_CKPT and hf_path == CURRENT_HF_PATH:
return ("Model already loaded.", model_ckpt)
if not model_ckpt:
raise RuntimeError("model_ckpt is empty. Please select a checkpoint.")
print(f"Loading model: {model_ckpt}, hf_path={hf_path}")
model = get_model(model_ckpt, hf_path=hf_path)
assert model.tokenizer_config is not None
tokenizer = get_tokenizer(**model.tokenizer_config)
transform = Transform.parse(**model.transform_config["predict_transform"])
CURRENT_MODEL_CKPT = model_ckpt
CURRENT_HF_PATH = hf_path
return ("Model loaded.", model_ckpt)
# ---------------------------------------------------------------------------
# File utilities (CLI-side).
# ---------------------------------------------------------------------------
SUPPORTED_EXT = {".obj", ".fbx", ".glb"}
def collect_files(input_path: Path) -> List[Path]:
if input_path.is_file():
return [input_path]
files = []
for p in input_path.rglob("*"):
if p.suffix.lower() in SUPPORTED_EXT:
files.append(p)
return files
def map_output_path(in_path: Path, input_root: Path, output_root: Path) -> Path:
rel = in_path.relative_to(input_root)
return (output_root / rel).with_suffix(".glb")
# ---------------------------------------------------------------------------
# Core inference (shared by CLI and Gradio).
# ---------------------------------------------------------------------------
def run_rig(
filepaths: List[Path],
top_k: int,
top_p: float,
temperature: float,
repetition_penalty: float,
num_beams: int,
use_skeleton: bool,
use_transfer: bool,
use_postprocess: bool,
output_paths: List[Path],
model_ckpt: str,
hf_path: Optional[str],
):
assert len(filepaths) == len(output_paths)
ensure_bpy_server_started()
load_model(model_ckpt, hf_path)
datapath = {
"data_name": None,
"loader": "bpy_server",
"filepaths": {"articulation": [str(p) for p in filepaths]},
}
dataset_config = DatasetConfig.parse(
shuffle=False,
batch_size=1,
num_workers=get_dataloader_workers(),
pin_memory=get_dataloader_workers() > 0,
persistent_workers=False,
datapath=datapath,
).split_by_cls()
module = RigDatasetModule(
predict_dataset_config=dataset_config,
predict_transform=transform,
tokenizer=tokenizer,
process_fn=model._process_fn,
)
dataloader = module.predict_dataloader()["articulation"]
results_out = []
infer_device = model.device if model is not None else "cuda"
for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
batch = {
k: v.to(infer_device) if isinstance(v, Tensor) else v
for k, v in batch.items()
}
if not use_skeleton:
batch.pop("skeleton_tokens", None)
batch.pop("skeleton_mask", None)
batch["generate_kwargs"] = dict(
max_length=2048,
top_k=int(top_k),
top_p=float(top_p),
temperature=float(temperature),
repetition_penalty=float(repetition_penalty),
num_return_sequences=1,
num_beams=int(num_beams),
do_sample=True,
)
if "skeleton_tokens" in batch and "skeleton_mask" in batch:
mask = batch["skeleton_mask"][0] == 1
skeleton_tokens = batch["skeleton_tokens"][0][mask].cpu().numpy()
else:
skeleton_tokens = None
preds: List[TokenRigResult] = model.predict_step(
batch,
skeleton_tokens=[skeleton_tokens] if skeleton_tokens is not None else None,
make_asset=True,
)["results"]
asset = preds[0].asset
assert asset is not None
if use_postprocess:
voxel = asset.voxel(resolution=196)
asset.skin *= voxel_skin(
grid=0,
grid_coords=voxel.coords,
joints=asset.joints,
vertices=asset.vertices,
faces=asset.faces,
mode="square",
voxel_size=voxel.voxel_size,
)
asset.normalize_skin()
out_path = output_paths[i]
out_path.parent.mkdir(parents=True, exist_ok=True)
if use_transfer:
payload = dict(
source_asset=asset,
target_path=asset.path,
export_path=str(out_path),
group_per_vertex=4,
)
res = bytes_to_object(
requests.post(
f"{BPY_SERVER}/transfer",
data=object_to_bytes(payload),
).content
)
else:
payload = dict(
asset=asset,
filepath=str(out_path),
group_per_vertex=4,
)
res = bytes_to_object(
requests.post(
f"{BPY_SERVER}/export",
data=object_to_bytes(payload),
).content
)
if res != "ok":
print(f"[Error] {res}")
else:
print(f"[OK] Exported: {out_path}")
results_out.append(out_path)
return results_out
# ---------------------------------------------------------------------------
# CLI entry point.
# ---------------------------------------------------------------------------
def run_cli(args):
input_path = Path(args.input).resolve()
output_path = Path(args.output).resolve()
files = collect_files(input_path)
if not files:
raise RuntimeError("No valid 3D files found.")
if len(files) == 1 and output_path.suffix:
outputs = [output_path]
else:
outputs = [map_output_path(f, input_path, output_path) for f in files]
run_rig(
files,
args.top_k,
args.top_p,
args.temperature,
args.repetition_penalty,
args.num_beams,
args.use_skeleton,
args.use_transfer,
args.use_postprocess,
outputs,
args.model_ckpt,
args.hf_path,
)
# ---------------------------------------------------------------------------
# Gradio wrapper (with ZeroGPU duration estimator).
# ---------------------------------------------------------------------------
TOT = 0
def _gpu_duration(
files,
top_k,
top_p,
temperature,
repetition_penalty,
num_beams,
use_skeleton,
use_transfer,
use_postprocess,
model_ckpt,
hf_path,
):
# Cold workers spend ~30–60 s importing bpy + loading the model before
# any GPU work. Give every request a generous 240 s floor.
file_count = len(files) if files is not None else 1
return min(900, max(240, 240 + 60 * file_count))
@spaces.GPU(duration=_gpu_duration)
def run_gradio(
files,
top_k,
top_p,
temperature,
repetition_penalty,
num_beams,
use_skeleton,
use_transfer,
use_postprocess,
model_ckpt,
hf_path,
):
if not files:
return "Please upload at least one 3D model.", None
tmp_out = Path(tempfile.mkdtemp(prefix="tokenrig_"))
filepaths = [Path(f.name) for f in files]
global TOT
outputs = []
for filepath in filepaths:
TOT += 1
outputs.append(tmp_out / f"res_{TOT}.glb")
run_rig(
filepaths,
top_k,
top_p,
temperature,
repetition_penalty,
num_beams,
use_skeleton,
use_transfer,
use_postprocess,
outputs,
model_ckpt,
hf_path,
)
return f"Processed {len(outputs)} models.", [str(p) for p in outputs]
# ---------------------------------------------------------------------------
# Gradio UI.
# ---------------------------------------------------------------------------
def build_gradio_app():
model_ckpts = MODEL_CKPTS
hf_paths = HF_PATHS
default_ckpt = model_ckpts[0] if model_ckpts else ""
default_hf = hf_paths[0] if hf_paths else "None"
with gr.Blocks(title="SkinTokens · TokenRig Demo") as app:
gr.Markdown(
"""
## 🦴 Mesh to Rig with [SkinTokens](https://zjp-shadow.github.io/works/SkinTokens/) · TokenRig
Automated **skeleton generation + skinning weight prediction** for any 3D mesh, via a unified
autoregressive model over learned *SkinTokens*. Successor to
[UniRig](https://github.com/VAST-AI-Research/UniRig) (SIGGRAPH&nbsp;'25).
* Upload one or more meshes → click **Run** → download a rigged `.glb`.
* **Paper**: [arXiv&nbsp;2602.04805](https://arxiv.org/abs/2602.04805) &nbsp;·&nbsp;
**Code**: [VAST-AI-Research/SkinTokens](https://github.com/VAST-AI-Research/SkinTokens) &nbsp;·&nbsp;
**Weights**: [🤗&nbsp;VAST-AI/SkinTokens](https://huggingface.co/VAST-AI/SkinTokens)
* Looking for **image → rigged 3D** instead? Try our sibling Space
[🤗&nbsp;VAST-AI/AniGen](https://huggingface.co/spaces/VAST-AI/AniGen).
* Want a full AI-powered 3D workspace? → [Tripo](https://www.tripo3d.ai)
"""
)
gr.HTML(
"""
<style>
@keyframes gentle-pulse {
0%, 100% { opacity: 1; }
50% { opacity: 0.35; }
}
</style>
<div style="text-align:left; color:#888; font-size:1em; line-height:1.6; margin: 4px 0 -4px 0;">
<span style="animation: gentle-pulse 3s ease-in-out infinite; display:inline-block;">&#128161; <b>Tips</b></span>&ensp;
Defaults work well for most meshes.
&nbsp;• If your mesh already has a skeleton and you only want skinning, enable
<b>Use existing skeleton</b> below.
&nbsp;• To keep your original textures and world scale, enable <b>Preserve original texture &amp; scale</b>.
</div>
"""
)
with gr.Row():
with gr.Column(scale=1):
files = gr.File(
label="3D Models ( .obj / .fbx / .glb, up to a few at a time )",
file_count="multiple",
file_types=[".obj", ".fbx", ".glb"],
)
with gr.Accordion("⚙️ Generation Settings", open=False):
model_ckpt = gr.Dropdown(
choices=model_ckpts,
value=default_ckpt,
label="Model checkpoint",
info="TokenRig autoregressive rigging model. The default is the GRPO-refined checkpoint recommended for most assets.",
interactive=True,
)
# Keep the hf_path component for callback compatibility, but hide it
# from the UI since it currently only exposes the default ("None") option.
hf_path = gr.Dropdown(
choices=hf_paths,
value=default_hf,
label="HF path (advanced)",
visible=False,
)
gr.Markdown("**Sampling parameters** — control autoregressive decoding of the rig.")
top_k = gr.Slider(
1, 200, value=5, step=1,
label="top_k",
info="Sample from the K most likely next tokens at each step. Lower = more deterministic output.",
)
top_p = gr.Slider(
0.1, 1.0, value=0.95, step=0.01,
label="top_p (nucleus)",
info="Sample from the smallest set of tokens whose cumulative probability ≥ p.",
)
temperature = gr.Slider(
0.1, 2.0, value=1.0, step=0.1,
label="temperature",
info="Softmax temperature. <1 sharpens the distribution (more conservative), >1 makes it flatter (more diverse).",
)
repetition_penalty = gr.Slider(
0.5, 3.0, value=2.0, step=0.1,
label="repetition_penalty",
info="Multiplicative penalty on tokens that have already been generated. 1.0 = no penalty.",
)
num_beams = gr.Slider(
1, 20, value=10, step=1,
label="num_beams",
info="Beam-search width. Larger = higher quality but slower; 1 disables beam search.",
)
gr.Markdown("**Pipeline toggles**")
use_skeleton = gr.Checkbox(
False,
label="Use existing skeleton (predict skinning only)",
info="If the uploaded file already contains a skeleton, keep it and only predict per-vertex skinning weights.",
)
use_transfer = gr.Checkbox(
False,
label="Preserve original texture & scale",
info="Transfer the predicted rig back onto the original (unprocessed) mesh, so textures and world units are preserved.",
)
use_postprocess = gr.Checkbox(
False,
label="Voxel skin post-processing",
info="Apply a voxel-based mask to the predicted skin weights before normalization. Slower.",
)
run_btn = gr.Button("🚀 Run", variant="primary")
with gr.Column(scale=1):
log = gr.Textbox(label="Status", lines=2, interactive=False)
output = gr.File(label="Rigged GLB output", interactive=False)
gr.Markdown(
"""
**Notes**
- The output `.glb` contains the predicted **skeleton + skinning weights**. Import it in Blender (File → Import → glTF&nbsp;2.0) or any DCC tool that reads glTF.
- In Blender, if you see a `glTF_not_exported` placeholder node, you can safely remove it.
- On busy moments Zero-GPU may queue your request for ~10–30&nbsp;s before inference starts — the status box will update once the GPU is attached.
- Please do **not** upload confidential or NSFW content. See the
[project page](https://zjp-shadow.github.io/works/SkinTokens/) for paper-accurate results and the
[code repo](https://github.com/VAST-AI-Research/SkinTokens) for local / batch inference.
"""
)
run_btn.click(
run_gradio,
inputs=[
files,
top_k,
top_p,
temperature,
repetition_penalty,
num_beams,
use_skeleton,
use_transfer,
use_postprocess,
model_ckpt,
hf_path,
],
outputs=[log, output],
)
return app
demo = build_gradio_app()
# Note: we do NOT pre-warm `bpy_server` in the main process. `bpy_server.py`
# transitively imports `src.model.michelangelo.utils.misc`, whose
# module-level `use_flash3 = FLASH3()` calls `torch.cuda.get_device_name(0)`
# at import time. That call fails ("RuntimeError: No CUDA GPUs are
# available") in the main Gradio process on ZeroGPU, where the GPU is only
# attached inside `@spaces.GPU`-decorated workers. So the bpy_server boot
# happens on first request, inside the worker.
# ---------------------------------------------------------------------------
# Entry point.
# ---------------------------------------------------------------------------
if __name__ == "__main__":
parser = argparse.ArgumentParser("TokenRig Demo")
parser.add_argument("--input", help="Input file or directory")
parser.add_argument("--output", help="Output file or directory")
parser.add_argument("--top_k", type=int, default=5)
parser.add_argument("--top_p", type=float, default=0.95)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--repetition_penalty", type=float, default=2.0)
parser.add_argument("--num_beams", type=int, default=10)
parser.add_argument("--use_skeleton", action="store_true")
parser.add_argument("--use_transfer", action="store_true")
parser.add_argument("--use_postprocess", action="store_true")
parser.add_argument("--model_ckpt", default=MODEL_CKPTS[0] if MODEL_CKPTS else "")
parser.add_argument("--hf_path", default=None)
parser.add_argument("--gradio", action="store_true")
args = parser.parse_args()
if args.gradio or not args.input:
demo.queue()
demo.launch(ssr_mode=False)
else:
ensure_bpy_server_started()
run_cli(args)