import subprocess import gradio as gr import os import time import requests import spaces import sys import shutil import tempfile import torch import cv2 import subprocess import numpy as np import trimesh import open3d as o3d from huggingface_hub import hf_hub_download import html import base64 import inspect BASE_DIR = os.path.dirname(os.path.abspath(__file__)) STATIC_DIR = os.path.join(BASE_DIR, "static") EXAMPLES_DIR = os.path.join(BASE_DIR, "examples") def prepare_local_assets(): os.makedirs(STATIC_DIR, exist_ok=True) base_url = "https://registry.npmmirror.com/three/0.160.0/files" assets = { "three.module.js": f"{base_url}/build/three.module.js", "OrbitControls.js": f"{base_url}/examples/jsm/controls/OrbitControls.js", "GLTFLoader.js": f"{base_url}/examples/jsm/loaders/GLTFLoader.js", "BufferGeometryUtils.js": f"{base_url}/examples/jsm/utils/BufferGeometryUtils.js" } for name, url in assets.items(): path = os.path.join(STATIC_DIR, name) if not os.path.exists(path): try: r = requests.get(url, verify=False, timeout=10) if r.status_code == 200: with open(path, "wb") as f: f.write(r.content) except Exception as e: print(f"Error downloading {name}: {e}") prepare_local_assets() def install_pytorch3d(): try: import pytorch3d print("โœ… PyTorch3D already installed.") return except ImportError: print("โณ PyTorch3D not found. Starting dynamic installation...") pyt_version_str = torch.__version__.split("+")[0].replace(".", "") version_str = "".join([ f"py3{sys.version_info.minor}_", f"cu{torch.version.cuda.replace('.', '')}_", f"pyt{pyt_version_str}" ]) whl_url = f"https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/pytorch3d-0.7.8-cp3{sys.version_info.minor}-cp3{sys.version_info.minor}-linux_x86_64.whl" print(f"๐Ÿ” Detected Env: {version_str}") print(f"โฌ‡๏ธ Attempting to install from Wheel: {whl_url}") try: subprocess.run([sys.executable, "-m", "pip", "install", whl_url], check=True) print("โœ… PyTorch3D installed via Wheel!") except subprocess.CalledProcessError: print("โš ๏ธ Wheel installation failed (maybe version mismatch).") print("๐Ÿ”จ Falling back to source compilation (this will take a few minutes)...") subprocess.run( [sys.executable, "-m", "pip", "install", "--no-build-isolation", "git+https://github.com/facebookresearch/pytorch3d.git@stable"], check=True ) print("โœ… PyTorch3D installed via Source Build!") install_pytorch3d() def install_mmcv(): try: import mmcv print(f"โœ… mmcv {mmcv.__version__} is already installed.") return except ImportError: print("โณ mmcv not found. Starting dynamic installation...") # 1. Detect environment versions to construct the URL dynamically # Example: CUDA 12.1 -> "121", Torch 2.4.0 -> "2.4" cuda_ver = torch.version.cuda.replace(".", "") torch_ver = ".".join(torch.__version__.split(".")[:2]) # 2. Construct the find-links URL matching OpenMMLab's structure # Structure: https://download.openmmlab.com/mmcv/dist/cu{CUDA}/torch{TORCH}/index.html find_links_url = f"https://download.openmmlab.com/mmcv/dist/cu{cuda_ver}/torch{torch_ver}/index.html" print(f"๐Ÿ” Detected Env: CUDA={cuda_ver}, Torch={torch_ver}") print(f"โฌ‡๏ธ Installing mmcv==2.2.0 from: {find_links_url}") try: # 3. Run pip install with the specific version and dynamic link subprocess.run([ sys.executable, "-m", "pip", "install", "mmcv==2.2.0", "--find-links", find_links_url ], check=True) print("โœ… mmcv installed successfully.") except subprocess.CalledProcessError: print("โš ๏ธ Installation failed. The specific version might not exist for this environment.") print("๐Ÿ”„ Attempting fallback using openmim (auto-resolve mode)...") # Fallback: Install openmim and let it handle the resolution subprocess.run([sys.executable, "-m", "pip", "install", "openmim"], check=True) subprocess.run(["mim", "install", "mmcv==2.2.0"], check=True) def install_sam2(): try: import sam2 except ImportError: print("Installing SAM 2 with patch...") subprocess.run(["git", "clone", "https://github.com/facebookresearch/segment-anything-2.git", "_tmp_sam2"], check=True) setup_path = "_tmp_sam2/setup.py" with open(setup_path, "r") as f: content = f.read() content = content.replace("torch>=2.5.1", "torch>=2.4.1") with open(setup_path, "w") as f: f.write(content) subprocess.run(["pip", "install", "--no-build-isolation", "--no-deps", "-v", "."], cwd="_tmp_sam2", check=True) shutil.rmtree("_tmp_sam2") install_sam2() sys.path.append(BASE_DIR) from unish.utils.inference_utils import ( load_model, process_video, run_inference, generate_mixed_geometries_in_memory, save_smpl_meshes_per_frame ) MODEL = None BODY_MODELS_PATH = "body_models/" # ========================================== # 4. ่พ…ๅŠฉๅ‡ฝๆ•ฐ # ========================================== def download_smpl_assets(body_models_path): if 'smpl' not in body_models_path: model_path = os.path.join(body_models_path, 'smpl') else: model_path = body_models_path target_dir = os.path.join(model_path, 'smpl') os.makedirs(target_dir, exist_ok=True) files = ["SMPL_NEUTRAL.pkl", "SMPL_MALE.pkl", "SMPL_FEMALE.pkl"] repo_id = "Murphyyyy/UniSH-Private-Assets" # <--- ไฟฎๆ”นไธบไฝ ็š„ไป“ๅบ“ token = os.environ.get("SMPL_DOWNLOAD_TOKEN") if not token: print("โŒ CRITICAL ERROR: 'SMPL_DOWNLOAD_TOKEN' not found in environment variables!") print("๐Ÿ‘‰ Since 'UniSH-Private-Assets' is likely private, inference WILL fail without a token.") for filename in files: file_path = os.path.join(target_dir, filename) if not os.path.exists(file_path): try: print(f"๐Ÿ“ฅ Downloading {filename} from {repo_id}...") hf_hub_download( repo_id=repo_id, filename=filename, token=token, local_dir=target_dir, local_dir_use_symlinks=False ) print(f"โœ… Downloaded to: {file_path}") except Exception as e: print(f"โŒ Failed to download {filename}: {e}") print(f" (Check if your HF Token has access to {repo_id})") def pack_sequence_to_glb(base_dir, output_path, start_frame, end_frame, scene_rate=1.0): scene = trimesh.Scene() scene_cloud_dir = os.path.join(base_dir, "scene_clouds_per_frame") smpl_mesh_dir = os.path.join(base_dir, "smpl_meshes_per_frame") MAX_POINTS_PER_FRAME = 60000 for i in range(start_frame, end_frame): candidates = [ os.path.join(smpl_mesh_dir, f"combined_smpl_mesh_frame_{i:04d}.ply"), os.path.join(smpl_mesh_dir, f"smpl_mesh_frame_{i:04d}.ply") ] target_human_path = None for p in candidates: if os.path.exists(p): target_human_path = p break if target_human_path: try: human_mesh = trimesh.load(target_human_path) node_name = f"frame_{i}_human" scene.add_geometry(human_mesh, node_name=node_name, geom_name=node_name) except Exception: pass scene_pcd_path = os.path.join(scene_cloud_dir, f"scene_frame_{i:04d}.ply") if os.path.exists(scene_pcd_path): try: scene_pc = trimesh.load(scene_pcd_path) if hasattr(scene_pc, 'vertices') and len(scene_pc.vertices) > 0: num_points = len(scene_pc.vertices) if num_points > MAX_POINTS_PER_FRAME: choice = np.random.choice(num_points, MAX_POINTS_PER_FRAME, replace=False) scene_pc.vertices = scene_pc.vertices[choice] if hasattr(scene_pc, 'colors') and len(scene_pc.colors) > 0: scene_pc.colors = scene_pc.colors[choice] node_name = f"frame_{i}_scene" scene.add_geometry(scene_pc, node_name=node_name, geom_name=node_name) except Exception: pass if len(scene.geometry) == 0: dummy = trimesh.creation.box(extents=[0.01, 0.01, 0.01]) scene.add_geometry(dummy, node_name='dummy') scene.export(output_path) if not os.path.exists(output_path): raise FileNotFoundError(f"Export failed: {output_path}") def get_video_duration(video_path): if not video_path: return 10.0 try: cap = cv2.VideoCapture(video_path) if not cap.isOpened(): return 10.0 fps = cap.get(cv2.CAP_PROP_FPS) frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) duration = frame_count / fps if fps > 0 else 10.0 cap.release() return duration except: return 10.0 def get_loading_html(message="Processing..."): return f"""

{message}

""" def get_player_html(abs_glb_path): def read_and_patch(filename): path = os.path.join(STATIC_DIR, filename) if not os.path.exists(path): return "" with open(path, "r", encoding="utf-8") as f: content = f.read() content = content.replace('../utils/BufferGeometryUtils.js', 'BufferGeometryUtils') content = content.replace('./BufferGeometryUtils.js', 'BufferGeometryUtils') content = content.replace('../../build/three.module.js', 'three') content = content.replace('../build/three.module.js', 'three') return content js_three = read_and_patch("three.module.js") js_orbit = read_and_patch("OrbitControls.js") js_loader = read_and_patch("GLTFLoader.js") js_buffer = read_and_patch("BufferGeometryUtils.js") def to_data_url(content, mime="text/javascript"): b64 = base64.b64encode(content.encode('utf-8')).decode('utf-8') return f"data:{mime};base64,{b64}" blob_three = to_data_url(js_three) blob_orbit = to_data_url(js_orbit) blob_loader = to_data_url(js_loader) blob_buffer = to_data_url(js_buffer) if not os.path.exists(abs_glb_path): return '
Error: Output file not found.
' with open(abs_glb_path, "rb") as f: glb_data = f.read() model_data_url = f"data:model/gltf-binary;base64,{base64.b64encode(glb_data).decode('utf-8')}" raw_html = f"""
""" return f'' if os.environ.get("SPACE_ID"): from spaces import GPU gpu_decorator = GPU(duration=120) else: def gpu_decorator(func): def wrapper(*args, **kwargs): if inspect.isgeneratorfunction(func): yield from func(*args, **kwargs) else: return func(*args, **kwargs) return wrapper @gpu_decorator def predict(video_path, start_time=0.0, end_time=10.0): duration_input = end_time - start_time if duration_input > 10.0: raise gr.Error(f"Video limit exceeded ({duration_input:.1f}s). Please keep it under 10 seconds.") if start_time >= end_time: raise gr.Error("Error: End time must be greater than Start time.") yield get_loading_html("Processing...") output_dir = tempfile.mkdtemp() trimmed_video_path = os.path.join(output_dir, "input_trim.mp4") duration = end_time - start_time cmd = ["ffmpeg", "-y", "-ss", str(start_time), "-i", video_path, "-t", str(duration), "-c:v", "libx264", "-c:a", "aac", trimmed_video_path] subprocess.run(cmd, check=True) global MODEL if MODEL is None: MODEL = load_model() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") MODEL.to(device) MODEL.eval() download_smpl_assets(BODY_MODELS_PATH) data_dict = process_video(trimmed_video_path, 6.0, 0, 518, bbox_scale=1.0) results = run_inference(MODEL, data_dict, device, chunk_size=300) seq_name = results['seq_name'] viz_scene, viz_smpl, viz_scene_only, _ = generate_mixed_geometries_in_memory( results, BODY_MODELS_PATH, fps=6.0, conf_thres=0.1 ) save_smpl_meshes_per_frame(results, output_dir, BODY_MODELS_PATH) base_dir = os.path.join(output_dir, seq_name) scene_cloud_dir = os.path.join(base_dir, "scene_clouds_per_frame") os.makedirs(scene_cloud_dir, exist_ok=True) for i, pcd in enumerate(viz_scene_only): if len(pcd.points) > 0: o3d.io.write_point_cloud(os.path.join(scene_cloud_dir, f"scene_frame_{i:04d}.ply"), pcd) tmp_glb_path = os.path.join(output_dir, "output.glb") pack_sequence_to_glb(base_dir, tmp_glb_path, 0, len(viz_scene), 0.5) yield get_player_html(tmp_glb_path) examples_list = [] if os.path.exists(EXAMPLES_DIR): examples_list = [[os.path.join("examples", f)] for f in os.listdir(EXAMPLES_DIR) if f.endswith(".mp4")] js_scrub = """(val) => { var video = document.querySelector('#input-video video'); if (video) { video.currentTime = val; } return val; }""" js_reset_video = """() => { setTimeout(() => { var video = document.querySelector('#input-video video'); if (video) { video.currentTime = 0; video.pause(); } }, 200); }""" custom_css = """ footer {visibility: hidden} h1.header-title { text-align: center; font-family: 'Segoe UI', sans-serif; font-weight: 700; color: #1f2937; margin-bottom: 2rem; } .alert-box { padding: 1rem; border-radius: 0.5rem; margin-bottom: 1rem; font-size: 0.9rem; line-height: 1.5; } .info-box { background-color: #eff6ff; border-left: 4px solid #3b82f6; color: #1e40af; } .warning-box { background-color: #fefce8; border-left: 4px solid #eab308; color: #854d0e; } .tip-box { background-color: #f0fdf4; border-left: 4px solid #22c55e; color: #15803d; } .viewer-container { box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1); border-radius: 12px; overflow: hidden; border: 1px solid #e5e7eb; background: #f9fafb; } #run-btn { background: linear-gradient(135deg, #6366f1 0%, #a855f7 100%); border: none; color: white; font-weight: bold; transition: all 0.2s; } #run-btn:hover { transform: translateY(-2px); box-shadow: 0 10px 20px rgba(99, 102, 241, 0.4); } #input-video { margin-bottom: 15px !important; } """ with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", radius_size="md"), css=custom_css, title="UniSH Demo") as demo: with gr.Column(elem_classes=["header-container"]): gr.Markdown("# UniSH: Unifying Scene and Human Reconstruction in a Feed-Forward Pass", elem_classes=["header-title"]) with gr.Row(equal_height=False): with gr.Column(scale=3, variant="panel"): gr.Markdown("### ๐Ÿ› ๏ธ Configuration") with gr.Group(): input_video = gr.Video(label="Upload Video", format="mp4", height=260, interactive=True, elem_id="input-video") with gr.Row(): start_time = gr.Slider(minimum=0, maximum=10, value=0, step=0.01, label="Start Time (s)") end_time = gr.Slider(minimum=0, maximum=10, value=3, step=0.01, label="End Time (s)") gr.HTML("""
๐Ÿ’ก Use Tips:
1. Contain only a single person.
2. No occlusion (self-occlusion is fine).
3. Keep the full body mostly visible.
4. Subject should not be too small.
""") submit_btn = gr.Button("๐Ÿš€ Start Reconstruction", variant="primary", elem_id="run-btn", size="lg") if examples_list: gr.Markdown("### ๐ŸŽฅ Examples") gr.Examples( examples=examples_list, inputs=[input_video, start_time, end_time], label="Click to try:", cache_examples=False ) with gr.Column(scale=7): gr.Markdown("### โ–ถ๏ธ Interactive Results") with gr.Group(elem_classes=["viewer-container"]): output_html = gr.HTML( label="3D Viewer", min_height=600, value='
Upload a video and click Start to view result.
' ) gr.HTML("""
โšก Performance Note:
Inference (feed-forward) is very fast, but generating visualization assets takes up most of the processing time.
๐Ÿ‘๏ธ Visual Quality:
The displayed results are downsampled for better web rendering performance.
""") def update_slider_range(video_path): if not video_path: return gr.update(value=0), gr.update(value=0) dur = get_video_duration(video_path) dur = round(dur, 2) return gr.update(maximum=dur, value=0), gr.update(maximum=dur, value=dur) input_video.change(fn=update_slider_range, inputs=[input_video], outputs=[start_time, end_time]) input_video.upload(fn=update_slider_range, inputs=[input_video], outputs=[start_time, end_time]) input_video.change(fn=None, inputs=[], outputs=[], js=js_reset_video) input_video.upload(fn=None, inputs=[], outputs=[], js=js_reset_video) start_time.change(fn=None, inputs=[start_time], outputs=None, js=js_scrub) end_time.change(fn=None, inputs=[end_time], outputs=None, js=js_scrub) submit_btn.click( fn=predict, inputs=[input_video, start_time, end_time], outputs=[output_html] ) demo.queue() demo.launch( server_name="0.0.0.0", server_port=7860, allowed_paths=[BASE_DIR, "/tmp", EXAMPLES_DIR] )