Spaces:
Running
on
Zero
Running
on
Zero
| import subprocess | |
| import gradio as gr | |
| import os | |
| import time | |
| import requests | |
| import spaces | |
| import sys | |
| import shutil | |
| import tempfile | |
| import torch | |
| import cv2 | |
| import subprocess | |
| import numpy as np | |
| import trimesh | |
| import open3d as o3d | |
| from huggingface_hub import hf_hub_download | |
| import html | |
| import base64 | |
| import inspect | |
| BASE_DIR = os.path.dirname(os.path.abspath(__file__)) | |
| STATIC_DIR = os.path.join(BASE_DIR, "static") | |
| EXAMPLES_DIR = os.path.join(BASE_DIR, "examples") | |
| def prepare_local_assets(): | |
| os.makedirs(STATIC_DIR, exist_ok=True) | |
| base_url = "https://registry.npmmirror.com/three/0.160.0/files" | |
| assets = { | |
| "three.module.js": f"{base_url}/build/three.module.js", | |
| "OrbitControls.js": f"{base_url}/examples/jsm/controls/OrbitControls.js", | |
| "GLTFLoader.js": f"{base_url}/examples/jsm/loaders/GLTFLoader.js", | |
| "BufferGeometryUtils.js": f"{base_url}/examples/jsm/utils/BufferGeometryUtils.js" | |
| } | |
| for name, url in assets.items(): | |
| path = os.path.join(STATIC_DIR, name) | |
| if not os.path.exists(path): | |
| try: | |
| r = requests.get(url, verify=False, timeout=10) | |
| if r.status_code == 200: | |
| with open(path, "wb") as f: | |
| f.write(r.content) | |
| except Exception as e: | |
| print(f"Error downloading {name}: {e}") | |
| prepare_local_assets() | |
| def install_pytorch3d(): | |
| try: | |
| import pytorch3d | |
| print("✅ PyTorch3D already installed.") | |
| return | |
| except ImportError: | |
| print("⏳ PyTorch3D not found. Starting dynamic installation...") | |
| pyt_version_str = torch.__version__.split("+")[0].replace(".", "") | |
| version_str = "".join([ | |
| f"py3{sys.version_info.minor}_", | |
| f"cu{torch.version.cuda.replace('.', '')}_", | |
| f"pyt{pyt_version_str}" | |
| ]) | |
| whl_url = f"https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/pytorch3d-0.7.8-cp3{sys.version_info.minor}-cp3{sys.version_info.minor}-linux_x86_64.whl" | |
| print(f"🔍 Detected Env: {version_str}") | |
| print(f"⬇️ Attempting to install from Wheel: {whl_url}") | |
| try: | |
| subprocess.run([sys.executable, "-m", "pip", "install", whl_url], check=True) | |
| print("✅ PyTorch3D installed via Wheel!") | |
| except subprocess.CalledProcessError: | |
| print("⚠️ Wheel installation failed (maybe version mismatch).") | |
| print("🔨 Falling back to source compilation (this will take a few minutes)...") | |
| subprocess.run( | |
| [sys.executable, "-m", "pip", "install", "--no-build-isolation", "git+https://github.com/facebookresearch/pytorch3d.git@stable"], | |
| check=True | |
| ) | |
| print("✅ PyTorch3D installed via Source Build!") | |
| install_pytorch3d() | |
| def install_mmcv(): | |
| try: | |
| import mmcv | |
| print(f"✅ mmcv {mmcv.__version__} is already installed.") | |
| return | |
| except ImportError: | |
| print("⏳ mmcv not found. Starting dynamic installation...") | |
| # 1. Detect environment versions to construct the URL dynamically | |
| # Example: CUDA 12.1 -> "121", Torch 2.4.0 -> "2.4" | |
| cuda_ver = torch.version.cuda.replace(".", "") | |
| torch_ver = ".".join(torch.__version__.split(".")[:2]) | |
| # 2. Construct the find-links URL matching OpenMMLab's structure | |
| # Structure: https://download.openmmlab.com/mmcv/dist/cu{CUDA}/torch{TORCH}/index.html | |
| find_links_url = f"https://download.openmmlab.com/mmcv/dist/cu{cuda_ver}/torch{torch_ver}/index.html" | |
| print(f"🔍 Detected Env: CUDA={cuda_ver}, Torch={torch_ver}") | |
| print(f"⬇️ Installing mmcv==2.2.0 from: {find_links_url}") | |
| try: | |
| # 3. Run pip install with the specific version and dynamic link | |
| subprocess.run([ | |
| sys.executable, "-m", "pip", "install", | |
| "mmcv==2.2.0", | |
| "--find-links", find_links_url | |
| ], check=True) | |
| print("✅ mmcv installed successfully.") | |
| except subprocess.CalledProcessError: | |
| print("⚠️ Installation failed. The specific version might not exist for this environment.") | |
| print("🔄 Attempting fallback using openmim (auto-resolve mode)...") | |
| # Fallback: Install openmim and let it handle the resolution | |
| subprocess.run([sys.executable, "-m", "pip", "install", "openmim"], check=True) | |
| subprocess.run(["mim", "install", "mmcv==2.2.0"], check=True) | |
| def install_sam2(): | |
| try: | |
| import sam2 | |
| except ImportError: | |
| print("Installing SAM 2 with patch...") | |
| subprocess.run(["git", "clone", "https://github.com/facebookresearch/segment-anything-2.git", "_tmp_sam2"], check=True) | |
| setup_path = "_tmp_sam2/setup.py" | |
| with open(setup_path, "r") as f: | |
| content = f.read() | |
| content = content.replace("torch>=2.5.1", "torch>=2.4.1") | |
| with open(setup_path, "w") as f: | |
| f.write(content) | |
| subprocess.run(["pip", "install", "--no-build-isolation", "--no-deps", "-v", "."], cwd="_tmp_sam2", check=True) | |
| shutil.rmtree("_tmp_sam2") | |
| install_sam2() | |
| sys.path.append(BASE_DIR) | |
| from unish.utils.inference_utils import ( | |
| load_model, process_video, run_inference, | |
| generate_mixed_geometries_in_memory, | |
| save_smpl_meshes_per_frame | |
| ) | |
| MODEL = None | |
| BODY_MODELS_PATH = "body_models/" | |
| # ========================================== | |
| # 4. 辅助函数 | |
| # ========================================== | |
| def download_smpl_assets(body_models_path): | |
| if 'smpl' not in body_models_path: | |
| model_path = os.path.join(body_models_path, 'smpl') | |
| else: | |
| model_path = body_models_path | |
| target_dir = os.path.join(model_path, 'smpl') | |
| os.makedirs(target_dir, exist_ok=True) | |
| files = ["SMPL_NEUTRAL.pkl", "SMPL_MALE.pkl", "SMPL_FEMALE.pkl"] | |
| repo_id = "Murphyyyy/UniSH-Private-Assets" # <--- 修改为你的仓库 | |
| token = os.environ.get("SMPL_DOWNLOAD_TOKEN") | |
| if not token: | |
| print("❌ CRITICAL ERROR: 'SMPL_DOWNLOAD_TOKEN' not found in environment variables!") | |
| print("👉 Since 'UniSH-Private-Assets' is likely private, inference WILL fail without a token.") | |
| for filename in files: | |
| file_path = os.path.join(target_dir, filename) | |
| if not os.path.exists(file_path): | |
| try: | |
| print(f"📥 Downloading {filename} from {repo_id}...") | |
| hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| token=token, | |
| local_dir=target_dir, | |
| local_dir_use_symlinks=False | |
| ) | |
| print(f"✅ Downloaded to: {file_path}") | |
| except Exception as e: | |
| print(f"❌ Failed to download {filename}: {e}") | |
| print(f" (Check if your HF Token has access to {repo_id})") | |
| def pack_sequence_to_glb(base_dir, output_path, start_frame, end_frame, scene_rate=1.0): | |
| scene = trimesh.Scene() | |
| scene_cloud_dir = os.path.join(base_dir, "scene_clouds_per_frame") | |
| smpl_mesh_dir = os.path.join(base_dir, "smpl_meshes_per_frame") | |
| MAX_POINTS_PER_FRAME = 60000 | |
| for i in range(start_frame, end_frame): | |
| candidates = [ | |
| os.path.join(smpl_mesh_dir, f"combined_smpl_mesh_frame_{i:04d}.ply"), | |
| os.path.join(smpl_mesh_dir, f"smpl_mesh_frame_{i:04d}.ply") | |
| ] | |
| target_human_path = None | |
| for p in candidates: | |
| if os.path.exists(p): | |
| target_human_path = p | |
| break | |
| if target_human_path: | |
| try: | |
| human_mesh = trimesh.load(target_human_path) | |
| node_name = f"frame_{i}_human" | |
| scene.add_geometry(human_mesh, node_name=node_name, geom_name=node_name) | |
| except Exception: | |
| pass | |
| scene_pcd_path = os.path.join(scene_cloud_dir, f"scene_frame_{i:04d}.ply") | |
| if os.path.exists(scene_pcd_path): | |
| try: | |
| scene_pc = trimesh.load(scene_pcd_path) | |
| if hasattr(scene_pc, 'vertices') and len(scene_pc.vertices) > 0: | |
| num_points = len(scene_pc.vertices) | |
| if num_points > MAX_POINTS_PER_FRAME: | |
| choice = np.random.choice(num_points, MAX_POINTS_PER_FRAME, replace=False) | |
| scene_pc.vertices = scene_pc.vertices[choice] | |
| if hasattr(scene_pc, 'colors') and len(scene_pc.colors) > 0: | |
| scene_pc.colors = scene_pc.colors[choice] | |
| node_name = f"frame_{i}_scene" | |
| scene.add_geometry(scene_pc, node_name=node_name, geom_name=node_name) | |
| except Exception: | |
| pass | |
| if len(scene.geometry) == 0: | |
| dummy = trimesh.creation.box(extents=[0.01, 0.01, 0.01]) | |
| scene.add_geometry(dummy, node_name='dummy') | |
| scene.export(output_path) | |
| if not os.path.exists(output_path): | |
| raise FileNotFoundError(f"Export failed: {output_path}") | |
| def get_video_duration(video_path): | |
| if not video_path: return 10.0 | |
| try: | |
| cap = cv2.VideoCapture(video_path) | |
| if not cap.isOpened(): return 10.0 | |
| fps = cap.get(cv2.CAP_PROP_FPS) | |
| frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT) | |
| duration = frame_count / fps if fps > 0 else 10.0 | |
| cap.release() | |
| return duration | |
| except: | |
| return 10.0 | |
| def get_loading_html(message="Processing..."): | |
| return f""" | |
| <div style="height: 600px; width: 100%; background: #f9fafb; border-radius: 12px; border: 1px solid #e5e7eb; display: flex; flex-direction: column; align-items: center; justify-content: center; font-family: sans-serif; color: #4b5563;"> | |
| <div class="loader-ring"></div> | |
| <p style="margin-top: 20px; font-weight: 500; font-size: 1.1em; animation: pulse 2s infinite;">{message}</p> | |
| <style> | |
| .loader-ring {{ display: inline-block; width: 64px; height: 64px; }} | |
| .loader-ring:after {{ content: " "; display: block; width: 46px; height: 46px; margin: 8px; border-radius: 50%; border: 5px solid #4f46e5; border-color: #4f46e5 transparent #4f46e5 transparent; animation: ring-spin 1.2s linear infinite; }} | |
| @keyframes ring-spin {{ 0% {{ transform: rotate(0deg); }} 100% {{ transform: rotate(360deg); }} }} | |
| @keyframes pulse {{ 0% {{ opacity: 0.6; }} 50% {{ opacity: 1; }} 100% {{ opacity: 0.6; }} }} | |
| </style> | |
| </div> | |
| """ | |
| def get_player_html(abs_glb_path): | |
| def read_and_patch(filename): | |
| path = os.path.join(STATIC_DIR, filename) | |
| if not os.path.exists(path): return "" | |
| with open(path, "r", encoding="utf-8") as f: content = f.read() | |
| content = content.replace('../utils/BufferGeometryUtils.js', 'BufferGeometryUtils') | |
| content = content.replace('./BufferGeometryUtils.js', 'BufferGeometryUtils') | |
| content = content.replace('../../build/three.module.js', 'three') | |
| content = content.replace('../build/three.module.js', 'three') | |
| return content | |
| js_three = read_and_patch("three.module.js") | |
| js_orbit = read_and_patch("OrbitControls.js") | |
| js_loader = read_and_patch("GLTFLoader.js") | |
| js_buffer = read_and_patch("BufferGeometryUtils.js") | |
| def to_data_url(content, mime="text/javascript"): | |
| b64 = base64.b64encode(content.encode('utf-8')).decode('utf-8') | |
| return f"data:{mime};base64,{b64}" | |
| blob_three = to_data_url(js_three) | |
| blob_orbit = to_data_url(js_orbit) | |
| blob_loader = to_data_url(js_loader) | |
| blob_buffer = to_data_url(js_buffer) | |
| if not os.path.exists(abs_glb_path): return '<div style="color:red; padding:20px;">Error: Output file not found.</div>' | |
| with open(abs_glb_path, "rb") as f: glb_data = f.read() | |
| model_data_url = f"data:model/gltf-binary;base64,{base64.b64encode(glb_data).decode('utf-8')}" | |
| raw_html = f""" | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <meta charset="utf-8"> | |
| <style> | |
| body {{ margin: 0; background: transparent; height: 100vh; overflow: hidden; font-family: sans-serif; }} | |
| #container {{ width: 100%; height: 100%; }} | |
| .controls {{ position: absolute; bottom: 20px; left: 50%; transform: translateX(-50%); display: flex; gap: 12px; padding: 12px 20px; background: rgba(255, 255, 255, 0.9); border-radius: 30px; z-index: 100; box-shadow: 0 4px 15px rgba(0,0,0,0.15); backdrop-filter: blur(5px); align-items: center; }} | |
| button {{ padding: 8px 20px; cursor: pointer; background: #4f46e5; color: white; border: none; border-radius: 20px; font-weight: 600; font-size: 14px; transition: background 0.2s; }} | |
| button:hover {{ background: #4338ca; }} | |
| input[type=range] {{ width: 200px; cursor: pointer; accent-color: #4f46e5; }} | |
| </style> | |
| <script type="importmap"> | |
| {{ "imports": {{ "three": "{blob_three}", "three/addons/controls/OrbitControls.js": "{blob_orbit}", "three/addons/loaders/GLTFLoader.js": "{blob_loader}", "BufferGeometryUtils": "{blob_buffer}" }} }} | |
| </script> | |
| </head> | |
| <body> | |
| <div id="container"></div> | |
| <div class="controls"><button id="btn-play">Play</button><input type="range" id="slider" min="0" max="0" value="0" step="1"></div> | |
| <script type="module"> | |
| import * as THREE from 'three'; | |
| import {{ OrbitControls }} from 'three/addons/controls/OrbitControls.js'; | |
| import {{ GLTFLoader }} from 'three/addons/loaders/GLTFLoader.js'; | |
| const container = document.getElementById('container'); | |
| const scene = new THREE.Scene(); scene.background = new THREE.Color(0xf9fafb); | |
| const camera = new THREE.PerspectiveCamera(45, window.innerWidth / window.innerHeight, 0.1, 100); | |
| camera.position.set(0, 1.5, 4); | |
| const renderer = new THREE.WebGLRenderer({{ antialias: true, alpha: true }}); renderer.setSize(window.innerWidth, window.innerHeight); renderer.setPixelRatio(window.devicePixelRatio); container.appendChild(renderer.domElement); | |
| const controls = new OrbitControls(camera, renderer.domElement); controls.enableDamping = true; | |
| scene.add(new THREE.AmbientLight(0xffffff, 0.8)); const dirLight = new THREE.DirectionalLight(0xffffff, 1.2); dirLight.position.set(5, 10, 7); scene.add(dirLight); | |
| let frames = []; let currentFrame = 0; let isPlaying = false; let timer = null; | |
| new GLTFLoader().load("{model_data_url}", (gltf) => {{ | |
| gltf.scene.rotation.x = Math.PI; scene.add(gltf.scene); | |
| gltf.scene.traverse(n => {{ | |
| if(n.name && n.name.startsWith('frame_')) {{ | |
| let parts = n.name.split('_'); let idx = parseInt(parts[1]); | |
| if(!isNaN(idx)) {{ if(!frames[idx]) frames[idx] = []; frames[idx].push(n); n.visible = false; }} | |
| }} | |
| if(n.isMesh) {{ n.geometry.computeVertexNormals(); n.material = new THREE.MeshStandardMaterial({{ color: 0x6366f1, roughness: 0.4, metalness: 0.1 }}); }} | |
| if(n.isPoints) {{ | |
| let size = n.name.includes('scene') ? 0.05 : 0.005; | |
| n.material.size = size; | |
| }} | |
| }}); | |
| if(frames.length > 0) {{ document.getElementById('slider').max = frames.length - 1; showFrame(0); animate(); }} | |
| }}, undefined, (e) => console.error(e)); | |
| function showFrame(idx) {{ | |
| if(frames[currentFrame]) frames[currentFrame].forEach(o => o.visible = false); | |
| if(frames[idx]) frames[idx].forEach(o => o.visible = true); | |
| currentFrame = idx; | |
| const slider = document.getElementById('slider'); | |
| if(slider) slider.value = idx; | |
| }} | |
| function animate() {{ requestAnimationFrame(animate); controls.update(); renderer.render(scene, camera); }} | |
| document.getElementById('btn-play').onclick = () => {{ | |
| isPlaying = !isPlaying; const btn = document.getElementById('btn-play'); | |
| if(isPlaying) {{ | |
| btn.innerText = "Pause"; btn.style.background = "#ef4444"; | |
| timer = setInterval(() => {{ if(frames.length > 0) {{ let next = (currentFrame + 1) % frames.length; showFrame(next); }} }}, 166); | |
| }} else {{ clearInterval(timer); btn.innerText = "Play"; btn.style.background = "#4f46e5"; }} | |
| }}; | |
| document.getElementById('slider').oninput = (e) => {{ if(isPlaying) document.getElementById('btn-play').click(); showFrame(parseInt(e.target.value)); }}; | |
| window.onresize = () => {{ camera.aspect = window.innerWidth / window.innerHeight; camera.updateProjectionMatrix(); renderer.setSize(window.innerWidth, window.innerHeight); }}; | |
| </script> | |
| </body> | |
| </html> | |
| """ | |
| return f'<iframe srcdoc="{html.escape(raw_html)}" width="100%" height="600px" style="border:none; border-radius: 12px;"></iframe>' | |
| if os.environ.get("SPACE_ID"): | |
| from spaces import GPU | |
| gpu_decorator = GPU(duration=120) | |
| else: | |
| def gpu_decorator(func): | |
| def wrapper(*args, **kwargs): | |
| if inspect.isgeneratorfunction(func): | |
| yield from func(*args, **kwargs) | |
| else: | |
| return func(*args, **kwargs) | |
| return wrapper | |
| def predict(video_path, start_time=0.0, end_time=10.0): | |
| duration_input = end_time - start_time | |
| if duration_input > 10.0: | |
| raise gr.Error(f"Video limit exceeded ({duration_input:.1f}s). Please keep it under 10 seconds.") | |
| if start_time >= end_time: | |
| raise gr.Error("Error: End time must be greater than Start time.") | |
| yield get_loading_html("Processing...") | |
| output_dir = tempfile.mkdtemp() | |
| trimmed_video_path = os.path.join(output_dir, "input_trim.mp4") | |
| duration = end_time - start_time | |
| cmd = ["ffmpeg", "-y", "-ss", str(start_time), "-i", video_path, "-t", str(duration), "-c:v", "libx264", "-c:a", "aac", trimmed_video_path] | |
| subprocess.run(cmd, check=True) | |
| global MODEL | |
| if MODEL is None: MODEL = load_model() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| MODEL.to(device) | |
| MODEL.eval() | |
| download_smpl_assets(BODY_MODELS_PATH) | |
| data_dict = process_video(trimmed_video_path, 6.0, 0, 518, bbox_scale=1.0) | |
| results = run_inference(MODEL, data_dict, device, chunk_size=300) | |
| seq_name = results['seq_name'] | |
| viz_scene, viz_smpl, viz_scene_only, _ = generate_mixed_geometries_in_memory( | |
| results, BODY_MODELS_PATH, fps=6.0, conf_thres=0.1 | |
| ) | |
| save_smpl_meshes_per_frame(results, output_dir, BODY_MODELS_PATH) | |
| base_dir = os.path.join(output_dir, seq_name) | |
| scene_cloud_dir = os.path.join(base_dir, "scene_clouds_per_frame") | |
| os.makedirs(scene_cloud_dir, exist_ok=True) | |
| for i, pcd in enumerate(viz_scene_only): | |
| if len(pcd.points) > 0: | |
| o3d.io.write_point_cloud(os.path.join(scene_cloud_dir, f"scene_frame_{i:04d}.ply"), pcd) | |
| tmp_glb_path = os.path.join(output_dir, "output.glb") | |
| pack_sequence_to_glb(base_dir, tmp_glb_path, 0, len(viz_scene), 0.5) | |
| yield get_player_html(tmp_glb_path) | |
| examples_list = [] | |
| if os.path.exists(EXAMPLES_DIR): | |
| examples_list = [[os.path.join("examples", f)] for f in os.listdir(EXAMPLES_DIR) if f.endswith(".mp4")] | |
| js_scrub = """(val) => { | |
| var video = document.querySelector('#input-video video'); | |
| if (video) { | |
| video.currentTime = val; | |
| } | |
| return val; | |
| }""" | |
| js_reset_video = """() => { | |
| setTimeout(() => { | |
| var video = document.querySelector('#input-video video'); | |
| if (video) { | |
| video.currentTime = 0; | |
| video.pause(); | |
| } | |
| }, 200); | |
| }""" | |
| custom_css = """ | |
| footer {visibility: hidden} | |
| h1.header-title { text-align: center; font-family: 'Segoe UI', sans-serif; font-weight: 700; color: #1f2937; margin-bottom: 2rem; } | |
| .alert-box { padding: 1rem; border-radius: 0.5rem; margin-bottom: 1rem; font-size: 0.9rem; line-height: 1.5; } | |
| .info-box { background-color: #eff6ff; border-left: 4px solid #3b82f6; color: #1e40af; } | |
| .warning-box { background-color: #fefce8; border-left: 4px solid #eab308; color: #854d0e; } | |
| .tip-box { background-color: #f0fdf4; border-left: 4px solid #22c55e; color: #15803d; } | |
| .viewer-container { box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1); border-radius: 12px; overflow: hidden; border: 1px solid #e5e7eb; background: #f9fafb; } | |
| #run-btn { background: linear-gradient(135deg, #6366f1 0%, #a855f7 100%); border: none; color: white; font-weight: bold; transition: all 0.2s; } | |
| #run-btn:hover { transform: translateY(-2px); box-shadow: 0 10px 20px rgba(99, 102, 241, 0.4); } | |
| #input-video { margin-bottom: 15px !important; } | |
| """ | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", radius_size="md"), css=custom_css, title="UniSH Demo") as demo: | |
| with gr.Column(elem_classes=["header-container"]): | |
| gr.Markdown("# UniSH: Unifying Scene and Human Reconstruction in a Feed-Forward Pass", elem_classes=["header-title"]) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=3, variant="panel"): | |
| gr.Markdown("### 🛠️ Configuration") | |
| with gr.Group(): | |
| input_video = gr.Video(label="Upload Video", format="mp4", height=260, interactive=True, elem_id="input-video") | |
| with gr.Row(): | |
| start_time = gr.Slider(minimum=0, maximum=10, value=0, step=0.01, label="Start Time (s)") | |
| end_time = gr.Slider(minimum=0, maximum=10, value=3, step=0.01, label="End Time (s)") | |
| gr.HTML(""" | |
| <div class="alert-box tip-box"> | |
| <strong>💡 Use Tips:</strong> | |
| <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 10px; margin-top: 8px; font-size: 0.85em; line-height: 1.4;"> | |
| <div>1. Contain only a <strong>single person</strong>.</div> | |
| <div>2. <strong>No occlusion</strong> (self-occlusion is fine).</div> | |
| <div>3. Keep the <strong>full body</strong> mostly visible.</div> | |
| <div>4. Subject should <strong>not be too small</strong>.</div> | |
| </div> | |
| </div> | |
| """) | |
| submit_btn = gr.Button("🚀 Start Reconstruction", variant="primary", elem_id="run-btn", size="lg") | |
| if examples_list: | |
| gr.Markdown("### 🎥 Examples") | |
| gr.Examples( | |
| examples=examples_list, | |
| inputs=[input_video, start_time, end_time], | |
| label="Click to try:", | |
| cache_examples=False | |
| ) | |
| with gr.Column(scale=7): | |
| gr.Markdown("### ▶️ Interactive Results") | |
| with gr.Group(elem_classes=["viewer-container"]): | |
| output_html = gr.HTML( | |
| label="3D Viewer", | |
| min_height=600, | |
| value='<div style="height:600px; display:flex; align-items:center; justify-content:center; color:#aaa; font-family:sans-serif; background:#f9fafb;">Upload a video and click Start to view result.</div>' | |
| ) | |
| gr.HTML(""" | |
| <div class="alert-box warning-box"> | |
| <strong>⚡ Performance Note:</strong><br> | |
| Inference (feed-forward) is very fast, but generating visualization assets takes up most of the processing time. | |
| </div> | |
| <div class="alert-box info-box"> | |
| <strong>👁️ Visual Quality:</strong><br> | |
| The displayed results are downsampled for better web rendering performance. | |
| </div> | |
| """) | |
| def update_slider_range(video_path): | |
| if not video_path: | |
| return gr.update(value=0), gr.update(value=0) | |
| dur = get_video_duration(video_path) | |
| dur = round(dur, 2) | |
| return gr.update(maximum=dur, value=0), gr.update(maximum=dur, value=dur) | |
| input_video.change(fn=update_slider_range, inputs=[input_video], outputs=[start_time, end_time]) | |
| input_video.upload(fn=update_slider_range, inputs=[input_video], outputs=[start_time, end_time]) | |
| input_video.change(fn=None, inputs=[], outputs=[], js=js_reset_video) | |
| input_video.upload(fn=None, inputs=[], outputs=[], js=js_reset_video) | |
| start_time.change(fn=None, inputs=[start_time], outputs=None, js=js_scrub) | |
| end_time.change(fn=None, inputs=[end_time], outputs=None, js=js_scrub) | |
| submit_btn.click( | |
| fn=predict, | |
| inputs=[input_video, start_time, end_time], | |
| outputs=[output_html] | |
| ) | |
| demo.queue() | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| allowed_paths=[BASE_DIR, "/tmp", EXAMPLES_DIR] | |
| ) |