Spaces:
Running on Zero
Running on Zero
| import os | |
| import subprocess | |
| import time | |
| import sys | |
| import shutil | |
| import tarfile | |
| import urllib.request | |
| import site | |
| from pathlib import Path | |
| from unittest.mock import MagicMock | |
| # ========================================================== | |
| # 0. GLOBALS (Blender userland download) | |
| # ========================================================== | |
| # Blender 3.6 LTS uses Python 3.10 -> good match for this Space | |
| BLENDER_VERSION = "3.6.5" | |
| BLENDER_TARBALL = f"blender-{BLENDER_VERSION}-linux-x64.tar.xz" | |
| BLENDER_URL = f"https://download.blender.org/release/Blender3.6/{BLENDER_TARBALL}" | |
| # Cache location writable without root | |
| BLENDER_CACHE_DIR = Path.home() / ".cache" / "unirig" / f"blender-{BLENDER_VERSION}" | |
| BLENDER_EXTRACT_DIR = BLENDER_CACHE_DIR / f"blender-{BLENDER_VERSION}-linux-x64" | |
| BLENDER_BIN = BLENDER_EXTRACT_DIR / "blender" | |
| # Where we will write a temporary Blender python script at runtime | |
| BLENDER_SCRIPT_PATH = BLENDER_CACHE_DIR / "hf_blender_extract.py" | |
| # ========================================================== | |
| # 1. SYSTEM SETUP (No Xvfb needed when using Blender -b) | |
| # ========================================================== | |
| # NOTE: We intentionally do NOT start Xvfb because HF blocks /tmp/.X11-unix creation | |
| # and Blender is run headless via `-b`. | |
| # ========================================================== | |
| # 2. BUGFIXES & MOCKS | |
| # ========================================================== | |
| # Fix A: Gradio Schema-Fehler | |
| import gradio_client.utils as client_utils | |
| client_utils._json_schema_to_python_type = lambda *args, **kwargs: "Any" | |
| client_utils.json_schema_to_python_type = lambda *args, **kwargs: "Any" | |
| # Fix B: Flash Attention Mocking | |
| try: | |
| import flash_attn # noqa: F401 | |
| except ImportError: | |
| mock = MagicMock() | |
| sys.modules["flash_attn"] = mock | |
| sys.modules["flash_attn.modules"] = mock | |
| sys.modules["flash_attn.modules.mha"] = mock | |
| print("Flash Attention gemockt.") | |
| # ========================================================== | |
| # 3. CORE IMPORTS | |
| # ========================================================== | |
| try: | |
| import open3d as o3d | |
| o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Error) | |
| except Exception: | |
| pass | |
| import gradio as gr | |
| import spaces | |
| import torch # noqa: F401 | |
| import lightning as L | |
| import yaml | |
| from box import Box | |
| # ========================================================== | |
| # 4. BLENDER HELPERS (download + run headless extraction) | |
| # ========================================================== | |
| def ensure_blender() -> str: | |
| """ | |
| Download and extract Blender into user cache dir (no root). | |
| Returns path to blender executable. | |
| """ | |
| if BLENDER_BIN.exists(): | |
| return str(BLENDER_BIN) | |
| BLENDER_CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| tar_path = BLENDER_CACHE_DIR / BLENDER_TARBALL | |
| if not tar_path.exists(): | |
| print(f"⬇️ Downloading Blender {BLENDER_VERSION} from: {BLENDER_URL}") | |
| urllib.request.urlretrieve(BLENDER_URL, tar_path) | |
| print(f"📦 Extracting Blender to: {BLENDER_CACHE_DIR}") | |
| with tarfile.open(tar_path, "r:xz") as tf: | |
| tf.extractall(path=BLENDER_CACHE_DIR) | |
| if not BLENDER_BIN.exists(): | |
| raise RuntimeError(f"Blender binary not found after extract: {BLENDER_BIN}") | |
| return str(BLENDER_BIN) | |
| def ensure_blender_script(): | |
| """ | |
| Writes a tiny extraction runner script that will be executed INSIDE Blender's Python. | |
| This avoids needing `import bpy` in the Space's Python runtime. | |
| """ | |
| if BLENDER_SCRIPT_PATH.exists(): | |
| return | |
| BLENDER_CACHE_DIR.mkdir(parents=True, exist_ok=True) | |
| # This script runs inside Blender's Python; it can import bpy and then call your extraction pipeline. | |
| script = r''' | |
| import sys | |
| import time | |
| from pathlib import Path | |
| def _parse(argv): | |
| args = {"input": None, "output_dir": None, "target_count": 50000} | |
| it = iter(argv) | |
| for k in it: | |
| if k == "--input": | |
| args["input"] = next(it) | |
| elif k == "--output_dir": | |
| args["output_dir"] = next(it) | |
| elif k == "--target_count": | |
| args["target_count"] = int(next(it)) | |
| if not args["input"] or not args["output_dir"]: | |
| raise SystemExit("Usage: --input <file> --output_dir <dir> [--target_count N]") | |
| return args | |
| def main(): | |
| argv = sys.argv | |
| if "--" in argv: | |
| argv = argv[argv.index("--") + 1 :] | |
| else: | |
| argv = [] | |
| args = _parse(argv) | |
| out = Path(args["output_dir"]) | |
| out.mkdir(parents=True, exist_ok=True) | |
| # Now import your project's extractor (this will import bpy inside Blender, which is fine) | |
| from src.data.extract import extract_builtin, get_files | |
| files = get_files( | |
| data_name="raw_data.npz", | |
| inputs=str(args["input"]), | |
| input_dataset_dir=None, | |
| output_dataset_dir=str(out), | |
| force_override=True, | |
| warning=False, | |
| ) | |
| if not files: | |
| raise RuntimeError("No files to extract") | |
| timestamp = str(int(time.time())) | |
| extract_builtin( | |
| output_folder=str(out), | |
| target_count=int(args["target_count"]), | |
| num_runs=1, | |
| id=0, | |
| time=timestamp, | |
| files=files, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |
| ''' | |
| BLENDER_SCRIPT_PATH.write_text(script, encoding="utf-8") | |
| def run_blender_extract(input_file: str, output_dir: str, target_count: int = 50000): | |
| """ | |
| Runs Blender headless (-b) and executes the extraction script. | |
| We also pass PYTHONPATH so Blender's Python can import this repo + site-packages. | |
| """ | |
| blender = ensure_blender() | |
| ensure_blender_script() | |
| repo_root = Path(__file__).parent.resolve() | |
| # Make installed pip packages visible to Blender-Python (in case extract.py needs them) | |
| py_paths = [] | |
| try: | |
| py_paths += site.getsitepackages() | |
| except Exception: | |
| pass | |
| py_paths.append(str(repo_root)) | |
| env = os.environ.copy() | |
| env["PYTHONPATH"] = os.pathsep.join([p for p in py_paths if p] + [env.get("PYTHONPATH", "")]) | |
| cmd = [ | |
| blender, | |
| "-b", | |
| "-noaudio", | |
| "--python", | |
| str(BLENDER_SCRIPT_PATH), | |
| "--", | |
| "--input", | |
| str(input_file), | |
| "--output_dir", | |
| str(output_dir), | |
| "--target_count", | |
| str(target_count), | |
| ] | |
| print("🧩 Running Blender extract:") | |
| print(" " + " ".join(cmd)) | |
| subprocess.check_call(cmd, env=env) | |
| # ========================================================== | |
| # 5. DEINE FUNKTIONEN (mit Blender-Fallback) | |
| # ========================================================== | |
| def validate_input_file(file_path: str) -> bool: | |
| supported_formats = [".obj", ".fbx", ".glb"] | |
| if not file_path or not Path(file_path).exists(): | |
| return False | |
| return Path(file_path).suffix.lower() in supported_formats | |
| def extract_mesh_python(input_file: str, output_dir: str) -> str: | |
| """ | |
| 1) Try native bpy (if it ever exists in the Space) | |
| 2) Otherwise run Blender headless subprocess that generates the npz | |
| """ | |
| try: | |
| import bpy # noqa: F401 | |
| from src.data.extract import extract_builtin, get_files | |
| files = get_files( | |
| data_name="raw_data.npz", | |
| inputs=str(input_file), | |
| input_dataset_dir=None, | |
| output_dataset_dir=output_dir, | |
| force_override=True, | |
| warning=False, | |
| ) | |
| if not files: | |
| raise RuntimeError("No files to extract") | |
| timestamp = str(int(time.time())) | |
| extract_builtin( | |
| output_folder=output_dir, | |
| target_count=50000, | |
| num_runs=1, | |
| id=0, | |
| time=timestamp, | |
| files=files, | |
| ) | |
| return files[0][1] | |
| except Exception as e: | |
| print(f"⚠️ Native bpy extraction failed ({type(e).__name__}: {e}) -> using Blender subprocess fallback.") | |
| # Blender subprocess fallback | |
| Path(output_dir).mkdir(parents=True, exist_ok=True) | |
| run_blender_extract(input_file=input_file, output_dir=output_dir, target_count=50000) | |
| # Recompute expected output path using existing helper | |
| from src.data.extract import get_files | |
| files = get_files( | |
| data_name="raw_data.npz", | |
| inputs=str(input_file), | |
| input_dataset_dir=None, | |
| output_dataset_dir=output_dir, | |
| force_override=True, | |
| warning=False, | |
| ) | |
| if not files: | |
| raise RuntimeError("No files produced by Blender extraction") | |
| return files[0][1] | |
| def run_inference_python( | |
| input_file: str, | |
| output_file: str, | |
| inference_type: str, | |
| seed: int = 12345, | |
| npz_dir: str = None, | |
| ) -> str: | |
| from src.data.datapath import Datapath | |
| from src.data.dataset import DatasetConfig, UniRigDatasetModule | |
| from src.data.transform import TransformConfig | |
| from src.inference.download import download | |
| from src.model.parse import get_model | |
| from src.system.parse import get_system, get_writer | |
| from src.tokenizer.parse import get_tokenizer | |
| from src.tokenizer.spec import TokenizerConfig | |
| if inference_type == "skeleton": | |
| L.seed_everything(seed, workers=True) | |
| configs = [ | |
| "configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml", | |
| "configs/transform/inference_ar_transform.yaml", | |
| "configs/model/unirig_ar_350m_1024_81920_float32.yaml", | |
| "configs/system/ar_inference_articulationxl.yaml", | |
| "configs/tokenizer/tokenizer_parts_articulationxl_256.yaml", | |
| ] | |
| data_name = "raw_data.npz" | |
| else: | |
| configs = [ | |
| "configs/task/quick_inference_unirig_skin.yaml", | |
| "configs/transform/inference_skin_transform.yaml", | |
| "configs/model/unirig_skin.yaml", | |
| "configs/system/skin.yaml", | |
| None, | |
| ] | |
| data_name = "predict_skeleton.npz" | |
| with open(configs[0], "r") as f: | |
| task = Box(yaml.safe_load(f)) | |
| if inference_type == "skeleton": | |
| if npz_dir is None: | |
| npz_dir = Path(output_file).parent / "npz" | |
| npz_dir.mkdir(exist_ok=True) | |
| npz_data_dir = extract_mesh_python(input_file, str(npz_dir)) | |
| datapath = Datapath(files=[npz_data_dir], cls=None) | |
| else: | |
| skeleton_work_dir = Path(input_file).parent | |
| skeleton_npz_dir = list(skeleton_work_dir.rglob("**/*.npz"))[0].parent | |
| datapath = Datapath(files=[str(skeleton_npz_dir)], cls=None) | |
| data_config = Box(yaml.safe_load(open("configs/data/quick_inference.yaml", "r"))) | |
| transform_config = Box(yaml.safe_load(open(configs[1], "r"))) | |
| if inference_type == "skeleton": | |
| tokenizer = get_tokenizer( | |
| config=TokenizerConfig.parse(config=Box(yaml.safe_load(open(configs[4], "r")))) | |
| ) | |
| model = get_model(tokenizer=tokenizer, **Box(yaml.safe_load(open(configs[2], "r")))) | |
| else: | |
| model = get_model(tokenizer=None, **Box(yaml.safe_load(open(configs[2], "r")))) | |
| data = UniRigDatasetModule( | |
| process_fn=model._process_fn, | |
| predict_dataset_config=DatasetConfig.parse(config=data_config.predict_dataset_config).split_by_cls(), | |
| predict_transform_config=TransformConfig.parse(config=transform_config.predict_transform_config), | |
| tokenizer_config=None if inference_type == "skin" else tokenizer.config, | |
| data_name=data_name, | |
| datapath=datapath, | |
| cls=None, | |
| ) | |
| writer_config = task.writer.copy() | |
| if inference_type == "skeleton": | |
| writer_config.update( | |
| { | |
| "npz_dir": str(npz_dir), | |
| "output_dir": str(Path(output_file).parent), | |
| "output_name": Path(output_file).name, | |
| "user_mode": False, | |
| } | |
| ) | |
| else: | |
| writer_config.update( | |
| { | |
| "npz_dir": str(skeleton_npz_dir), | |
| "output_name": str(output_file), | |
| "user_mode": True, | |
| "export_fbx": True, | |
| } | |
| ) | |
| callbacks = [get_writer(**writer_config, order_config=data.predict_transform_config.order_config)] | |
| system = get_system(**Box(yaml.safe_load(open(configs[3], "r"))), model=model, steps_per_epoch=1) | |
| trainer = L.Trainer(callbacks=callbacks, logger=None, **task.trainer) | |
| trainer.predict( | |
| system, | |
| datamodule=data, | |
| ckpt_path=download(task.resume_from_checkpoint), | |
| return_predictions=False, | |
| ) | |
| return str(output_file) | |
| def merge_results_python(source_file: str, target_file: str, output_file: str) -> str: | |
| from src.inference.merge import transfer | |
| transfer(source=str(source_file), target=str(target_file), output=str(output_file), add_root=False) | |
| return str(output_file) | |
| # ========================================================== | |
| # 6. GRADIO APP | |
| # ========================================================== | |
| def main(input_file: str, seed: int = 12345): | |
| temp_dir = Path(__file__).parent / "tmp" | |
| temp_dir.mkdir(exist_ok=True) | |
| if not validate_input_file(input_file): | |
| raise gr.Error("Invalid file format") | |
| file_stem = Path(input_file).stem | |
| input_model_dir = temp_dir / f"{file_stem}_{seed}" | |
| input_model_dir.mkdir(exist_ok=True) | |
| input_path = input_model_dir / Path(input_file).name | |
| shutil.copy2(input_file, input_path) | |
| skel_fbx = input_model_dir / f"{file_stem}_skeleton.fbx" | |
| skel_only = input_model_dir / f"{file_stem}_skeleton_only{input_path.suffix}" | |
| skin_fbx = input_model_dir / f"{file_stem}_skin.fbx" | |
| final_out = input_model_dir / f"{file_stem}_skeleton_and_skinning{input_path.suffix}" | |
| run_inference_python(str(input_path), str(skel_fbx), "skeleton", seed) | |
| merge_results_python(str(skel_fbx), str(input_path), str(skel_only)) | |
| run_inference_python(str(skel_fbx), str(skin_fbx), "skin") | |
| merge_results_python(str(skin_fbx), str(input_path), str(final_out)) | |
| return str(final_out), [str(skel_only), str(final_out)] | |
| def create_app(): | |
| with gr.Blocks(title="UniRig Demo") as interface: | |
| gr.Markdown("# 🎯 UniRig: Automated 3D Model Rigging") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_3d = gr.Model3D(label="Upload 3D Model") | |
| seed = gr.Number(value=12345, label="Seed") | |
| btn = gr.Button("Start Rigging", variant="primary") | |
| with gr.Column(): | |
| out_3d = gr.Model3D(label="Result") | |
| out_files = gr.Files(label="Download Files") | |
| btn.click(fn=main, inputs=[input_3d, seed], outputs=[out_3d, out_files]) | |
| return interface | |
| if __name__ == "__main__": | |
| create_app().queue().launch(show_api=False) | |