UniRigExtras / app.py
ivalenzuela's picture
A帽adir dependencia de litmodels en la instalaci贸n de paquetes
b143961
import shutil
import subprocess
import time
from pathlib import Path
from typing import Tuple, List
import gradio as gr
import lightning as L
import spaces
import torch
import yaml
from box import Box
from lightning.pytorch.callbacks import ModelCheckpoint
# Install dependencies
torch_version = torch.__version__.split("+")[0]
cuda_version = torch.version.cuda
spconv_version = "-cu121" if cuda_version else ""
if cuda_version:
cuda_version = f"cu{cuda_version.replace('.', '')}"
else:
cuda_version = "cpu"
subprocess.run(f'pip install spconv{spconv_version}', shell=True)
subprocess.run(f'pip install torch_scatter torch_cluster -f https://data.pyg.org/whl/torch-{torch_version}+{cuda_version}.html --no-cache-dir', shell=True)
subprocess.run(f'pip uninstall flash-attn -y && pip install flash-attn --no-build-isolation --no-cache-dir', shell=True)
subprocess.run(f'pip install bpy==3.6.0 --extra-index-url https://download.blender.org/pypi/', shell=True)
subprocess.run(f'pip install lightning[extra]', shell=True)
subprocess.run(f'pip install litmodels', shell=True)
def validate_input_file(file_path: str, supported_formats: list) -> bool:
if not file_path or not Path(file_path).exists():
return False
file_ext = Path(file_path).suffix.lower()
return file_ext in supported_formats
def extract_mesh_python(input_file: str, output_dir: str, target_count: int) -> str:
from src.data.extract import extract_builtin, get_files
files = get_files(
data_name="raw_data.npz",
inputs=str(input_file),
input_dataset_dir=None,
output_dataset_dir=output_dir,
force_override=True,
warning=False,
)
if not files:
raise RuntimeError("No files to extract")
timestamp = str(int(time.time()))
extract_builtin(
output_folder=output_dir,
target_count=target_count,
num_runs=1,
id=0,
time=timestamp,
files=files,
)
expected_npz_dir = files[0][1]
expected_npz_file = Path(expected_npz_dir) / "raw_data.npz"
if not expected_npz_file.exists():
raise RuntimeError(f"Extraction failed: {expected_npz_file} not found")
return expected_npz_dir
def run_inference_python(
input_file: str,
output_file: str,
inference_type: str,
seed: int = 12345,
npz_dir: str = None,
target_count: int = 50000,
task_config_path: str = None,
transform_config_path: str = None,
model_config_path: str = None,
system_config_path: str = None,
tokenizer_config_path: str = None,
data_name: str = None,
) -> str:
from src.data.datapath import Datapath
from src.data.dataset import DatasetConfig, UniRigDatasetModule
from src.data.transform import TransformConfig
from src.inference.download import download
from src.model.parse import get_model
from src.system.parse import get_system, get_writer
from src.tokenizer.parse import get_tokenizer
from src.tokenizer.spec import TokenizerConfig
if inference_type == "skeleton":
L.seed_everything(seed, workers=True)
if task_config_path is None or not Path(task_config_path).exists():
raise FileNotFoundError(f"Task configuration file not found: {task_config_path}")
with open(task_config_path, 'r') as f:
task = Box(yaml.safe_load(f))
if inference_type == "skeleton":
if npz_dir is None:
npz_dir = Path(output_file).parent / "npz"
npz_dir = Path(npz_dir)
npz_dir.mkdir(exist_ok=True)
npz_data_dir = extract_mesh_python(input_file, npz_dir, target_count)
datapath = Datapath(files=[npz_data_dir], cls=None)
else:
skeleton_work_dir = Path(input_file).parent
all_npz_files = list(skeleton_work_dir.rglob("**/*.npz"))
if not all_npz_files:
raise RuntimeError(f"No NPZ files found for skin inference in {skeleton_work_dir}")
skeleton_npz_dir = all_npz_files[0].parent
datapath = Datapath(files=[str(skeleton_npz_dir)], cls=None)
if not Path("configs/data/quick_inference.yaml").exists():
raise FileNotFoundError("Missing configs/data/quick_inference.yaml")
data_config = Box(yaml.safe_load(open("configs/data/quick_inference.yaml", 'r')))
if transform_config_path is None or not Path(transform_config_path).exists():
raise FileNotFoundError(f"Transform configuration file not found: {transform_config_path}")
transform_config = Box(yaml.safe_load(open(transform_config_path, 'r')))
if inference_type == "skeleton":
if tokenizer_config_path is None or not Path(tokenizer_config_path).exists():
raise FileNotFoundError(f"Tokenizer configuration file not found: {tokenizer_config_path}")
tokenizer_config = TokenizerConfig.parse(config=Box(yaml.safe_load(open(tokenizer_config_path, 'r'))))
tokenizer = get_tokenizer(config=tokenizer_config)
if model_config_path is None or not Path(model_config_path).exists():
raise FileNotFoundError(f"Model configuration file not found: {model_config_path}")
model_config = Box(yaml.safe_load(open(model_config_path, 'r')))
model = get_model(tokenizer=tokenizer, **model_config)
else:
tokenizer_config = None
tokenizer = None
if model_config_path is None or not Path(model_config_path).exists():
raise FileNotFoundError(f"Model configuration file not found: {model_config_path}")
model_config = Box(yaml.safe_load(open(model_config_path, 'r')))
model = get_model(tokenizer=None, **model_config)
predict_dataset_config = DatasetConfig.parse(config=data_config.predict_dataset_config).split_by_cls()
predict_transform_config = TransformConfig.parse(config=transform_config.predict_transform_config)
data = UniRigDatasetModule(
process_fn=model._process_fn,
predict_dataset_config=predict_dataset_config,
predict_transform_config=predict_transform_config,
tokenizer_config=tokenizer_config,
debug=False,
data_name=data_name,
datapath=datapath,
cls=None,
)
callbacks = []
writer_config = task.writer.copy()
if inference_type == "skeleton":
writer_config['npz_dir'] = str(npz_dir)
writer_config['output_dir'] = str(Path(output_file).parent)
writer_config['output_name'] = Path(output_file).name
writer_config['user_mode'] = False
else:
writer_config['npz_dir'] = str(skeleton_npz_dir)
writer_config['output_name'] = str(output_file)
writer_config['user_mode'] = True
writer_config['export_fbx'] = True
checkpoint_callbacks = []
if hasattr(task, 'callbacks') and task.callbacks:
for cb in task.callbacks:
if isinstance(cb, dict) and cb.get('__target__', '').startswith('ModelCheckpoint'):
cb_kwargs = {k: v for k, v in cb.items() if k != '__target__'}
checkpoint_callbacks.append(ModelCheckpoint(**cb_kwargs))
callbacks = checkpoint_callbacks + [get_writer(**writer_config, order_config=predict_transform_config.order_config)]
if system_config_path is None or not Path(system_config_path).exists():
raise FileNotFoundError(f"System configuration file not found: {system_config_path}")
system_config = Box(yaml.safe_load(open(system_config_path, 'r')))
system = get_system(**system_config, model=model, steps_per_epoch=1)
trainer_config = task.trainer
resume_from_checkpoint = download(task.resume_from_checkpoint)
trainer = L.Trainer(callbacks=callbacks, logger=None, **trainer_config)
trainer.predict(system, datamodule=data, ckpt_path=resume_from_checkpoint, return_predictions=False)
if inference_type == "skeleton":
input_name_stem = Path(input_file).stem
actual_output_dir = Path(output_file).parent / input_name_stem
actual_output_file = actual_output_dir / "skeleton.fbx"
if not actual_output_file.exists():
alt_files = list(Path(output_file).parent.rglob("skeleton.fbx"))
if alt_files:
actual_output_file = alt_files[0]
else:
all_files = list(Path(output_file).parent.rglob("*"))
raise RuntimeError(f"Skeleton FBX file not found. Expected at: {actual_output_file}")
if actual_output_file != Path(output_file):
shutil.copy2(actual_output_file, output_file)
else:
if not Path(output_file).exists():
skin_files = list(Path(output_file).parent.rglob("*skin*.fbx"))
if skin_files:
actual_output_file = skin_files[0]
shutil.copy2(actual_output_file, output_file)
else:
raise RuntimeError(f"Skin FBX file not found. Expected at: {output_file}")
return str(output_file)
def merge_results_python(source_file: str, target_file: str, output_file: str) -> str:
from src.inference.merge import transfer
if not Path(source_file).exists():
raise ValueError(f"Source file does not exist: {source_file}")
if not Path(target_file).exists():
raise ValueError(f"Target file does not exist: {target_file}")
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
transfer(source=str(source_file), target=str(target_file), output=str(output_path), add_root=False)
if not output_path.exists():
raise RuntimeError(f"Merge failed: Output file not created at {output_path}")
if not output_path.is_file():
raise RuntimeError(f"Merge failed: Output path is not a valid file: {output_path}")
return str(output_path.resolve())
@spaces.GPU()
def main(
input_file: str,
seed: int = 12345,
target_count: int = 50000,
supported_formats: list = ['.obj', '.fbx', '.glb'],
) -> Tuple[List[str], List[str]]:
base_dir = Path(__file__).parent
temp_dir = base_dir / "tmp"
temp_dir.mkdir(exist_ok=True)
generated_files = []
completed_files = []
if not validate_input_file(input_file, supported_formats):
raise gr.Error(f"Error: Invalid or unsupported file format. Supported formats: {', '.join(supported_formats)}")
file_stem = Path(input_file).stem
input_model_dir = temp_dir / f"{file_stem}_{seed}"
input_model_dir.mkdir(exist_ok=True)
input_file_path = Path(input_file)
shutil.copy2(input_file_path, input_model_dir / input_file_path.name)
input_file_path = input_model_dir / input_file_path.name
try:
intermediate_skeleton_file = input_model_dir / f"{file_stem}_skeleton.fbx"
final_skeleton_file = input_model_dir / f"{file_stem}_skeleton_only{input_file_path.suffix}"
run_inference_python(
input_file=str(input_file_path),
output_file=str(intermediate_skeleton_file),
inference_type="skeleton",
seed=seed,
target_count=target_count,
task_config_path="configs/task/quick_inference_skeleton_articulationxl_ar_256.yaml",
transform_config_path="configs/transform/inference_ar_transform.yaml",
model_config_path="configs/model/unirig_ar_350m_1024_81920_float32.yaml",
system_config_path="configs/system/ar_inference_articulationxl.yaml",
tokenizer_config_path="configs/tokenizer/tokenizer_parts_articulationxl_256.yaml",
data_name="raw_data.npz",
)
merge_results_python(str(intermediate_skeleton_file), str(input_file_path), str(final_skeleton_file))
generated_files.append(str(final_skeleton_file))
completed_files.append(str(final_skeleton_file))
except Exception:
# Return all generated and completed files so far, no error in UI
return generated_files, completed_files
try:
intermediate_skin_file = input_model_dir / f"{file_stem}_skin.fbx"
final_skin_file = input_model_dir / f"{file_stem}_skeleton_and_skinning{input_file_path.suffix}"
run_inference_python(
input_file=str(intermediate_skeleton_file),
output_file=str(intermediate_skin_file),
inference_type="skin",
seed=seed,
task_config_path="configs/task/quick_inference_unirig_skin.yaml",
transform_config_path="configs/transform/inference_skin_transform.yaml",
model_config_path="configs/model/unirig_skin.yaml",
system_config_path="configs/system/skin.yaml",
tokenizer_config_path=None,
data_name="predict_skeleton.npz",
)
merge_results_python(str(intermediate_skin_file), str(input_file_path), str(final_skin_file))
generated_files.append(str(final_skin_file))
completed_files.append(str(final_skin_file))
except Exception:
return generated_files, completed_files
return generated_files, completed_files
def create_app():
with gr.Blocks(title="UniRig - 3D Model Rigging Demo") as interface:
gr.HTML(
"""
<div class="title" style="text-align: center">
<h1>馃幆 UniRig: Automated 3D Model Rigging</h1>
<p style="font-size: 1.1em; color: #6b7280;">
Leverage deep learning to automatically generate skeletons and skinning weights for your 3D models
</p>
</div>
"""
)
gr.Markdown(
"""## Notes:
- If you are not seeing the 3D model preview and you are using chrome, go to `chrome://flags/#enable-unsafe-webgpu` and enable the flag.
- Supported File Formats are `.obj`, `.fbx`, `.glb`
- The process may take a few minutes depending on the model complexity and server load.
"""
)
with gr.Row(equal_height=True):
with gr.Column(scale=1):
input_3d_model = gr.Model3D(label="Upload 3D Model")
with gr.Group():
with gr.Row(equal_height=True):
seed = gr.Number(
value=int(torch.randint(0, 100000, (1,)).item()),
label="Random Seed (for reproducible results)",
scale=4,
)
target_count = gr.Number(
value=50000,
label="Target Count (points for mesh extraction)",
precision=0,
interactive=True,
)
random_btn = gr.Button("馃攧 Random Seed", variant="secondary", scale=1)
pipeline_btn = gr.Button("馃幆 Start Processing", variant="primary", size="lg")
with gr.Column():
skeleton_output = gr.Model3D(label="Skeleton Output")
skin_output = gr.Model3D(label="Skin Output")
files_to_download = gr.Files(label="Download Files")
random_btn.click(
fn=lambda: int(torch.randint(0, 100000, (1,)).item()),
outputs=seed,
)
def pipeline_wrapper(input_file, seed_val, target_count_val):
generated_files, completed_files = main(input_file, seed_val, int(target_count_val))
skeleton_file = None
skin_file = None
for f in completed_files:
if "skeleton_only" in f:
skeleton_file = f
elif "skeleton_and_skinning" in f:
skin_file = f
return skeleton_file or gr.update(value=None), skin_file or gr.update(value=None), completed_files
pipeline_btn.click(
fn=pipeline_wrapper,
inputs=[input_3d_model, seed, target_count],
outputs=[skeleton_output, skin_output, files_to_download],
)
gr.HTML(
"""
<div style="text-align: center; margin-top: 2em; padding: 1em; border-radius: 8px;">
<p style="color: #6b7280;">
馃敩 <strong>UniRig</strong> - Research by Tsinghua University & Tripo<br>
馃搫 <a href="https://arxiv.org/abs/2504.12451" target="_blank">Paper</a> |
馃彔 <a href="https://zjp-shadow.github.io/works/UniRig/" target="_blank">Project Page</a> |
馃 <a href="https://huggingface.co/VAST-AI/UniRig" target="_blank">Models</a>
</p>
</div>
"""
)
return interface
if __name__ == "__main__":
app = create_app()
app.queue().launch()