UniRig / app.py
MajorDaniel's picture
Update app.py
f1ff5ee verified
import os
import subprocess
import time
import sys
import shutil
import tarfile
import urllib.request
import site
from pathlib import Path
from unittest.mock import MagicMock
# ==========================================================
# 0. GLOBALS (Blender userland download)
# ==========================================================
# Blender 3.6 LTS uses Python 3.10 -> good match for this Space
BLENDER_VERSION = "3.6.5"
BLENDER_TARBALL = f"blender-{BLENDER_VERSION}-linux-x64.tar.xz"
BLENDER_URL = f"https://download.blender.org/release/Blender3.6/{BLENDER_TARBALL}"
# Cache location writable without root
BLENDER_CACHE_DIR = Path.home() / ".cache" / "unirig" / f"blender-{BLENDER_VERSION}"
BLENDER_EXTRACT_DIR = BLENDER_CACHE_DIR / f"blender-{BLENDER_VERSION}-linux-x64"
BLENDER_BIN = BLENDER_EXTRACT_DIR / "blender"
# Where we will write a temporary Blender python script at runtime
BLENDER_SCRIPT_PATH = BLENDER_CACHE_DIR / "hf_blender_extract.py"
# ==========================================================
# 1. SYSTEM SETUP (No Xvfb needed when using Blender -b)
# ==========================================================
# NOTE: We intentionally do NOT start Xvfb because HF blocks /tmp/.X11-unix creation
# and Blender is run headless via `-b`.
# ==========================================================
# 2. BUGFIXES & MOCKS
# ==========================================================
# Fix A: Gradio Schema-Fehler
import gradio_client.utils as client_utils
client_utils._json_schema_to_python_type = lambda *args, **kwargs: "Any"
client_utils.json_schema_to_python_type = lambda *args, **kwargs: "Any"
# Fix B: Flash Attention Mocking
try:
import flash_attn # noqa: F401
except ImportError:
mock = MagicMock()
sys.modules["flash_attn"] = mock
sys.modules["flash_attn.modules"] = mock
sys.modules["flash_attn.modules.mha"] = mock
print("Flash Attention gemockt.")
# ==========================================================
# 3. CORE IMPORTS
# ==========================================================
try:
import open3d as o3d
o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Error)
except Exception:
pass
import gradio as gr
import spaces
import torch # noqa: F401
import lightning as L
import yaml
from box import Box
# ==========================================================
# 4. BLENDER HELPERS (download + run headless extraction)
# ==========================================================
def ensure_blender() -> str:
"""
Download and extract Blender into user cache dir (no root).
Returns path to blender executable.
"""
if BLENDER_BIN.exists():
return str(BLENDER_BIN)
BLENDER_CACHE_DIR.mkdir(parents=True, exist_ok=True)
tar_path = BLENDER_CACHE_DIR / BLENDER_TARBALL
if not tar_path.exists():
print(f"⬇️ Downloading Blender {BLENDER_VERSION} from: {BLENDER_URL}")
urllib.request.urlretrieve(BLENDER_URL, tar_path)
print(f"📦 Extracting Blender to: {BLENDER_CACHE_DIR}")
with tarfile.open(tar_path, "r:xz") as tf:
tf.extractall(path=BLENDER_CACHE_DIR)
if not BLENDER_BIN.exists():
raise RuntimeError(f"Blender binary not found after extract: {BLENDER_BIN}")
return str(BLENDER_BIN)
def ensure_blender_script():
"""
Writes a tiny extraction runner script that will be executed INSIDE Blender's Python.
This avoids needing `import bpy` in the Space's Python runtime.
"""
if BLENDER_SCRIPT_PATH.exists():
return
BLENDER_CACHE_DIR.mkdir(parents=True, exist_ok=True)
# This script runs inside Blender's Python; it can import bpy and then call your extraction pipeline.
script = r'''
import sys
import time
from pathlib import Path
def _parse(argv):
args = {"input": None, "output_dir": None, "target_count": 50000}
it = iter(argv)
for k in it:
if k == "--input":
args["input"] = next(it)
elif k == "--output_dir":
args["output_dir"] = next(it)
elif k == "--target_count":
args["target_count"] = int(next(it))
if not args["input"] or not args["output_dir"]:
raise SystemExit("Usage: --input <file> --output_dir <dir> [--target_count N]")
return args
def main():
argv = sys.argv
if "--" in argv:
argv = argv[argv.index("--") + 1 :]
else:
argv = []
args = _parse(argv)
out = Path(args["output_dir"])
out.mkdir(parents=True, exist_ok=True)
# Now import your project's extractor (this will import bpy inside Blender, which is fine)
from src.data.extract import extract_builtin, get_files
files = get_files(
data_name="raw_data.npz",
inputs=str(args["input"]),
input_dataset_dir=None,
output_dataset_dir=str(out),
force_override=True,
warning=False,
)
if not files:
raise RuntimeError("No files to extract")
timestamp = str(int(time.time()))
extract_builtin(
output_folder=str(out),
target_count=int(args["target_count"]),
num_runs=1,
id=0,
time=timestamp,
files=files,
)
if __name__ == "__main__":
main()
'''
BLENDER_SCRIPT_PATH.write_text(script, encoding="utf-8")
def run_blender_extract(input_file: str, output_dir: str, target_count: int = 50000):
"""
Runs Blender headless (-b) and executes the extraction script.
We also pass PYTHONPATH so Blender's Python can import this repo + site-packages.
"""
blender = ensure_blender()
ensure_blender_script()
repo_root = Path(__file__).parent.resolve()
# Make installed pip packages visible to Blender-Python (in case extract.py needs them)
py_paths = []
try:
py_paths += site.getsitepackages()
except Exception:
pass
py_paths.append(str(repo_root))
env = os.environ.copy()
env["PYTHONPATH"] = os.pathsep.join([p for p in py_paths if p] + [env.get("PYTHONPATH", "")])
cmd = [
blender,
"-b",
"-noaudio",
"--python",
str(BLENDER_SCRIPT_PATH),
"--",
"--input",
str(input_file),
"--output_dir",
str(output_dir),
"--target_count",
str(target_count),
]
print("🧩 Running Blender extract:")
print(" " + " ".join(cmd))
subprocess.check_call(cmd, env=env)
# ==========================================================
# 5. DEINE FUNKTIONEN (mit Blender-Fallback)
# ==========================================================
def validate_input_file(file_path: str) -> bool:
supported_formats = [".obj", ".fbx", ".glb"]
if not file_path or not Path(file_path).exists():
return False
return Path(file_path).suffix.lower() in supported_formats
def extract_mesh_python(input_file: str, output_dir: str) -> str:
"""
1) Try native bpy (if it ever exists in the Space)
2) Otherwise run Blender headless subprocess that generates the npz
"""
try:
import bpy # noqa: F401
from src.data.extract import extract_builtin, get_files
files = get_files(
data_name="raw_data.npz",
inputs=str(input_file),
input_dataset_dir=None,
output_dataset_dir=output_dir,
force_override=True,
warning=False,
)
if not files:
raise RuntimeError("No files to extract")
timestamp = str(int(time.time()))
extract_builtin(
output_folder=output_dir,
target_count=50000,
num_runs=1,
id=0,
time=timestamp,
files=files,
)
return files[0][1]
except Exception as e:
print(f"⚠️ Native bpy extraction failed ({type(e).__name__}: {e}) -> using Blender subprocess fallback.")
# Blender subprocess fallback
Path(output_dir).mkdir(parents=True, exist_ok=True)
run_blender_extract(input_file=input_file, output_dir=output_dir, target_count=50000)
# Recompute expected output path using existing helper
from src.data.extract import get_files
files = get_files(
data_name="raw_data.npz",
inputs=str(input_file),
input_dataset_dir=None,
output_dataset_dir=output_dir,
force_override=True,
warning=False,
)
if not files:
raise RuntimeError("No files produced by Blender extraction")
return files[0][1]
def run_inference_python(
input_file: str,
output_file: str,
inference_type: str,
seed: int = 12345,
npz_dir: str = None,
) -> str:
from src.data.datapath import Datapath
from src.data.dataset import DatasetConfig, UniRigDatasetModule
from src.data.transform import TransformConfig
from src.inference.download import download
from src.model.parse import get_model
from src.system.parse import get_system, get_writer
from src.tokenizer.parse import get_tokenizer
from src.tokenizer.spec import TokenizerConfig
if inference_type == "skeleton":
L.seed_everything(seed, workers=True)
configs = [
"configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml",
"configs/transform/inference_ar_transform.yaml",
"configs/model/unirig_ar_350m_1024_81920_float32.yaml",
"configs/system/ar_inference_articulationxl.yaml",
"configs/tokenizer/tokenizer_parts_articulationxl_256.yaml",
]
data_name = "raw_data.npz"
else:
configs = [
"configs/task/quick_inference_unirig_skin.yaml",
"configs/transform/inference_skin_transform.yaml",
"configs/model/unirig_skin.yaml",
"configs/system/skin.yaml",
None,
]
data_name = "predict_skeleton.npz"
with open(configs[0], "r") as f:
task = Box(yaml.safe_load(f))
if inference_type == "skeleton":
if npz_dir is None:
npz_dir = Path(output_file).parent / "npz"
npz_dir.mkdir(exist_ok=True)
npz_data_dir = extract_mesh_python(input_file, str(npz_dir))
datapath = Datapath(files=[npz_data_dir], cls=None)
else:
skeleton_work_dir = Path(input_file).parent
skeleton_npz_dir = list(skeleton_work_dir.rglob("**/*.npz"))[0].parent
datapath = Datapath(files=[str(skeleton_npz_dir)], cls=None)
data_config = Box(yaml.safe_load(open("configs/data/quick_inference.yaml", "r")))
transform_config = Box(yaml.safe_load(open(configs[1], "r")))
if inference_type == "skeleton":
tokenizer = get_tokenizer(
config=TokenizerConfig.parse(config=Box(yaml.safe_load(open(configs[4], "r"))))
)
model = get_model(tokenizer=tokenizer, **Box(yaml.safe_load(open(configs[2], "r"))))
else:
model = get_model(tokenizer=None, **Box(yaml.safe_load(open(configs[2], "r"))))
data = UniRigDatasetModule(
process_fn=model._process_fn,
predict_dataset_config=DatasetConfig.parse(config=data_config.predict_dataset_config).split_by_cls(),
predict_transform_config=TransformConfig.parse(config=transform_config.predict_transform_config),
tokenizer_config=None if inference_type == "skin" else tokenizer.config,
data_name=data_name,
datapath=datapath,
cls=None,
)
writer_config = task.writer.copy()
if inference_type == "skeleton":
writer_config.update(
{
"npz_dir": str(npz_dir),
"output_dir": str(Path(output_file).parent),
"output_name": Path(output_file).name,
"user_mode": False,
}
)
else:
writer_config.update(
{
"npz_dir": str(skeleton_npz_dir),
"output_name": str(output_file),
"user_mode": True,
"export_fbx": True,
}
)
callbacks = [get_writer(**writer_config, order_config=data.predict_transform_config.order_config)]
system = get_system(**Box(yaml.safe_load(open(configs[3], "r"))), model=model, steps_per_epoch=1)
trainer = L.Trainer(callbacks=callbacks, logger=None, **task.trainer)
trainer.predict(
system,
datamodule=data,
ckpt_path=download(task.resume_from_checkpoint),
return_predictions=False,
)
return str(output_file)
def merge_results_python(source_file: str, target_file: str, output_file: str) -> str:
from src.inference.merge import transfer
transfer(source=str(source_file), target=str(target_file), output=str(output_file), add_root=False)
return str(output_file)
# ==========================================================
# 6. GRADIO APP
# ==========================================================
@spaces.GPU()
def main(input_file: str, seed: int = 12345):
temp_dir = Path(__file__).parent / "tmp"
temp_dir.mkdir(exist_ok=True)
if not validate_input_file(input_file):
raise gr.Error("Invalid file format")
file_stem = Path(input_file).stem
input_model_dir = temp_dir / f"{file_stem}_{seed}"
input_model_dir.mkdir(exist_ok=True)
input_path = input_model_dir / Path(input_file).name
shutil.copy2(input_file, input_path)
skel_fbx = input_model_dir / f"{file_stem}_skeleton.fbx"
skel_only = input_model_dir / f"{file_stem}_skeleton_only{input_path.suffix}"
skin_fbx = input_model_dir / f"{file_stem}_skin.fbx"
final_out = input_model_dir / f"{file_stem}_skeleton_and_skinning{input_path.suffix}"
run_inference_python(str(input_path), str(skel_fbx), "skeleton", seed)
merge_results_python(str(skel_fbx), str(input_path), str(skel_only))
run_inference_python(str(skel_fbx), str(skin_fbx), "skin")
merge_results_python(str(skin_fbx), str(input_path), str(final_out))
return str(final_out), [str(skel_only), str(final_out)]
def create_app():
with gr.Blocks(title="UniRig Demo") as interface:
gr.Markdown("# 🎯 UniRig: Automated 3D Model Rigging")
with gr.Row():
with gr.Column():
input_3d = gr.Model3D(label="Upload 3D Model")
seed = gr.Number(value=12345, label="Seed")
btn = gr.Button("Start Rigging", variant="primary")
with gr.Column():
out_3d = gr.Model3D(label="Result")
out_files = gr.Files(label="Download Files")
btn.click(fn=main, inputs=[input_3d, seed], outputs=[out_3d, out_files])
return interface
if __name__ == "__main__":
create_app().queue().launch(show_api=False)