import shutil import subprocess import time from pathlib import Path from typing import Tuple, List import gradio as gr import lightning as L import spaces import torch import yaml from box import Box from lightning.pytorch.callbacks import ModelCheckpoint # Install dependencies torch_version = torch.__version__.split("+")[0] cuda_version = torch.version.cuda spconv_version = "-cu121" if cuda_version else "" if cuda_version: cuda_version = f"cu{cuda_version.replace('.', '')}" else: cuda_version = "cpu" subprocess.run(f'pip install spconv{spconv_version}', shell=True) subprocess.run(f'pip install torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-{torch_version}+{cuda_version}.html --no-cache-dir', shell=True) subprocess.run(f'pip uninstall flash-attn -y && pip install flash-attn --no-build-isolation --no-cache-dir', shell=True) subprocess.run(f'pip install bpy==3.6.0 --extra-index-url https://download.blender.org/pypi/', shell=True) subprocess.run(f'pip install lightning[extra]', shell=True) subprocess.run(f'pip install litmodels', shell=True) def validate_input_file(file_path: str, supported_formats: list) -> bool: if not file_path or not Path(file_path).exists(): return False file_ext = Path(file_path).suffix.lower() return file_ext in supported_formats def extract_mesh_python(input_file: str, output_dir: str, target_count: int) -> str: from src.data.extract import extract_builtin, get_files files = get_files( data_name="raw_data.npz", inputs=str(input_file), input_dataset_dir=None, output_dataset_dir=output_dir, force_override=True, warning=False, ) if not files: raise RuntimeError("No files to extract") timestamp = str(int(time.time())) extract_builtin( output_folder=output_dir, target_count=target_count, num_runs=1, id=0, time=timestamp, files=files, ) expected_npz_dir = files[0][1] expected_npz_file = Path(expected_npz_dir) / "raw_data.npz" if not expected_npz_file.exists(): raise RuntimeError(f"Extraction failed: {expected_npz_file} not found") return expected_npz_dir def run_inference_python( input_file: str, output_file: str, inference_type: str, seed: int = 12345, npz_dir: str = None, target_count: int = 50000, task_config_path: str = None, transform_config_path: str = None, model_config_path: str = None, system_config_path: str = None, tokenizer_config_path: str = None, data_name: str = None, ) -> str: from src.data.datapath import Datapath from src.data.dataset import DatasetConfig, UniRigDatasetModule from src.data.transform import TransformConfig from src.inference.download import download from src.model.parse import get_model from src.system.parse import get_system, get_writer from src.tokenizer.parse import get_tokenizer from src.tokenizer.spec import TokenizerConfig if inference_type == "skeleton": L.seed_everything(seed, workers=True) if task_config_path is None or not Path(task_config_path).exists(): raise FileNotFoundError(f"Task configuration file not found: {task_config_path}") with open(task_config_path, 'r') as f: task = Box(yaml.safe_load(f)) if inference_type == "skeleton": if npz_dir is None: npz_dir = Path(output_file).parent / "npz" npz_dir = Path(npz_dir) npz_dir.mkdir(exist_ok=True) npz_data_dir = extract_mesh_python(input_file, npz_dir, target_count) datapath = Datapath(files=[npz_data_dir], cls=None) else: skeleton_work_dir = Path(input_file).parent all_npz_files = list(skeleton_work_dir.rglob("**/*.npz")) if not all_npz_files: raise RuntimeError(f"No NPZ files found for skin inference in {skeleton_work_dir}") skeleton_npz_dir = all_npz_files[0].parent datapath = Datapath(files=[str(skeleton_npz_dir)], cls=None) if not Path("configs/data/quick_inference.yaml").exists(): raise FileNotFoundError("Missing configs/data/quick_inference.yaml") data_config = Box(yaml.safe_load(open("configs/data/quick_inference.yaml", 'r'))) if transform_config_path is None or not Path(transform_config_path).exists(): raise FileNotFoundError(f"Transform configuration file not found: {transform_config_path}") transform_config = Box(yaml.safe_load(open(transform_config_path, 'r'))) if inference_type == "skeleton": if tokenizer_config_path is None or not Path(tokenizer_config_path).exists(): raise FileNotFoundError(f"Tokenizer configuration file not found: {tokenizer_config_path}") tokenizer_config = TokenizerConfig.parse(config=Box(yaml.safe_load(open(tokenizer_config_path, 'r')))) tokenizer = get_tokenizer(config=tokenizer_config) if model_config_path is None or not Path(model_config_path).exists(): raise FileNotFoundError(f"Model configuration file not found: {model_config_path}") model_config = Box(yaml.safe_load(open(model_config_path, 'r'))) model = get_model(tokenizer=tokenizer, **model_config) else: tokenizer_config = None tokenizer = None if model_config_path is None or not Path(model_config_path).exists(): raise FileNotFoundError(f"Model configuration file not found: {model_config_path}") model_config = Box(yaml.safe_load(open(model_config_path, 'r'))) model = get_model(tokenizer=None, **model_config) predict_dataset_config = DatasetConfig.parse(config=data_config.predict_dataset_config).split_by_cls() predict_transform_config = TransformConfig.parse(config=transform_config.predict_transform_config) data = UniRigDatasetModule( process_fn=model._process_fn, predict_dataset_config=predict_dataset_config, predict_transform_config=predict_transform_config, tokenizer_config=tokenizer_config, debug=False, data_name=data_name, datapath=datapath, cls=None, ) callbacks = [] writer_config = task.writer.copy() if inference_type == "skeleton": writer_config['npz_dir'] = str(npz_dir) writer_config['output_dir'] = str(Path(output_file).parent) writer_config['output_name'] = Path(output_file).name writer_config['user_mode'] = False else: writer_config['npz_dir'] = str(skeleton_npz_dir) writer_config['output_name'] = str(output_file) writer_config['user_mode'] = True writer_config['export_fbx'] = True checkpoint_callbacks = [] if hasattr(task, 'callbacks') and task.callbacks: for cb in task.callbacks: if isinstance(cb, dict) and cb.get('__target__', '').startswith('ModelCheckpoint'): cb_kwargs = {k: v for k, v in cb.items() if k != '__target__'} checkpoint_callbacks.append(ModelCheckpoint(**cb_kwargs)) callbacks = checkpoint_callbacks + [get_writer(**writer_config, order_config=predict_transform_config.order_config)] if system_config_path is None or not Path(system_config_path).exists(): raise FileNotFoundError(f"System configuration file not found: {system_config_path}") system_config = Box(yaml.safe_load(open(system_config_path, 'r'))) system = get_system(**system_config, model=model, steps_per_epoch=1) trainer_config = task.trainer resume_from_checkpoint = download(task.resume_from_checkpoint) trainer = L.Trainer(callbacks=callbacks, logger=None, **trainer_config) trainer.predict(system, datamodule=data, ckpt_path=resume_from_checkpoint, return_predictions=False) if inference_type == "skeleton": input_name_stem = Path(input_file).stem actual_output_dir = Path(output_file).parent / input_name_stem actual_output_file = actual_output_dir / "skeleton.fbx" if not actual_output_file.exists(): alt_files = list(Path(output_file).parent.rglob("skeleton.fbx")) if alt_files: actual_output_file = alt_files[0] else: all_files = list(Path(output_file).parent.rglob("*")) raise RuntimeError(f"Skeleton FBX file not found. Expected at: {actual_output_file}") if actual_output_file != Path(output_file): shutil.copy2(actual_output_file, output_file) else: if not Path(output_file).exists(): skin_files = list(Path(output_file).parent.rglob("*skin*.fbx")) if skin_files: actual_output_file = skin_files[0] shutil.copy2(actual_output_file, output_file) else: raise RuntimeError(f"Skin FBX file not found. Expected at: {output_file}") return str(output_file) def merge_results_python(source_file: str, target_file: str, output_file: str) -> str: from src.inference.merge import transfer if not Path(source_file).exists(): raise ValueError(f"Source file does not exist: {source_file}") if not Path(target_file).exists(): raise ValueError(f"Target file does not exist: {target_file}") output_path = Path(output_file) output_path.parent.mkdir(parents=True, exist_ok=True) transfer(source=str(source_file), target=str(target_file), output=str(output_path), add_root=False) if not output_path.exists(): raise RuntimeError(f"Merge failed: Output file not created at {output_path}") if not output_path.is_file(): raise RuntimeError(f"Merge failed: Output path is not a valid file: {output_path}") return str(output_path.resolve()) @spaces.GPU() def main( input_file: str, seed: int = 12345, target_count: int = 50000, supported_formats: list = ['.obj', '.fbx', '.glb'], ) -> Tuple[List[str], List[str]]: base_dir = Path(__file__).parent temp_dir = base_dir / "tmp" temp_dir.mkdir(exist_ok=True) generated_files = [] completed_files = [] if not validate_input_file(input_file, supported_formats): raise gr.Error(f"Error: Invalid or unsupported file format. Supported formats: {', '.join(supported_formats)}") file_stem = Path(input_file).stem input_model_dir = temp_dir / f"{file_stem}_{seed}" input_model_dir.mkdir(exist_ok=True) input_file_path = Path(input_file) shutil.copy2(input_file_path, input_model_dir / input_file_path.name) input_file_path = input_model_dir / input_file_path.name try: intermediate_skeleton_file = input_model_dir / f"{file_stem}_skeleton.fbx" final_skeleton_file = input_model_dir / f"{file_stem}_skeleton_only{input_file_path.suffix}" run_inference_python( input_file=str(input_file_path), output_file=str(intermediate_skeleton_file), inference_type="skeleton", seed=seed, target_count=target_count, task_config_path="configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml", transform_config_path="configs/transform/inference_ar_transform.yaml", model_config_path="configs/model/unirig_ar_350m_1024_81920_float32.yaml", system_config_path="configs/system/ar_inference_articulationxl.yaml", tokenizer_config_path="configs/tokenizer/tokenizer_parts_articulationxl_256.yaml", data_name="raw_data.npz", ) merge_results_python(str(intermediate_skeleton_file), str(input_file_path), str(final_skeleton_file)) generated_files.append(str(final_skeleton_file)) completed_files.append(str(final_skeleton_file)) except Exception: # Return all generated and completed files so far, no error in UI return generated_files, completed_files try: intermediate_skin_file = input_model_dir / f"{file_stem}_skin.fbx" final_skin_file = input_model_dir / f"{file_stem}_skeleton_and_skinning{input_file_path.suffix}" run_inference_python( input_file=str(intermediate_skeleton_file), output_file=str(intermediate_skin_file), inference_type="skin", seed=seed, task_config_path="configs/task/quick_inference_unirig_skin.yaml", transform_config_path="configs/transform/inference_skin_transform.yaml", model_config_path="configs/model/unirig_skin.yaml", system_config_path="configs/system/skin.yaml", tokenizer_config_path=None, data_name="predict_skeleton.npz", ) merge_results_python(str(intermediate_skin_file), str(input_file_path), str(final_skin_file)) generated_files.append(str(final_skin_file)) completed_files.append(str(final_skin_file)) except Exception: return generated_files, completed_files return generated_files, completed_files def create_app(): with gr.Blocks(title="UniRig - 3D Model Rigging Demo") as interface: gr.HTML( """
Leverage deep learning to automatically generate skeletons and skinning weights for your 3D models
🔬 UniRig - Research by Tsinghua University & Tripo
📄 Paper |
🏠 Project Page |
🤗 Models