|
|
import gradio as gr |
|
|
import tempfile |
|
|
import os |
|
|
import sys |
|
|
import shutil |
|
|
import subprocess |
|
|
import traceback |
|
|
from pathlib import Path |
|
|
from typing import Optional, Tuple, List |
|
|
import spaces |
|
|
|
|
|
import subprocess |
|
|
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent)) |
|
|
|
|
|
import trimesh |
|
|
import yaml |
|
|
|
|
|
class UniRigDemo: |
|
|
"""Main class for the UniRig Gradio demo application.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.temp_dir = tempfile.mkdtemp() |
|
|
self.results_dir = os.path.join(self.temp_dir, "results") |
|
|
os.makedirs(self.results_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
self.supported_formats = ['.obj', '.fbx', '.glb', '.gltf', '.vrm'] |
|
|
|
|
|
|
|
|
self.skeleton_model = None |
|
|
self.skin_model = None |
|
|
|
|
|
def load_config(self, config_path: str) -> dict: |
|
|
"""Load YAML configuration file.""" |
|
|
try: |
|
|
with open(config_path, 'r') as f: |
|
|
return yaml.safe_load(f) |
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Failed to load config {config_path}: {str(e)}") |
|
|
|
|
|
def validate_input_file(self, file_path: str) -> bool: |
|
|
"""Validate if the input file format is supported.""" |
|
|
if not file_path or not os.path.exists(file_path): |
|
|
return False |
|
|
|
|
|
file_ext = Path(file_path).suffix.lower() |
|
|
return file_ext in self.supported_formats |
|
|
|
|
|
def preprocess_model(self, input_file: str, output_dir: str) -> str: |
|
|
""" |
|
|
Preprocess the 3D model for inference. |
|
|
This extracts mesh data and saves it as .npz format. |
|
|
""" |
|
|
try: |
|
|
|
|
|
extract_cmd = [ |
|
|
'python', '-m', 'src.data.extract', |
|
|
'--config', 'configs/data/quick_inference.yaml', |
|
|
'--input', input_file, |
|
|
'--output_dir', output_dir, |
|
|
'--force_override', 'true', |
|
|
'--faces_target_count', '50000' |
|
|
] |
|
|
|
|
|
|
|
|
result = subprocess.run( |
|
|
extract_cmd, |
|
|
cwd=str(Path(__file__).parent), |
|
|
capture_output=True, |
|
|
text=True |
|
|
) |
|
|
|
|
|
if result.returncode != 0: |
|
|
raise RuntimeError(f"Extraction failed: {result.stderr}") |
|
|
|
|
|
|
|
|
npz_files = list(Path(output_dir).glob("*.npz")) |
|
|
if not npz_files: |
|
|
raise RuntimeError("No .npz file generated during preprocessing") |
|
|
|
|
|
return str(npz_files[0]) |
|
|
|
|
|
except Exception as e: |
|
|
raise RuntimeError(f"Preprocessing failed: {str(e)}") |
|
|
|
|
|
def generate_skeleton(self, input_file: str, seed: int = 12345) -> Tuple[str, str, str]: |
|
|
""" |
|
|
Generate skeleton for the input 3D model. |
|
|
|
|
|
Args: |
|
|
input_file: Path to the input 3D model file |
|
|
seed: Random seed for reproducible results |
|
|
|
|
|
Returns: |
|
|
Tuple of (status_message, output_file_path, preview_info) |
|
|
""" |
|
|
try: |
|
|
|
|
|
if not self.validate_input_file(input_file): |
|
|
return "Error: Invalid or unsupported file format. Supported: " + ", ".join(self.supported_formats), "", "" |
|
|
|
|
|
|
|
|
work_dir = os.path.join(self.temp_dir, f"skeleton_{seed}") |
|
|
os.makedirs(work_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
input_name = Path(input_file).name |
|
|
work_input = os.path.join(work_dir, input_name) |
|
|
shutil.copy2(input_file, work_input) |
|
|
|
|
|
|
|
|
output_file = os.path.join(work_dir, f"{Path(input_name).stem}_skeleton.fbx") |
|
|
|
|
|
skeleton_cmd = [ |
|
|
'bash', 'launch/inference/generate_skeleton.sh', |
|
|
'--input', work_input, |
|
|
'--output', output_file, |
|
|
'--seed', str(seed) |
|
|
] |
|
|
|
|
|
|
|
|
result = subprocess.run( |
|
|
skeleton_cmd, |
|
|
cwd=str(Path(__file__).parent), |
|
|
capture_output=True, |
|
|
text=True |
|
|
) |
|
|
|
|
|
if result.returncode != 0: |
|
|
return f"Error: Skeleton generation failed: {result.stderr}", "", "" |
|
|
|
|
|
if not os.path.exists(output_file): |
|
|
return "Error: Skeleton file was not generated", "", "" |
|
|
|
|
|
|
|
|
preview_info = self.generate_model_preview(output_file) |
|
|
|
|
|
return "โ
Skeleton generated successfully!", output_file, preview_info |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error: {str(e)}" |
|
|
traceback.print_exc() |
|
|
return error_msg, "", "" |
|
|
|
|
|
def generate_skinning(self, skeleton_file: str) -> Tuple[str, str, str]: |
|
|
""" |
|
|
Generate skinning weights for the skeleton. |
|
|
|
|
|
Args: |
|
|
skeleton_file: Path to the skeleton file (from skeleton generation step) |
|
|
|
|
|
Returns: |
|
|
Tuple of (status_message, output_file_path, preview_info) |
|
|
""" |
|
|
try: |
|
|
if not skeleton_file or not os.path.exists(skeleton_file): |
|
|
return "Error: No skeleton file provided or file doesn't exist", "", "" |
|
|
|
|
|
|
|
|
work_dir = Path(skeleton_file).parent |
|
|
output_file = os.path.join(work_dir, f"{Path(skeleton_file).stem}_skin.fbx") |
|
|
|
|
|
|
|
|
skin_cmd = [ |
|
|
'bash', 'launch/inference/generate_skin.sh', |
|
|
'--input', skeleton_file, |
|
|
'--output', output_file |
|
|
] |
|
|
|
|
|
|
|
|
result = subprocess.run( |
|
|
skin_cmd, |
|
|
cwd=str(Path(__file__).parent), |
|
|
capture_output=True, |
|
|
text=True |
|
|
) |
|
|
|
|
|
if result.returncode != 0: |
|
|
return f"Error: Skinning generation failed: {result.stderr}", "", "" |
|
|
|
|
|
if not os.path.exists(output_file): |
|
|
return "Error: Skinning file was not generated", "", "" |
|
|
|
|
|
|
|
|
preview_info = self.generate_model_preview(output_file) |
|
|
|
|
|
return "โ
Skinning weights generated successfully!", output_file, preview_info |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error: {str(e)}" |
|
|
traceback.print_exc() |
|
|
return error_msg, "", "" |
|
|
|
|
|
def merge_results(self, original_file: str, rigged_file: str) -> Tuple[str, str, str]: |
|
|
""" |
|
|
Merge the rigged skeleton/skin with the original model. |
|
|
|
|
|
Args: |
|
|
original_file: Path to the original 3D model |
|
|
rigged_file: Path to the rigged file (skeleton or skin) |
|
|
|
|
|
Returns: |
|
|
Tuple of (status_message, output_file_path, preview_info) |
|
|
""" |
|
|
try: |
|
|
if not original_file or not os.path.exists(original_file): |
|
|
return "Error: Original file not provided or doesn't exist", "", "" |
|
|
|
|
|
if not rigged_file or not os.path.exists(rigged_file): |
|
|
return "Error: Rigged file not provided or doesn't exist", "", "" |
|
|
|
|
|
|
|
|
work_dir = Path(rigged_file).parent |
|
|
output_file = os.path.join(work_dir, f"{Path(original_file).stem}_rigged.glb") |
|
|
|
|
|
|
|
|
merge_cmd = [ |
|
|
'bash', 'launch/inference/merge.sh', |
|
|
'--source', rigged_file, |
|
|
'--target', original_file, |
|
|
'--output', output_file |
|
|
] |
|
|
|
|
|
|
|
|
result = subprocess.run( |
|
|
merge_cmd, |
|
|
cwd=str(Path(__file__).parent), |
|
|
capture_output=True, |
|
|
text=True |
|
|
) |
|
|
|
|
|
if result.returncode != 0: |
|
|
return f"Error: Merge failed: {result.stderr}", "", "" |
|
|
|
|
|
if not os.path.exists(output_file): |
|
|
return "Error: Merged file was not generated", "", "" |
|
|
|
|
|
|
|
|
preview_info = self.generate_model_preview(output_file) |
|
|
|
|
|
return "โ
Model rigging completed successfully!", output_file, preview_info |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error: {str(e)}" |
|
|
traceback.print_exc() |
|
|
return error_msg, "", "" |
|
|
|
|
|
def generate_model_preview(self, model_path: str) -> str: |
|
|
""" |
|
|
Generate preview information for a 3D model. |
|
|
|
|
|
Args: |
|
|
model_path: Path to the model file |
|
|
|
|
|
Returns: |
|
|
HTML string with model information |
|
|
""" |
|
|
try: |
|
|
if not os.path.exists(model_path): |
|
|
return "Model file not found" |
|
|
|
|
|
|
|
|
try: |
|
|
mesh = trimesh.load(model_path) |
|
|
if hasattr(mesh, 'vertices'): |
|
|
vertices_count = len(mesh.vertices) |
|
|
faces_count = len(mesh.faces) if hasattr(mesh, 'faces') else 0 |
|
|
else: |
|
|
vertices_count = 0 |
|
|
faces_count = 0 |
|
|
except Exception: |
|
|
vertices_count = 0 |
|
|
faces_count = 0 |
|
|
|
|
|
file_size = os.path.getsize(model_path) |
|
|
file_size_mb = file_size / (1024 * 1024) |
|
|
|
|
|
preview_html = f""" |
|
|
<div style="padding: 10px; border: 1px solid #ddd; border-radius: 5px; background-color: #f9f9f9;"> |
|
|
<h4>๐ Model Information</h4> |
|
|
<p><strong>File:</strong> {Path(model_path).name}</p> |
|
|
<p><strong>Size:</strong> {file_size_mb:.2f} MB</p> |
|
|
<p><strong>Vertices:</strong> {vertices_count:,}</p> |
|
|
<p><strong>Faces:</strong> {faces_count:,}</p> |
|
|
<p><strong>Format:</strong> {Path(model_path).suffix.upper()}</p> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
return preview_html |
|
|
|
|
|
except Exception as e: |
|
|
return f"Error generating preview: {str(e)}" |
|
|
|
|
|
def complete_pipeline(self, input_file: str, seed: int = 12345) -> Tuple[str, str, str, str, str]: |
|
|
""" |
|
|
Run the complete rigging pipeline: skeleton generation โ skinning โ merge. |
|
|
|
|
|
Args: |
|
|
input_file: Path to the input 3D model file |
|
|
seed: Random seed for reproducible results |
|
|
|
|
|
Returns: |
|
|
Tuple of status messages and file paths for each step |
|
|
""" |
|
|
try: |
|
|
|
|
|
skeleton_status, skeleton_file, skeleton_preview = self.generate_skeleton(input_file, seed) |
|
|
if not skeleton_file: |
|
|
return skeleton_status, "", "", "", "" |
|
|
|
|
|
|
|
|
skin_status, skin_file, skin_preview = self.generate_skinning(skeleton_file) |
|
|
if not skin_file: |
|
|
return f"{skeleton_status}\n{skin_status}", skeleton_file, "", "", "" |
|
|
|
|
|
|
|
|
merge_status, final_file, final_preview = self.merge_results(input_file, skin_file) |
|
|
|
|
|
|
|
|
combined_status = f""" |
|
|
๐๏ธ **Pipeline Complete!** |
|
|
|
|
|
**Step 1 - Skeleton Generation:** โ
Complete |
|
|
**Step 2 - Skinning Weights:** โ
Complete |
|
|
**Step 3 - Final Merge:** โ
Complete |
|
|
|
|
|
{merge_status} |
|
|
""" |
|
|
|
|
|
return combined_status, skeleton_file, skin_file, final_file, final_preview |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Pipeline Error: {str(e)}" |
|
|
traceback.print_exc() |
|
|
return error_msg, "", "", "", "" |
|
|
|
|
|
|
|
|
def create_demo_interface(): |
|
|
"""Create and configure the Gradio interface.""" |
|
|
|
|
|
demo_instance = UniRigDemo() |
|
|
|
|
|
with gr.Blocks(title="UniRig - 3D Model Rigging Demo") as interface: |
|
|
|
|
|
|
|
|
gr.HTML(""" |
|
|
<div class="title" style="text-align: center"> |
|
|
<h1>๐ฏ UniRig: Automated 3D Model Rigging</h1> |
|
|
<p style="font-size: 1.1em; color: #6b7280;"> |
|
|
Leverage deep learning to automatically generate skeletons and skinning weights for your 3D models |
|
|
</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
|
|
|
gr.HTML(""" |
|
|
<h3>๐ About UniRig</h3> |
|
|
<p>UniRig is a state-of-the-art framework that automates the complex process of 3D model rigging:</p> |
|
|
<ul> |
|
|
<li><strong>Skeleton Generation:</strong> AI predicts optimal bone structures</li> |
|
|
<li><strong>Skinning Weights:</strong> Automatic vertex-to-bone weight assignment</li> |
|
|
<li><strong>Universal Support:</strong> Works with humans, animals, and objects</li> |
|
|
</ul> |
|
|
<p><strong>Supported formats:</strong> .obj, .fbx, .glb, .gltf, .vrm</p> |
|
|
""") |
|
|
|
|
|
|
|
|
with gr.Tabs(): |
|
|
|
|
|
|
|
|
with gr.Tab("๐ Complete Pipeline", elem_id="pipeline-tab"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
pipeline_input = gr.Model3D( |
|
|
label="Upload 3D Model", |
|
|
display_mode="solid", |
|
|
) |
|
|
pipeline_seed = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=99999, |
|
|
value=12345, |
|
|
step=1, |
|
|
label="Random Seed (for reproducible results)" |
|
|
) |
|
|
pipeline_btn = gr.Button("๐ฏ Start Complete Pipeline", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
pipeline_status = gr.Markdown("Ready to process your 3D model...") |
|
|
pipeline_preview = gr.HTML("") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
gr.HTML("<h4>๐ฅ Download Results</h4>") |
|
|
pipeline_skeleton_out = gr.File(label="Skeleton (.fbx)", visible=False) |
|
|
pipeline_skin_out = gr.File(label="Skinning Weights (.fbx)", visible=False) |
|
|
pipeline_final_out = gr.File(label="Final Rigged Model (.glb)", visible=False) |
|
|
|
|
|
pipeline_btn.click( |
|
|
fn=demo_instance.complete_pipeline, |
|
|
inputs=[pipeline_input, pipeline_seed], |
|
|
outputs=[pipeline_status, pipeline_skeleton_out, pipeline_skin_out, |
|
|
pipeline_final_out, pipeline_preview] |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Tab("๐ง Step-by-Step Process", elem_id="stepwise-tab"): |
|
|
gr.HTML("<h3>Manual Step-by-Step Rigging Process</h3>") |
|
|
gr.HTML("<p>Process your model step by step with full control over each stage.</p>") |
|
|
|
|
|
|
|
|
with gr.Group(): |
|
|
gr.HTML("<h4>Step 1: Skeleton Generation</h4>") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
step1_input = gr.File( |
|
|
label="Upload 3D Model", |
|
|
file_types=[".obj", ".fbx", ".glb", ".gltf", ".vrm"], |
|
|
type="filepath" |
|
|
) |
|
|
step1_seed = gr.Slider( |
|
|
minimum=1, |
|
|
maximum=99999, |
|
|
value=12345, |
|
|
step=1, |
|
|
label="Random Seed" |
|
|
) |
|
|
step1_btn = gr.Button("Generate Skeleton", variant="secondary") |
|
|
|
|
|
with gr.Column(): |
|
|
step1_status = gr.Markdown("Upload a model to start...") |
|
|
step1_preview = gr.HTML("") |
|
|
step1_output = gr.File(label="Skeleton File (.fbx)", visible=False) |
|
|
|
|
|
|
|
|
with gr.Group(): |
|
|
gr.HTML("<h4>Step 2: Skinning Weight Generation</h4>") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
step2_input = gr.File( |
|
|
label="Skeleton File (from Step 1)", |
|
|
file_types=[".fbx"], |
|
|
type="filepath" |
|
|
) |
|
|
step2_btn = gr.Button("Generate Skinning Weights", variant="secondary") |
|
|
|
|
|
with gr.Column(): |
|
|
step2_status = gr.Markdown("Complete Step 1 first...") |
|
|
step2_preview = gr.HTML("") |
|
|
step2_output = gr.File(label="Skinning File (.fbx)", visible=False) |
|
|
|
|
|
|
|
|
with gr.Group(): |
|
|
gr.HTML("<h4>Step 3: Merge with Original Model</h4>") |
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
step3_original = gr.File( |
|
|
label="Original Model", |
|
|
file_types=[".obj", ".fbx", ".glb", ".gltf", ".vrm"], |
|
|
type="filepath" |
|
|
) |
|
|
step3_rigged = gr.File( |
|
|
label="Rigged File (from Step 2)", |
|
|
file_types=[".fbx"], |
|
|
type="filepath" |
|
|
) |
|
|
step3_btn = gr.Button("Merge Results", variant="secondary") |
|
|
|
|
|
with gr.Column(): |
|
|
step3_status = gr.Markdown("Complete previous steps first...") |
|
|
step3_preview = gr.HTML("") |
|
|
step3_output = gr.File(label="Final Rigged Model (.glb)", visible=False) |
|
|
|
|
|
|
|
|
step1_btn.click( |
|
|
fn=demo_instance.generate_skeleton, |
|
|
inputs=[step1_input, step1_seed], |
|
|
outputs=[step1_status, step1_output, step1_preview] |
|
|
) |
|
|
|
|
|
step2_btn.click( |
|
|
fn=demo_instance.generate_skinning, |
|
|
inputs=[step2_input], |
|
|
outputs=[step2_status, step2_output, step2_preview] |
|
|
) |
|
|
|
|
|
step3_btn.click( |
|
|
fn=demo_instance.merge_results, |
|
|
inputs=[step3_original, step3_rigged], |
|
|
outputs=[step3_status, step3_output, step3_preview] |
|
|
) |
|
|
|
|
|
|
|
|
step1_output.change( |
|
|
fn=lambda x: x, |
|
|
inputs=[step1_output], |
|
|
outputs=[step2_input] |
|
|
) |
|
|
|
|
|
|
|
|
step2_output.change( |
|
|
fn=lambda x: x, |
|
|
inputs=[step2_output], |
|
|
outputs=[step3_rigged] |
|
|
) |
|
|
|
|
|
|
|
|
gr.HTML(""" |
|
|
<div style="text-align: center; margin-top: 2em; padding: 1em; border-radius: 8px;"> |
|
|
<p style="color: #6b7280;"> |
|
|
๐ฌ <strong>UniRig</strong> - Research by Tsinghua University & Tripo<br> |
|
|
๐ <a href="https://arxiv.org/abs/2504.12451" target="_blank">Paper</a> | |
|
|
๐ <a href="https://zjp-shadow.github.io/works/UniRig/" target="_blank">Project Page</a> | |
|
|
๐ค <a href="https://huggingface.co/VAST-AI/UniRig" target="_blank">Models</a> |
|
|
</p> |
|
|
<p style="color: #9ca3af; font-size: 0.9em;"> |
|
|
โก Powered by PyTorch & Gradio | ๐ฏ GPU recommended for optimal performance |
|
|
</p> |
|
|
</div> |
|
|
""") |
|
|
|
|
|
return interface |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main function to launch the Gradio demo.""" |
|
|
|
|
|
|
|
|
demo = create_demo_interface() |
|
|
|
|
|
|
|
|
demo.queue().launch() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |