Spaces:
Build error
Build error
| import shutil | |
| import subprocess | |
| import time | |
| from pathlib import Path | |
| from typing import Tuple, List | |
| import gradio as gr | |
| import lightning as L | |
| import spaces | |
| import torch | |
| import yaml | |
| from box import Box | |
| from lightning.pytorch.callbacks import ModelCheckpoint | |
| # Install dependencies | |
| torch_version = torch.__version__.split("+")[0] | |
| cuda_version = torch.version.cuda | |
| spconv_version = "-cu121" if cuda_version else "" | |
| if cuda_version: | |
| cuda_version = f"cu{cuda_version.replace('.', '')}" | |
| else: | |
| cuda_version = "cpu" | |
| subprocess.run(f'pip install spconv{spconv_version}', shell=True) | |
| subprocess.run(f'pip install torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-{torch_version}+{cuda_version}.html --no-cache-dir', shell=True) | |
| subprocess.run(f'pip uninstall flash-attn -y && pip install flash-attn --no-build-isolation --no-cache-dir', shell=True) | |
| subprocess.run(f'pip install bpy==3.6.0 --extra-index-url https://download.blender.org/pypi/', shell=True) | |
| subprocess.run(f'pip install lightning[extra]', shell=True) | |
| subprocess.run(f'pip install litmodels', shell=True) | |
| def validate_input_file(file_path: str, supported_formats: list) -> bool: | |
| if not file_path or not Path(file_path).exists(): | |
| return False | |
| file_ext = Path(file_path).suffix.lower() | |
| return file_ext in supported_formats | |
| def extract_mesh_python(input_file: str, output_dir: str, target_count: int) -> str: | |
| from src.data.extract import extract_builtin, get_files | |
| files = get_files( | |
| data_name="raw_data.npz", | |
| inputs=str(input_file), | |
| input_dataset_dir=None, | |
| output_dataset_dir=output_dir, | |
| force_override=True, | |
| warning=False, | |
| ) | |
| if not files: | |
| raise RuntimeError("No files to extract") | |
| timestamp = str(int(time.time())) | |
| extract_builtin( | |
| output_folder=output_dir, | |
| target_count=target_count, | |
| num_runs=1, | |
| id=0, | |
| time=timestamp, | |
| files=files, | |
| ) | |
| expected_npz_dir = files[0][1] | |
| expected_npz_file = Path(expected_npz_dir) / "raw_data.npz" | |
| if not expected_npz_file.exists(): | |
| raise RuntimeError(f"Extraction failed: {expected_npz_file} not found") | |
| return expected_npz_dir | |
| def run_inference_python( | |
| input_file: str, | |
| output_file: str, | |
| inference_type: str, | |
| seed: int = 12345, | |
| npz_dir: str = None, | |
| target_count: int = 50000, | |
| task_config_path: str = None, | |
| transform_config_path: str = None, | |
| model_config_path: str = None, | |
| system_config_path: str = None, | |
| tokenizer_config_path: str = None, | |
| data_name: str = None, | |
| ) -> str: | |
| from src.data.datapath import Datapath | |
| from src.data.dataset import DatasetConfig, UniRigDatasetModule | |
| from src.data.transform import TransformConfig | |
| from src.inference.download import download | |
| from src.model.parse import get_model | |
| from src.system.parse import get_system, get_writer | |
| from src.tokenizer.parse import get_tokenizer | |
| from src.tokenizer.spec import TokenizerConfig | |
| if inference_type == "skeleton": | |
| L.seed_everything(seed, workers=True) | |
| if task_config_path is None or not Path(task_config_path).exists(): | |
| raise FileNotFoundError(f"Task configuration file not found: {task_config_path}") | |
| with open(task_config_path, 'r') as f: | |
| task = Box(yaml.safe_load(f)) | |
| if inference_type == "skeleton": | |
| if npz_dir is None: | |
| npz_dir = Path(output_file).parent / "npz" | |
| npz_dir = Path(npz_dir) | |
| npz_dir.mkdir(exist_ok=True) | |
| npz_data_dir = extract_mesh_python(input_file, npz_dir, target_count) | |
| datapath = Datapath(files=[npz_data_dir], cls=None) | |
| else: | |
| skeleton_work_dir = Path(input_file).parent | |
| all_npz_files = list(skeleton_work_dir.rglob("**/*.npz")) | |
| if not all_npz_files: | |
| raise RuntimeError(f"No NPZ files found for skin inference in {skeleton_work_dir}") | |
| skeleton_npz_dir = all_npz_files[0].parent | |
| datapath = Datapath(files=[str(skeleton_npz_dir)], cls=None) | |
| if not Path("configs/data/quick_inference.yaml").exists(): | |
| raise FileNotFoundError("Missing configs/data/quick_inference.yaml") | |
| data_config = Box(yaml.safe_load(open("configs/data/quick_inference.yaml", 'r'))) | |
| if transform_config_path is None or not Path(transform_config_path).exists(): | |
| raise FileNotFoundError(f"Transform configuration file not found: {transform_config_path}") | |
| transform_config = Box(yaml.safe_load(open(transform_config_path, 'r'))) | |
| if inference_type == "skeleton": | |
| if tokenizer_config_path is None or not Path(tokenizer_config_path).exists(): | |
| raise FileNotFoundError(f"Tokenizer configuration file not found: {tokenizer_config_path}") | |
| tokenizer_config = TokenizerConfig.parse(config=Box(yaml.safe_load(open(tokenizer_config_path, 'r')))) | |
| tokenizer = get_tokenizer(config=tokenizer_config) | |
| if model_config_path is None or not Path(model_config_path).exists(): | |
| raise FileNotFoundError(f"Model configuration file not found: {model_config_path}") | |
| model_config = Box(yaml.safe_load(open(model_config_path, 'r'))) | |
| model = get_model(tokenizer=tokenizer, **model_config) | |
| else: | |
| tokenizer_config = None | |
| tokenizer = None | |
| if model_config_path is None or not Path(model_config_path).exists(): | |
| raise FileNotFoundError(f"Model configuration file not found: {model_config_path}") | |
| model_config = Box(yaml.safe_load(open(model_config_path, 'r'))) | |
| model = get_model(tokenizer=None, **model_config) | |
| predict_dataset_config = DatasetConfig.parse(config=data_config.predict_dataset_config).split_by_cls() | |
| predict_transform_config = TransformConfig.parse(config=transform_config.predict_transform_config) | |
| data = UniRigDatasetModule( | |
| process_fn=model._process_fn, | |
| predict_dataset_config=predict_dataset_config, | |
| predict_transform_config=predict_transform_config, | |
| tokenizer_config=tokenizer_config, | |
| debug=False, | |
| data_name=data_name, | |
| datapath=datapath, | |
| cls=None, | |
| ) | |
| callbacks = [] | |
| writer_config = task.writer.copy() | |
| if inference_type == "skeleton": | |
| writer_config['npz_dir'] = str(npz_dir) | |
| writer_config['output_dir'] = str(Path(output_file).parent) | |
| writer_config['output_name'] = Path(output_file).name | |
| writer_config['user_mode'] = False | |
| else: | |
| writer_config['npz_dir'] = str(skeleton_npz_dir) | |
| writer_config['output_name'] = str(output_file) | |
| writer_config['user_mode'] = True | |
| writer_config['export_fbx'] = True | |
| checkpoint_callbacks = [] | |
| if hasattr(task, 'callbacks') and task.callbacks: | |
| for cb in task.callbacks: | |
| if isinstance(cb, dict) and cb.get('__target__', '').startswith('ModelCheckpoint'): | |
| cb_kwargs = {k: v for k, v in cb.items() if k != '__target__'} | |
| checkpoint_callbacks.append(ModelCheckpoint(**cb_kwargs)) | |
| callbacks = checkpoint_callbacks + [get_writer(**writer_config, order_config=predict_transform_config.order_config)] | |
| if system_config_path is None or not Path(system_config_path).exists(): | |
| raise FileNotFoundError(f"System configuration file not found: {system_config_path}") | |
| system_config = Box(yaml.safe_load(open(system_config_path, 'r'))) | |
| system = get_system(**system_config, model=model, steps_per_epoch=1) | |
| trainer_config = task.trainer | |
| resume_from_checkpoint = download(task.resume_from_checkpoint) | |
| trainer = L.Trainer(callbacks=callbacks, logger=None, **trainer_config) | |
| trainer.predict(system, datamodule=data, ckpt_path=resume_from_checkpoint, return_predictions=False) | |
| if inference_type == "skeleton": | |
| input_name_stem = Path(input_file).stem | |
| actual_output_dir = Path(output_file).parent / input_name_stem | |
| actual_output_file = actual_output_dir / "skeleton.fbx" | |
| if not actual_output_file.exists(): | |
| alt_files = list(Path(output_file).parent.rglob("skeleton.fbx")) | |
| if alt_files: | |
| actual_output_file = alt_files[0] | |
| else: | |
| all_files = list(Path(output_file).parent.rglob("*")) | |
| raise RuntimeError(f"Skeleton FBX file not found. Expected at: {actual_output_file}") | |
| if actual_output_file != Path(output_file): | |
| shutil.copy2(actual_output_file, output_file) | |
| else: | |
| if not Path(output_file).exists(): | |
| skin_files = list(Path(output_file).parent.rglob("*skin*.fbx")) | |
| if skin_files: | |
| actual_output_file = skin_files[0] | |
| shutil.copy2(actual_output_file, output_file) | |
| else: | |
| raise RuntimeError(f"Skin FBX file not found. Expected at: {output_file}") | |
| return str(output_file) | |
| def merge_results_python(source_file: str, target_file: str, output_file: str) -> str: | |
| from src.inference.merge import transfer | |
| if not Path(source_file).exists(): | |
| raise ValueError(f"Source file does not exist: {source_file}") | |
| if not Path(target_file).exists(): | |
| raise ValueError(f"Target file does not exist: {target_file}") | |
| output_path = Path(output_file) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| transfer(source=str(source_file), target=str(target_file), output=str(output_path), add_root=False) | |
| if not output_path.exists(): | |
| raise RuntimeError(f"Merge failed: Output file not created at {output_path}") | |
| if not output_path.is_file(): | |
| raise RuntimeError(f"Merge failed: Output path is not a valid file: {output_path}") | |
| return str(output_path.resolve()) | |
| def main( | |
| input_file: str, | |
| seed: int = 12345, | |
| target_count: int = 50000, | |
| supported_formats: list = ['.obj', '.fbx', '.glb'], | |
| ) -> Tuple[List[str], List[str]]: | |
| base_dir = Path(__file__).parent | |
| temp_dir = base_dir / "tmp" | |
| temp_dir.mkdir(exist_ok=True) | |
| generated_files = [] | |
| completed_files = [] | |
| if not validate_input_file(input_file, supported_formats): | |
| raise gr.Error(f"Error: Invalid or unsupported file format. Supported formats: {', '.join(supported_formats)}") | |
| file_stem = Path(input_file).stem | |
| input_model_dir = temp_dir / f"{file_stem}_{seed}" | |
| input_model_dir.mkdir(exist_ok=True) | |
| input_file_path = Path(input_file) | |
| shutil.copy2(input_file_path, input_model_dir / input_file_path.name) | |
| input_file_path = input_model_dir / input_file_path.name | |
| try: | |
| intermediate_skeleton_file = input_model_dir / f"{file_stem}_skeleton.fbx" | |
| final_skeleton_file = input_model_dir / f"{file_stem}_skeleton_only{input_file_path.suffix}" | |
| run_inference_python( | |
| input_file=str(input_file_path), | |
| output_file=str(intermediate_skeleton_file), | |
| inference_type="skeleton", | |
| seed=seed, | |
| target_count=target_count, | |
| task_config_path="configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml", | |
| transform_config_path="configs/transform/inference_ar_transform.yaml", | |
| model_config_path="configs/model/unirig_ar_350m_1024_81920_float32.yaml", | |
| system_config_path="configs/system/ar_inference_articulationxl.yaml", | |
| tokenizer_config_path="configs/tokenizer/tokenizer_parts_articulationxl_256.yaml", | |
| data_name="raw_data.npz", | |
| ) | |
| merge_results_python(str(intermediate_skeleton_file), str(input_file_path), str(final_skeleton_file)) | |
| generated_files.append(str(final_skeleton_file)) | |
| completed_files.append(str(final_skeleton_file)) | |
| except Exception: | |
| # Return all generated and completed files so far, no error in UI | |
| return generated_files, completed_files | |
| try: | |
| intermediate_skin_file = input_model_dir / f"{file_stem}_skin.fbx" | |
| final_skin_file = input_model_dir / f"{file_stem}_skeleton_and_skinning{input_file_path.suffix}" | |
| run_inference_python( | |
| input_file=str(intermediate_skeleton_file), | |
| output_file=str(intermediate_skin_file), | |
| inference_type="skin", | |
| seed=seed, | |
| task_config_path="configs/task/quick_inference_unirig_skin.yaml", | |
| transform_config_path="configs/transform/inference_skin_transform.yaml", | |
| model_config_path="configs/model/unirig_skin.yaml", | |
| system_config_path="configs/system/skin.yaml", | |
| tokenizer_config_path=None, | |
| data_name="predict_skeleton.npz", | |
| ) | |
| merge_results_python(str(intermediate_skin_file), str(input_file_path), str(final_skin_file)) | |
| generated_files.append(str(final_skin_file)) | |
| completed_files.append(str(final_skin_file)) | |
| except Exception: | |
| return generated_files, completed_files | |
| return generated_files, completed_files | |
| def create_app(): | |
| with gr.Blocks(title="UniRig - 3D Model Rigging Demo") as interface: | |
| gr.HTML( | |
| """ | |
| <div class="title" style="text-align: center"> | |
| <h1>馃幆 UniRig: Automated 3D Model Rigging</h1> | |
| <p style="font-size: 1.1em; color: #6b7280;"> | |
| Leverage deep learning to automatically generate skeletons and skinning weights for your 3D models | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| gr.Markdown( | |
| """## Notes: | |
| - If you are not seeing the 3D model preview and you are using chrome, go to `chrome://flags/#enable-unsafe-webgpu` and enable the flag. | |
| - Supported File Formats are `.obj`, `.fbx`, `.glb` | |
| - The process may take a few minutes depending on the model complexity and server load. | |
| """ | |
| ) | |
| with gr.Row(equal_height=True): | |
| with gr.Column(scale=1): | |
| input_3d_model = gr.Model3D(label="Upload 3D Model") | |
| with gr.Group(): | |
| with gr.Row(equal_height=True): | |
| seed = gr.Number( | |
| value=int(torch.randint(0, 100000, (1,)).item()), | |
| label="Random Seed (for reproducible results)", | |
| scale=4, | |
| ) | |
| target_count = gr.Number( | |
| value=50000, | |
| label="Target Count (points for mesh extraction)", | |
| precision=0, | |
| interactive=True, | |
| ) | |
| random_btn = gr.Button("馃攧 Random Seed", variant="secondary", scale=1) | |
| pipeline_btn = gr.Button("馃幆 Start Processing", variant="primary", size="lg") | |
| with gr.Column(): | |
| skeleton_output = gr.Model3D(label="Skeleton Output") | |
| skin_output = gr.Model3D(label="Skin Output") | |
| files_to_download = gr.Files(label="Download Files") | |
| random_btn.click( | |
| fn=lambda: int(torch.randint(0, 100000, (1,)).item()), | |
| outputs=seed, | |
| ) | |
| def pipeline_wrapper(input_file, seed_val, target_count_val): | |
| generated_files, completed_files = main(input_file, seed_val, int(target_count_val)) | |
| skeleton_file = None | |
| skin_file = None | |
| for f in completed_files: | |
| if "skeleton_only" in f: | |
| skeleton_file = f | |
| elif "skeleton_and_skinning" in f: | |
| skin_file = f | |
| return skeleton_file or gr.update(value=None), skin_file or gr.update(value=None), completed_files | |
| pipeline_btn.click( | |
| fn=pipeline_wrapper, | |
| inputs=[input_3d_model, seed, target_count], | |
| outputs=[skeleton_output, skin_output, files_to_download], | |
| ) | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center; margin-top: 2em; padding: 1em; border-radius: 8px;"> | |
| <p style="color: #6b7280;"> | |
| 馃敩 <strong>UniRig</strong> - Research by Tsinghua University & Tripo<br> | |
| 馃搫 <a href="https://arxiv.org/abs/2504.12451" target="_blank">Paper</a> | | |
| 馃彔 <a href="https://zjp-shadow.github.io/works/UniRig/" target="_blank">Project Page</a> | | |
| 馃 <a href="https://huggingface.co/VAST-AI/UniRig" target="_blank">Models</a> | |
| </p> | |
| </div> | |
| """ | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| app = create_app() | |
| app.queue().launch() | |