|
|
import shutil |
|
|
import subprocess |
|
|
import time |
|
|
from pathlib import Path |
|
|
from typing import Tuple |
|
|
|
|
|
import gradio as gr |
|
|
import lightning as L |
|
|
import spaces |
|
|
import torch |
|
|
import yaml |
|
|
|
|
|
subprocess.run('pip install flash-attn --no-build-isolation', shell=True) |
|
|
|
|
|
|
|
|
torch_version = torch.__version__.split("+")[0] |
|
|
cuda_version = torch.version.cuda |
|
|
spconv_version = "-cu121" if cuda_version else "" |
|
|
|
|
|
|
|
|
if cuda_version: |
|
|
cuda_version = f"cu{cuda_version.replace('.', '')}" |
|
|
else: |
|
|
cuda_version = "cpu" |
|
|
|
|
|
|
|
|
subprocess.run(f'pip install spconv{spconv_version}', shell=True) |
|
|
subprocess.run(f'pip install torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-{torch_version}+{cuda_version}.html --no-cache-dir', shell=True) |
|
|
|
|
|
|
|
|
|
|
|
def validate_input_file(file_path: str) -> bool: |
|
|
"""Validate if the input file format is supported.""" |
|
|
supported_formats = ['.obj', '.fbx', '.glb'] |
|
|
if not file_path or not Path(file_path).exists(): |
|
|
return False |
|
|
|
|
|
file_ext = Path(file_path).suffix.lower() |
|
|
return file_ext in supported_formats |
|
|
|
|
|
def extract_mesh_python(input_file: str, output_dir: str) -> str: |
|
|
""" |
|
|
Extract mesh data from 3D model using Python (replaces extract.sh) |
|
|
Returns path to generated .npz file |
|
|
""" |
|
|
|
|
|
from src.data.extract import get_files, extract_builtin |
|
|
|
|
|
|
|
|
files = get_files( |
|
|
data_name="raw_data.npz", |
|
|
inputs=str(input_file), |
|
|
input_dataset_dir=None, |
|
|
output_dataset_dir=output_dir, |
|
|
force_override=True, |
|
|
warning=False, |
|
|
) |
|
|
|
|
|
if not files: |
|
|
raise RuntimeError("No files to extract") |
|
|
|
|
|
|
|
|
timestamp = str(int(time.time())) |
|
|
extract_builtin( |
|
|
output_folder=output_dir, |
|
|
target_count=50000, |
|
|
num_runs=1, |
|
|
id=0, |
|
|
time=timestamp, |
|
|
files=files, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
expected_npz_dir = files[0][1] |
|
|
expected_npz_file = Path(expected_npz_dir) / "raw_data.npz" |
|
|
|
|
|
if not expected_npz_file.exists(): |
|
|
raise RuntimeError(f"Extraction failed: {expected_npz_file} not found") |
|
|
|
|
|
return expected_npz_dir |
|
|
|
|
|
def run_skeleton_inference_python(input_file: str, output_file: str, seed: int = 12345) -> str: |
|
|
""" |
|
|
Run skeleton inference using Python (replaces skeleton part of generate_skeleton.sh) |
|
|
Returns path to skeleton FBX file |
|
|
""" |
|
|
from box import Box |
|
|
|
|
|
from src.data.datapath import Datapath |
|
|
from src.data.dataset import DatasetConfig, UniRigDatasetModule |
|
|
from src.data.transform import TransformConfig |
|
|
from src.inference.download import download |
|
|
from src.model.parse import get_model |
|
|
from src.system.parse import get_system, get_writer |
|
|
from src.tokenizer.parse import get_tokenizer |
|
|
from src.tokenizer.spec import TokenizerConfig |
|
|
|
|
|
|
|
|
L.seed_everything(seed, workers=True) |
|
|
|
|
|
|
|
|
task_config_path = "configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml" |
|
|
if not Path(task_config_path).exists(): |
|
|
raise FileNotFoundError(f"Task configuration file not found: {task_config_path}") |
|
|
|
|
|
|
|
|
with open(task_config_path, 'r') as f: |
|
|
task = Box(yaml.safe_load(f)) |
|
|
|
|
|
|
|
|
npz_dir = Path(output_file).parent / "npz" |
|
|
npz_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
npz_data_dir = extract_mesh_python(input_file, npz_dir) |
|
|
|
|
|
|
|
|
datapath = Datapath(files=[npz_data_dir], cls=None) |
|
|
|
|
|
|
|
|
data_config = Box(yaml.safe_load(open("configs/data/quick_inference.yaml", 'r'))) |
|
|
transform_config = Box(yaml.safe_load(open("configs/transform/inference_ar_transform.yaml", 'r'))) |
|
|
|
|
|
|
|
|
tokenizer_config = TokenizerConfig.parse(config=Box(yaml.safe_load(open("configs/tokenizer/tokenizer_parts_articulationxl_256.yaml", 'r')))) |
|
|
tokenizer = get_tokenizer(config=tokenizer_config) |
|
|
|
|
|
|
|
|
model_config = Box(yaml.safe_load(open("configs/model/unirig_ar_350m_1024_81920_float32.yaml", 'r'))) |
|
|
model = get_model(tokenizer=tokenizer, **model_config) |
|
|
|
|
|
|
|
|
predict_dataset_config = DatasetConfig.parse(config=data_config.predict_dataset_config).split_by_cls() |
|
|
predict_transform_config = TransformConfig.parse(config=transform_config.predict_transform_config) |
|
|
|
|
|
|
|
|
data = UniRigDatasetModule( |
|
|
process_fn=model._process_fn, |
|
|
predict_dataset_config=predict_dataset_config, |
|
|
predict_transform_config=predict_transform_config, |
|
|
tokenizer_config=tokenizer_config, |
|
|
debug=False, |
|
|
data_name="raw_data.npz", |
|
|
datapath=datapath, |
|
|
cls=None, |
|
|
) |
|
|
|
|
|
|
|
|
callbacks = [] |
|
|
writer_config = task.writer.copy() |
|
|
writer_config['npz_dir'] = str(npz_dir) |
|
|
writer_config['output_dir'] = str(Path(output_file).parent) |
|
|
writer_config['output_name'] = Path(output_file).name |
|
|
writer_config['user_mode'] = False |
|
|
print(f"Writer config: {writer_config}") |
|
|
|
|
|
callbacks.append(get_writer(**writer_config, order_config=predict_transform_config.order_config)) |
|
|
|
|
|
|
|
|
system_config = Box(yaml.safe_load(open("configs/system/ar_inference_articulationxl.yaml", 'r'))) |
|
|
system = get_system(**system_config, model=model, steps_per_epoch=1) |
|
|
|
|
|
|
|
|
trainer_config = task.trainer |
|
|
resume_from_checkpoint = download(task.resume_from_checkpoint) |
|
|
|
|
|
trainer = L.Trainer(callbacks=callbacks, logger=None, **trainer_config) |
|
|
|
|
|
|
|
|
trainer.predict(system, datamodule=data, ckpt_path=resume_from_checkpoint, return_predictions=False) |
|
|
|
|
|
|
|
|
|
|
|
input_name_stem = Path(input_file).stem |
|
|
actual_output_dir = Path(output_file).parent / input_name_stem |
|
|
actual_output_file = actual_output_dir / "skeleton.fbx" |
|
|
|
|
|
if not actual_output_file.exists(): |
|
|
|
|
|
alt_files = list(Path(output_file).parent.rglob("skeleton.fbx")) |
|
|
if alt_files: |
|
|
actual_output_file = alt_files[0] |
|
|
print(f"Found skeleton at alternative location: {actual_output_file}") |
|
|
else: |
|
|
|
|
|
all_files = list(Path(output_file).parent.rglob("*")) |
|
|
print(f"Available files: {[str(f) for f in all_files]}") |
|
|
raise RuntimeError(f"Skeleton FBX file not found. Expected at: {actual_output_file}") |
|
|
|
|
|
|
|
|
if actual_output_file != Path(output_file): |
|
|
shutil.copy2(actual_output_file, output_file) |
|
|
print(f"Copied skeleton from {actual_output_file} to {output_file}") |
|
|
|
|
|
print(f"Generated skeleton at: {output_file}") |
|
|
return str(output_file) |
|
|
|
|
|
def run_skin_inference_python(skeleton_file: str, output_file: str) -> str: |
|
|
""" |
|
|
Run skin inference using Python (replaces skin part of generate_skin.sh) |
|
|
Returns path to skin FBX file |
|
|
""" |
|
|
from box import Box |
|
|
|
|
|
from src.data.datapath import Datapath |
|
|
from src.data.dataset import DatasetConfig, UniRigDatasetModule |
|
|
from src.data.transform import TransformConfig |
|
|
from src.inference.download import download |
|
|
from src.model.parse import get_model |
|
|
from src.system.parse import get_system, get_writer |
|
|
|
|
|
|
|
|
task_config_path = "configs/task/quick_inference_unirig_skin.yaml" |
|
|
with open(task_config_path, 'r') as f: |
|
|
task = Box(yaml.safe_load(f)) |
|
|
|
|
|
|
|
|
skeleton_work_dir = Path(skeleton_file).parent |
|
|
all_npz_files = list(skeleton_work_dir.rglob("**/*.npz")) |
|
|
|
|
|
|
|
|
skeleton_npz_dir = all_npz_files[0].parent |
|
|
datapath = Datapath(files=[str(skeleton_npz_dir)], cls=None) |
|
|
|
|
|
|
|
|
data_config = Box(yaml.safe_load(open("configs/data/quick_inference.yaml", 'r'))) |
|
|
transform_config = Box(yaml.safe_load(open("configs/transform/inference_skin_transform.yaml", 'r'))) |
|
|
|
|
|
|
|
|
model_config = Box(yaml.safe_load(open("configs/model/unirig_skin.yaml", 'r'))) |
|
|
model = get_model(tokenizer=None, **model_config) |
|
|
|
|
|
|
|
|
predict_dataset_config = DatasetConfig.parse(config=data_config.predict_dataset_config).split_by_cls() |
|
|
predict_transform_config = TransformConfig.parse(config=transform_config.predict_transform_config) |
|
|
|
|
|
|
|
|
data = UniRigDatasetModule( |
|
|
process_fn=model._process_fn, |
|
|
predict_dataset_config=predict_dataset_config, |
|
|
predict_transform_config=predict_transform_config, |
|
|
tokenizer_config=None, |
|
|
debug=False, |
|
|
data_name="predict_skeleton.npz", |
|
|
datapath=datapath, |
|
|
cls=None, |
|
|
) |
|
|
|
|
|
|
|
|
callbacks = [] |
|
|
writer_config = task.writer.copy() |
|
|
writer_config['npz_dir'] = str(skeleton_npz_dir) |
|
|
writer_config['output_name'] = str(output_file) |
|
|
writer_config['user_mode'] = True |
|
|
writer_config['export_fbx'] = True |
|
|
callbacks.append(get_writer(**writer_config, order_config=predict_transform_config.order_config)) |
|
|
|
|
|
|
|
|
system_config = Box(yaml.safe_load(open("configs/system/skin.yaml", 'r'))) |
|
|
system = get_system(**system_config, model=model, steps_per_epoch=1) |
|
|
|
|
|
|
|
|
trainer_config = task.trainer |
|
|
resume_from_checkpoint = download(task.resume_from_checkpoint) |
|
|
|
|
|
trainer = L.Trainer(callbacks=callbacks, logger=None, **trainer_config) |
|
|
|
|
|
|
|
|
trainer.predict(system, datamodule=data, ckpt_path=resume_from_checkpoint, return_predictions=False) |
|
|
|
|
|
|
|
|
|
|
|
if not Path(output_file).exists(): |
|
|
|
|
|
skin_files = list(Path(output_file).parent.rglob("*skin*.fbx")) |
|
|
if skin_files: |
|
|
actual_output_file = skin_files[0] |
|
|
|
|
|
shutil.copy2(actual_output_file, output_file) |
|
|
else: |
|
|
raise RuntimeError(f"Skin FBX file not found. Expected at: {output_file}") |
|
|
|
|
|
return str(output_file) |
|
|
|
|
|
def merge_results_python(source_file: str, target_file: str, output_file: str) -> str: |
|
|
""" |
|
|
Merge results using Python (replaces merge.sh) |
|
|
Returns path to merged file |
|
|
""" |
|
|
from src.inference.merge import transfer |
|
|
|
|
|
|
|
|
if not Path(source_file).exists(): |
|
|
raise ValueError(f"Source file does not exist: {source_file}") |
|
|
if not Path(target_file).exists(): |
|
|
raise ValueError(f"Target file does not exist: {target_file}") |
|
|
|
|
|
|
|
|
output_path = Path(output_file) |
|
|
output_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
transfer(source=str(source_file), target=str(target_file), output=str(output_path), add_root=False) |
|
|
|
|
|
|
|
|
if not output_path.exists(): |
|
|
raise RuntimeError(f"Merge failed: Output file not created at {output_path}") |
|
|
|
|
|
if not output_path.is_file(): |
|
|
raise RuntimeError(f"Merge failed: Output path is not a valid file: {output_path}") |
|
|
|
|
|
return str(output_path.resolve()) |
|
|
|
|
|
@spaces.GPU() |
|
|
def complete_pipeline(input_file: str, seed: int = 12345) -> Tuple[str, list]: |
|
|
""" |
|
|
Run the complete rigging pipeline: skeleton generation β skinning β merge. |
|
|
|
|
|
Args: |
|
|
input_file: Path to the input 3D model file |
|
|
seed: Random seed for reproducible results |
|
|
|
|
|
Returns: |
|
|
Tuple of (final_file_path, list_of_intermediate_files) |
|
|
""" |
|
|
|
|
|
base_dir = Path(__file__).parent |
|
|
temp_dir = base_dir / "tmp" |
|
|
temp_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
supported_formats = ['.obj', '.fbx', '.glb'] |
|
|
|
|
|
|
|
|
if not validate_input_file(input_file): |
|
|
raise gr.Error(f"Error: Invalid or unsupported file format. Supported formats: {', '.join(supported_formats)}") |
|
|
|
|
|
|
|
|
file_stem = Path(input_file).stem |
|
|
input_model_dir = temp_dir / f"{file_stem}_{seed}" |
|
|
input_model_dir.mkdir(exist_ok=True) |
|
|
|
|
|
|
|
|
input_file = Path(input_file) |
|
|
shutil.copy2(input_file, input_model_dir / input_file.name) |
|
|
input_file = input_model_dir / input_file.name |
|
|
print(f"New input file path: {input_file}") |
|
|
|
|
|
|
|
|
output_skeleton_file = input_model_dir / f"{file_stem}_skeleton.fbx" |
|
|
run_skeleton_inference_python(input_file, output_skeleton_file, seed) |
|
|
|
|
|
|
|
|
output_skin_file = input_model_dir / f"{file_stem}_skin.fbx" |
|
|
run_skin_inference_python(output_skeleton_file, output_skin_file) |
|
|
|
|
|
|
|
|
final_file = input_model_dir / f"{file_stem}_rigged.glb" |
|
|
merge_results_python(output_skin_file, input_file, final_file) |
|
|
|
|
|
return str(final_file), [str(output_skeleton_file), str(output_skin_file), str(final_file)] |
|
|
|
|
|
|
|
|
def create_app(): |
|
|
"""Create and configure the Gradio interface.""" |
|
|
|
|
|
with gr.Blocks(title="UniRig - 3D Model Rigging Demo") as interface: |
|
|
|
|
|
|
|
|
gr.HTML(""" |
|
|
<div class="title" style="text-align: center"> |
|
|
<h1>π― UniRig: Automated 3D Model Rigging</h1> |
|
|
<p style="font-size: 1.1em; color: #6b7280;"> |
|
|
Leverage deep learning to automatically generate skeletons and skinning weights for your 3D models |
|
|
</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
gr.Markdown(""" |
|
|
## π How to Use ? |
|
|
1. **Upload your 3D model** - Drop your .obj, .fbx, or .glb file in the upload area |
|
|
2. **Set random seed** (optional) - Use the same seed for reproducible results |
|
|
3. **Click "Start Complete Pipeline"** - The AI will automatically rig your model |
|
|
4. **Download results** - `_skeleton.fbx` is the base model with skeleton, `_skin.fbx` is the base model with armature/skeleton and skinning weights, and `_rigged.*` is the final rigged model ready for use. |
|
|
|
|
|
**Supported File Formats:** .obj, .fbx, .glb |
|
|
**Note:** The process may take a few minutes depending on the model complexity and server load. |
|
|
""") |
|
|
|
|
|
with gr.Row(equal_height=True): |
|
|
with gr.Column(scale=1): |
|
|
input_3d_model = gr.Model3D(label="Upload 3D Model") |
|
|
|
|
|
with gr.Row(equal_height=True): |
|
|
seed = gr.Number( |
|
|
value=12345, |
|
|
label="Random Seed (for reproducible results)", |
|
|
scale=4, |
|
|
) |
|
|
random_btn = gr.Button("π Random Seed", variant="secondary", scale=1) |
|
|
|
|
|
pipeline_btn = gr.Button("π― Start Complete Pipeline", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(): |
|
|
pipeline_skeleton_out = gr.Model3D(label="Final Rigged Model", scale=4) |
|
|
files_to_download = gr.Files(label="Download Files", scale=1) |
|
|
|
|
|
random_btn.click( |
|
|
fn=lambda: int(torch.randint(0, 100000, (1,)).item()), |
|
|
outputs=seed |
|
|
) |
|
|
|
|
|
pipeline_btn.click( |
|
|
fn=complete_pipeline, |
|
|
inputs=[input_3d_model, seed], |
|
|
outputs=[pipeline_skeleton_out, files_to_download] |
|
|
) |
|
|
|
|
|
|
|
|
gr.HTML(""" |
|
|
<div style="text-align: center; margin-top: 2em; padding: 1em; border-radius: 8px;"> |
|
|
<p style="color: #6b7280;"> |
|
|
π¬ <strong>UniRig</strong> - Research by Tsinghua University & Tripo<br> |
|
|
π <a href="https://arxiv.org/abs/2504.12451" target="_blank">Paper</a> | |
|
|
π <a href="https://zjp-shadow.github.io/works/UniRig/" target="_blank">Project Page</a> | |
|
|
π€ <a href="https://huggingface.co/VAST-AI/UniRig" target="_blank">Models</a> |
|
|
</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
return interface |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
app = create_app() |
|
|
|
|
|
|
|
|
app.queue().launch() |
|
|
|