File size: 26,784 Bytes
9d7cf7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
498751c
9d7cf7f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
import argparse
import atexit
import importlib
import os
import signal
import subprocess
import sys
import tempfile
import time
from pathlib import Path
from typing import List, Optional, Tuple

import gradio as gr
import requests
from torch import Tensor
from tqdm import tqdm

# ---------------------------------------------------------------------------
# ZeroGPU compatibility shim. The hosted HF Space provides the `spaces` 
# package; running locally we substitute a no-op.
# ---------------------------------------------------------------------------
try:
    spaces = importlib.import_module("spaces")
except Exception:
    class _SpacesCompat:
        @staticmethod
        def GPU(*args, **kwargs):
            if len(args) == 1 and callable(args[0]) and not kwargs:
                return args[0]

            def _decorator(fn):
                return fn

            return _decorator

    spaces = _SpacesCompat()

os.environ.setdefault("XFORMERS_IGNORE_FLASH_VERSION_CHECK", "1")
gr.TEMP_DIR = "tmp_gradio"


# ---------------------------------------------------------------------------
# Install the bundled `bpy` wheel at runtime if it isn't already importable.
#
# Why this is non-trivial:
#  - Putting the wheel in requirements.txt fails: HF Spaces' Docker build
#    mounts only requirements.txt BEFORE the repo COPY, so the wheel path
#    doesn't exist at pip-install time.
#  - PyPI doesn't ship a bpy wheel matching this exact build (rc0 / cp312 /
#    manylinux_2_39).
#  - The `bpy-*.whl` committed in this repo gets auto-tracked by HF's LFS
#    layer (Hub auto-LFS for blobs > ~10 MB even when .gitattributes doesn't
#    list `*.whl`). The container's COPY-from-repo only carries the LFS
#    *pointer* file β€” a ~150-byte text stub β€” not the actual wheel binary.
#    So `pip install <wheel>` and `zipfile.ZipFile(<wheel>)` both fail with
#    "is not a zip file" / "Wheel is invalid".
#
# So: we detect the LFS-pointer case and re-fetch the real wheel from the
# HF Hub at runtime (where the API resolves LFS server-side), then extract
# it directly into site-packages.
# ---------------------------------------------------------------------------
def _ensure_bpy_installed():
    try:
        import bpy  # noqa: F401
        return
    except Exception:
        pass

    import glob
    import sysconfig
    import zipfile

    here = os.path.dirname(os.path.abspath(__file__))
    wheels = sorted(glob.glob(os.path.join(here, "bpy-*.whl")))
    if not wheels:
        print("[demo] WARNING: bpy not importable and no bundled wheel found", flush=True)
        return

    wheel = wheels[-1]
    wheel_name = os.path.basename(wheel)

    # Detect LFS pointer (text stub starting with "version https://git-lfs...").
    is_real_zip = False
    try:
        with open(wheel, "rb") as f:
            is_real_zip = f.read(4).startswith(b"PK")
    except Exception:
        pass

    if not is_real_zip:
        print(
            f"[demo] {wheel_name} on disk is an LFS pointer ({os.path.getsize(wheel)} B); "
            f"fetching real wheel from HF Hub...",
            flush=True,
        )
        from huggingface_hub import hf_hub_download

        space_id = os.environ.get("SPACE_ID", "VAST-AI/SkinTokens")
        token = os.environ.get("HF_TOKEN")  # set as a Space secret for private repos
        wheel = hf_hub_download(
            repo_id=space_id,
            repo_type="space",
            filename=wheel_name,
            token=token,
        )
        print(f"[demo] fetched -> {wheel} ({os.path.getsize(wheel)} B)", flush=True)

    site = sysconfig.get_paths()["purelib"]
    print(f"[demo] Extracting {wheel_name} into {site}", flush=True)
    with zipfile.ZipFile(wheel) as z:
        z.extractall(site)
    print("[demo] bpy wheel extracted.", flush=True)


_ensure_bpy_installed()


# ---------------------------------------------------------------------------
# Download model checkpoints (TokenRig + SkinTokens FSQ-CVAE) and the Qwen3
# tokenizer/config on first cold-start.
#
# These live in the *model* repo `VAST-AI/SkinTokens` (private), separate
# from this Space repo, so they aren't COPYed into the container. Re-uses
# `HF_TOKEN` from the Space secrets.
# ---------------------------------------------------------------------------
def _ensure_models_downloaded():
    here = os.path.dirname(os.path.abspath(__file__))
    needed_ckpts = [
        "experiments/skin_vae_2_10_32768/last.ckpt",
        "experiments/articulation_xl_quantization_256_token_4/grpo_1400.ckpt",
    ]
    qwen_dir = os.path.join(here, "models", "Qwen3-0.6B")

    all_present = (
        all(os.path.exists(os.path.join(here, p)) for p in needed_ckpts)
        and os.path.exists(os.path.join(qwen_dir, "tokenizer.json"))
    )
    if all_present:
        return

    from huggingface_hub import hf_hub_download, snapshot_download

    token = os.environ.get("HF_TOKEN")

    for rel in needed_ckpts:
        target = os.path.join(here, rel)
        if os.path.exists(target):
            continue
        print(f"[demo] Downloading checkpoint: {rel}", flush=True)
        hf_hub_download(
            repo_id="VAST-AI/SkinTokens",
            filename=rel,
            local_dir=here,
            token=token,
        )

    if not os.path.exists(os.path.join(qwen_dir, "tokenizer.json")):
        print("[demo] Downloading Qwen3-0.6B tokenizer/config", flush=True)
        snapshot_download(
            repo_id="Qwen/Qwen3-0.6B",
            local_dir=qwen_dir,
            ignore_patterns=["*.bin", "*.safetensors"],
        )

    print("[demo] All checkpoints ready.", flush=True)


_ensure_models_downloaded()


from src.data.dataset import DatasetConfig, RigDatasetModule
from src.data.transform import Transform
from src.model.tokenrig import TokenRigResult
from src.tokenizer.parse import get_tokenizer
from src.server.spec import (
    BPY_SERVER,
    get_model,
    object_to_bytes,
    bytes_to_object,
)
from src.data.vertex_group import voxel_skin


# ---------------------------------------------------------------------------
# Pre-warm `bpy_server` in the main (Gradio) process at module load.
#
# Why this is necessary on ZeroGPU: each user request runs inside a fresh
# `@spaces.GPU` worker process with a hard time budget (β‰ˆ60 s on free tier).
# Importing the Blender shared object inside that budget burns 30–60 s, so
# the worker is killed *during* bpy import β€” manifesting as
# "GPU task aborted" before any model code runs.
#
# We start `bpy_server.py` here, in the always-running main process, so the
# slow bpy import happens exactly once at Space boot. Workers then just hit
# `localhost:59876` over HTTP β€” sub-millisecond, no startup cost.
# ---------------------------------------------------------------------------

MODEL_CKPTS = [
    "experiments/articulation_xl_quantization_256_token_4/grpo_1400.ckpt",
]

HF_PATHS = [
    "None",
]


def get_dataloader_workers() -> int:
    if os.getenv("SPACE_ID"):
        return 0
    return 1


# ---------------------------------------------------------------------------
# bpy_server lifecycle β€” lazy start so the heavy import doesn't fight ZeroGPU
# during module load.
# ---------------------------------------------------------------------------
_BPY_SERVER_PROC = None


def is_bpy_server_alive(timeout: float = 1.0) -> bool:
    try:
        resp = requests.get(f"{BPY_SERVER}/ping", timeout=timeout)
        return resp.status_code == 200
    except Exception:
        return False


def start_bpy_server():
    proc = subprocess.Popen(
        [sys.executable, "bpy_server.py"],
        stdout=None,
        stderr=None,
        preexec_fn=os.setsid,
    )
    print(f"[Main] bpy_server.py started (pid={proc.pid})")

    def cleanup():
        print(f"[Main] Terminating bpy_server.py (pid={proc.pid})")
        try:
            os.killpg(os.getpgid(proc.pid), signal.SIGTERM)
        except ProcessLookupError:
            pass

    atexit.register(cleanup)
    return proc


def wait_for_bpy_server(timeout: float = 120):
    """Wait for bpy_server.py to come up. The first start of bpy_server is
    slow because importing the Blender `.so` (~200 MB shared object) takes
    30–60 s on a cold container. We allow up to 120 s."""
    t0 = time.time()
    last_log = 0.0
    while True:
        try:
            requests.get(f"{BPY_SERVER}/ping", timeout=1)
            print(f"[Main] bpy_server is ready (after {time.time() - t0:.1f}s)")
            return
        except Exception:
            now = time.time()
            if now - t0 > timeout:
                raise RuntimeError(
                    f"bpy_server failed to start after {timeout:.0f}s"
                )
            if now - last_log > 10:  # progress every 10s
                print(f"[Main] still waiting for bpy_server ({now - t0:.0f}s elapsed)")
                last_log = now
            time.sleep(0.5)


def ensure_bpy_server_started():
    global _BPY_SERVER_PROC
    if is_bpy_server_alive():
        return
    if _BPY_SERVER_PROC is not None and _BPY_SERVER_PROC.poll() is None:
        return
    _BPY_SERVER_PROC = start_bpy_server()
    wait_for_bpy_server()


# ---------------------------------------------------------------------------
# Lazy model loading.
# ---------------------------------------------------------------------------
model = None
tokenizer = None
transform = None
CURRENT_MODEL_CKPT: Optional[str] = None
CURRENT_HF_PATH: Optional[str] = None


def load_model(model_ckpt: str, hf_path: Optional[str]) -> Tuple[str, str]:
    global model, tokenizer, transform, CURRENT_MODEL_CKPT, CURRENT_HF_PATH
    if hf_path == "None":
        hf_path = None
    if model is not None and model_ckpt == CURRENT_MODEL_CKPT and hf_path == CURRENT_HF_PATH:
        return ("Model already loaded.", model_ckpt)

    if not model_ckpt:
        raise RuntimeError("model_ckpt is empty. Please select a checkpoint.")

    print(f"Loading model: {model_ckpt}, hf_path={hf_path}")
    model = get_model(model_ckpt, hf_path=hf_path)
    assert model.tokenizer_config is not None
    tokenizer = get_tokenizer(**model.tokenizer_config)
    transform = Transform.parse(**model.transform_config["predict_transform"])
    CURRENT_MODEL_CKPT = model_ckpt
    CURRENT_HF_PATH = hf_path
    return ("Model loaded.", model_ckpt)


# ---------------------------------------------------------------------------
# File utilities (CLI-side).
# ---------------------------------------------------------------------------
SUPPORTED_EXT = {".obj", ".fbx", ".glb"}


def collect_files(input_path: Path) -> List[Path]:
    if input_path.is_file():
        return [input_path]

    files = []
    for p in input_path.rglob("*"):
        if p.suffix.lower() in SUPPORTED_EXT:
            files.append(p)
    return files


def map_output_path(in_path: Path, input_root: Path, output_root: Path) -> Path:
    rel = in_path.relative_to(input_root)
    return (output_root / rel).with_suffix(".glb")


# ---------------------------------------------------------------------------
# Core inference (shared by CLI and Gradio).
# ---------------------------------------------------------------------------
def run_rig(
    filepaths: List[Path],
    top_k: int,
    top_p: float,
    temperature: float,
    repetition_penalty: float,
    num_beams: int,
    use_skeleton: bool,
    use_transfer: bool,
    use_postprocess: bool,
    output_paths: List[Path],
    model_ckpt: str,
    hf_path: Optional[str],
):
    assert len(filepaths) == len(output_paths)
    ensure_bpy_server_started()
    load_model(model_ckpt, hf_path)

    datapath = {
        "data_name": None,
        "loader": "bpy_server",
        "filepaths": {"articulation": [str(p) for p in filepaths]},
    }

    dataset_config = DatasetConfig.parse(
        shuffle=False,
        batch_size=1,
        num_workers=get_dataloader_workers(),
        pin_memory=get_dataloader_workers() > 0,
        persistent_workers=False,
        datapath=datapath,
    ).split_by_cls()

    module = RigDatasetModule(
        predict_dataset_config=dataset_config,
        predict_transform=transform,
        tokenizer=tokenizer,
        process_fn=model._process_fn,
    )

    dataloader = module.predict_dataloader()["articulation"]

    results_out = []
    infer_device = model.device if model is not None else "cuda"

    for i, batch in tqdm(enumerate(dataloader), total=len(dataloader)):
        batch = {
            k: v.to(infer_device) if isinstance(v, Tensor) else v
            for k, v in batch.items()
        }

        if not use_skeleton:
            batch.pop("skeleton_tokens", None)
            batch.pop("skeleton_mask", None)

        batch["generate_kwargs"] = dict(
            max_length=2048,
            top_k=int(top_k),
            top_p=float(top_p),
            temperature=float(temperature),
            repetition_penalty=float(repetition_penalty),
            num_return_sequences=1,
            num_beams=int(num_beams),
            do_sample=True,
        )

        if "skeleton_tokens" in batch and "skeleton_mask" in batch:
            mask = batch["skeleton_mask"][0] == 1
            skeleton_tokens = batch["skeleton_tokens"][0][mask].cpu().numpy()
        else:
            skeleton_tokens = None

        preds: List[TokenRigResult] = model.predict_step(
            batch,
            skeleton_tokens=[skeleton_tokens] if skeleton_tokens is not None else None,
            make_asset=True,
        )["results"]

        asset = preds[0].asset
        assert asset is not None

        if use_postprocess:
            voxel = asset.voxel(resolution=196)
            asset.skin *= voxel_skin(
                grid=0,
                grid_coords=voxel.coords,
                joints=asset.joints,
                vertices=asset.vertices,
                faces=asset.faces,
                mode="square",
                voxel_size=voxel.voxel_size,
            )
            asset.normalize_skin()

        out_path = output_paths[i]
        out_path.parent.mkdir(parents=True, exist_ok=True)

        if use_transfer:
            payload = dict(
                source_asset=asset,
                target_path=asset.path,
                export_path=str(out_path),
                group_per_vertex=4,
            )
            res = bytes_to_object(
                requests.post(
                    f"{BPY_SERVER}/transfer",
                    data=object_to_bytes(payload),
                ).content
            )
        else:
            payload = dict(
                asset=asset,
                filepath=str(out_path),
                group_per_vertex=4,
            )
            res = bytes_to_object(
                requests.post(
                    f"{BPY_SERVER}/export",
                    data=object_to_bytes(payload),
                ).content
            )

        if res != "ok":
            print(f"[Error] {res}")
        else:
            print(f"[OK] Exported: {out_path}")

        results_out.append(out_path)

    return results_out


# ---------------------------------------------------------------------------
# CLI entry point.
# ---------------------------------------------------------------------------
def run_cli(args):
    input_path = Path(args.input).resolve()
    output_path = Path(args.output).resolve()

    files = collect_files(input_path)
    if not files:
        raise RuntimeError("No valid 3D files found.")

    if len(files) == 1 and output_path.suffix:
        outputs = [output_path]
    else:
        outputs = [map_output_path(f, input_path, output_path) for f in files]

    run_rig(
        files,
        args.top_k,
        args.top_p,
        args.temperature,
        args.repetition_penalty,
        args.num_beams,
        args.use_skeleton,
        args.use_transfer,
        args.use_postprocess,
        outputs,
        args.model_ckpt,
        args.hf_path,
    )


# ---------------------------------------------------------------------------
# Gradio wrapper (with ZeroGPU duration estimator).
# ---------------------------------------------------------------------------
TOT = 0


def _gpu_duration(
    files,
    top_k,
    top_p,
    temperature,
    repetition_penalty,
    num_beams,
    use_skeleton,
    use_transfer,
    use_postprocess,
    model_ckpt,
    hf_path,
):
    # Cold workers spend ~30–60 s importing bpy + loading the model before
    # any GPU work. Give every request a generous 240 s floor.
    file_count = len(files) if files is not None else 1
    return min(900, max(240, 240 + 60 * file_count))


@spaces.GPU(duration=_gpu_duration)
def run_gradio(
    files,
    top_k,
    top_p,
    temperature,
    repetition_penalty,
    num_beams,
    use_skeleton,
    use_transfer,
    use_postprocess,
    model_ckpt,
    hf_path,
):
    if not files:
        return "Please upload at least one 3D model.", None

    tmp_out = Path(tempfile.mkdtemp(prefix="tokenrig_"))
    filepaths = [Path(f.name) for f in files]
    global TOT
    outputs = []
    for filepath in filepaths:
        TOT += 1
        outputs.append(tmp_out / f"res_{TOT}.glb")

    run_rig(
        filepaths,
        top_k,
        top_p,
        temperature,
        repetition_penalty,
        num_beams,
        use_skeleton,
        use_transfer,
        use_postprocess,
        outputs,
        model_ckpt,
        hf_path,
    )

    return f"Processed {len(outputs)} models.", [str(p) for p in outputs]


# ---------------------------------------------------------------------------
# Gradio UI.
# ---------------------------------------------------------------------------
def build_gradio_app():
    model_ckpts = MODEL_CKPTS
    hf_paths = HF_PATHS
    default_ckpt = model_ckpts[0] if model_ckpts else ""
    default_hf = hf_paths[0] if hf_paths else "None"

    with gr.Blocks(title="SkinTokens Β· TokenRig Demo") as app:
        gr.Markdown(
            """
            ## 🦴 Mesh to Rig with [SkinTokens](https://zjp-shadow.github.io/works/SkinTokens/) · TokenRig

            Automated **skeleton generation + skinning weight prediction** for any 3D mesh, via a unified
            autoregressive model over learned *SkinTokens*. Successor to
            [UniRig](https://github.com/VAST-AI-Research/UniRig) (SIGGRAPH&nbsp;'25).

            * Upload one or more meshes β†’ click **Run** β†’ download a rigged `.glb`.
            * **Paper**: [arXiv&nbsp;2602.04805](https://arxiv.org/abs/2602.04805) &nbsp;Β·&nbsp;
              **Code**: [VAST-AI-Research/SkinTokens](https://github.com/VAST-AI-Research/SkinTokens) &nbsp;Β·&nbsp;
              **Weights**: [πŸ€—&nbsp;VAST-AI/SkinTokens](https://huggingface.co/VAST-AI/SkinTokens)
            * Looking for **image β†’ rigged 3D** instead? Try our sibling Space
              [πŸ€—&nbsp;VAST-AI/AniGen](https://huggingface.co/spaces/VAST-AI/AniGen).
            * Want a full AI-powered 3D workspace? β†’ [Tripo](https://www.tripo3d.ai)
            """
        )

        gr.HTML(
            """
<style>
@keyframes gentle-pulse {
    0%, 100% { opacity: 1; }
    50% { opacity: 0.35; }
}
</style>
<div style="text-align:left; color:#888; font-size:1em; line-height:1.6; margin: 4px 0 -4px 0;">
    <span style="animation: gentle-pulse 3s ease-in-out infinite; display:inline-block;">&#128161; <b>Tips</b></span>&ensp;
    Defaults work well for most meshes.
    &nbsp;β€’ If your mesh already has a skeleton and you only want skinning, enable
    <b>Use existing skeleton</b> below.
    &nbsp;β€’ To keep your original textures and world scale, enable <b>Preserve original texture &amp; scale</b>.
</div>
"""
        )

        with gr.Row():
            with gr.Column(scale=1):
                files = gr.File(
                    label="3D Models  ( .obj / .fbx / .glb, up to a few at a time )",
                    file_count="multiple",
                    file_types=[".obj", ".fbx", ".glb"],
                )

                with gr.Accordion("βš™οΈ Generation Settings", open=False):
                    model_ckpt = gr.Dropdown(
                        choices=model_ckpts,
                        value=default_ckpt,
                        label="Model checkpoint",
                        info="TokenRig autoregressive rigging model. The default is the GRPO-refined checkpoint recommended for most assets.",
                        interactive=True,
                    )
                    # Keep the hf_path component for callback compatibility, but hide it
                    # from the UI since it currently only exposes the default ("None") option.
                    hf_path = gr.Dropdown(
                        choices=hf_paths,
                        value=default_hf,
                        label="HF path (advanced)",
                        visible=False,
                    )

                    gr.Markdown("**Sampling parameters** β€” control autoregressive decoding of the rig.")
                    top_k = gr.Slider(
                        1, 200, value=5, step=1,
                        label="top_k",
                        info="Sample from the K most likely next tokens at each step. Lower = more deterministic output.",
                    )
                    top_p = gr.Slider(
                        0.1, 1.0, value=0.95, step=0.01,
                        label="top_p (nucleus)",
                        info="Sample from the smallest set of tokens whose cumulative probability β‰₯ p.",
                    )
                    temperature = gr.Slider(
                        0.1, 2.0, value=1.0, step=0.1,
                        label="temperature",
                        info="Softmax temperature. <1 sharpens the distribution (more conservative), >1 makes it flatter (more diverse).",
                    )
                    repetition_penalty = gr.Slider(
                        0.5, 3.0, value=2.0, step=0.1,
                        label="repetition_penalty",
                        info="Multiplicative penalty on tokens that have already been generated. 1.0 = no penalty.",
                    )
                    num_beams = gr.Slider(
                        1, 20, value=10, step=1,
                        label="num_beams",
                        info="Beam-search width. Larger = higher quality but slower; 1 disables beam search.",
                    )

                    gr.Markdown("**Pipeline toggles**")
                    use_skeleton = gr.Checkbox(
                        False,
                        label="Use existing skeleton (predict skinning only)",
                        info="If the uploaded file already contains a skeleton, keep it and only predict per-vertex skinning weights.",
                    )
                    use_transfer = gr.Checkbox(
                        False,
                        label="Preserve original texture & scale",
                        info="Transfer the predicted rig back onto the original (unprocessed) mesh, so textures and world units are preserved.",
                    )
                    use_postprocess = gr.Checkbox(
                        False,
                        label="Voxel skin post-processing",
                        info="Apply a voxel-based mask to the predicted skin weights before normalization. Slower.",
                    )

                run_btn = gr.Button("πŸš€ Run", variant="primary")

            with gr.Column(scale=1):
                log = gr.Textbox(label="Status", lines=2, interactive=False)
                output = gr.File(label="Rigged GLB output", interactive=False)
                gr.Markdown(
                    """
                    **Notes**
                    - The output `.glb` contains the predicted **skeleton + skinning weights**. Import it in Blender (File β†’ Import β†’ glTF&nbsp;2.0) or any DCC tool that reads glTF.
                    - In Blender, if you see a `glTF_not_exported` placeholder node, you can safely remove it.
                    - On busy moments Zero-GPU may queue your request for ~10–30&nbsp;s before inference starts β€” the status box will update once the GPU is attached.
                    - Please do **not** upload confidential or NSFW content. See the
                      [project page](https://zjp-shadow.github.io/works/SkinTokens/) for paper-accurate results and the
                      [code repo](https://github.com/VAST-AI-Research/SkinTokens) for local / batch inference.
                    """
                )

        run_btn.click(
            run_gradio,
            inputs=[
                files,
                top_k,
                top_p,
                temperature,
                repetition_penalty,
                num_beams,
                use_skeleton,
                use_transfer,
                use_postprocess,
                model_ckpt,
                hf_path,
            ],
            outputs=[log, output],
        )

    return app


demo = build_gradio_app()


# Note: we do NOT pre-warm `bpy_server` in the main process. `bpy_server.py`
# transitively imports `src.model.michelangelo.utils.misc`, whose
# module-level `use_flash3 = FLASH3()` calls `torch.cuda.get_device_name(0)`
# at import time. That call fails ("RuntimeError: No CUDA GPUs are
# available") in the main Gradio process on ZeroGPU, where the GPU is only
# attached inside `@spaces.GPU`-decorated workers. So the bpy_server boot
# happens on first request, inside the worker.


# ---------------------------------------------------------------------------
# Entry point.
# ---------------------------------------------------------------------------
if __name__ == "__main__":
    parser = argparse.ArgumentParser("TokenRig Demo")
    parser.add_argument("--input", help="Input file or directory")
    parser.add_argument("--output", help="Output file or directory")

    parser.add_argument("--top_k", type=int, default=5)
    parser.add_argument("--top_p", type=float, default=0.95)
    parser.add_argument("--temperature", type=float, default=1.0)
    parser.add_argument("--repetition_penalty", type=float, default=2.0)
    parser.add_argument("--num_beams", type=int, default=10)

    parser.add_argument("--use_skeleton", action="store_true")
    parser.add_argument("--use_transfer", action="store_true")
    parser.add_argument("--use_postprocess", action="store_true")

    parser.add_argument("--model_ckpt", default=MODEL_CKPTS[0] if MODEL_CKPTS else "")
    parser.add_argument("--hf_path", default=None)

    parser.add_argument("--gradio", action="store_true")

    args = parser.parse_args()

    if args.gradio or not args.input:
        demo.queue()
        demo.launch(ssr_mode=False)
    else:
        ensure_bpy_server_started()
        run_cli(args)