UniSH / app.py
murphylmf
update
cb911dd
import subprocess
import gradio as gr
import os
import time
import requests
import spaces
import sys
import shutil
import tempfile
import torch
import cv2
import subprocess
import numpy as np
import trimesh
import open3d as o3d
from huggingface_hub import hf_hub_download
import html
import base64
import inspect
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
STATIC_DIR = os.path.join(BASE_DIR, "static")
EXAMPLES_DIR = os.path.join(BASE_DIR, "examples")
def prepare_local_assets():
os.makedirs(STATIC_DIR, exist_ok=True)
base_url = "https://registry.npmmirror.com/three/0.160.0/files"
assets = {
"three.module.js": f"{base_url}/build/three.module.js",
"OrbitControls.js": f"{base_url}/examples/jsm/controls/OrbitControls.js",
"GLTFLoader.js": f"{base_url}/examples/jsm/loaders/GLTFLoader.js",
"BufferGeometryUtils.js": f"{base_url}/examples/jsm/utils/BufferGeometryUtils.js"
}
for name, url in assets.items():
path = os.path.join(STATIC_DIR, name)
if not os.path.exists(path):
try:
r = requests.get(url, verify=False, timeout=10)
if r.status_code == 200:
with open(path, "wb") as f:
f.write(r.content)
except Exception as e:
print(f"Error downloading {name}: {e}")
prepare_local_assets()
def install_pytorch3d():
try:
import pytorch3d
print("✅ PyTorch3D already installed.")
return
except ImportError:
print("⏳ PyTorch3D not found. Starting dynamic installation...")
pyt_version_str = torch.__version__.split("+")[0].replace(".", "")
version_str = "".join([
f"py3{sys.version_info.minor}_",
f"cu{torch.version.cuda.replace('.', '')}_",
f"pyt{pyt_version_str}"
])
whl_url = f"https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/pytorch3d-0.7.8-cp3{sys.version_info.minor}-cp3{sys.version_info.minor}-linux_x86_64.whl"
print(f"🔍 Detected Env: {version_str}")
print(f"⬇️ Attempting to install from Wheel: {whl_url}")
try:
subprocess.run([sys.executable, "-m", "pip", "install", whl_url], check=True)
print("✅ PyTorch3D installed via Wheel!")
except subprocess.CalledProcessError:
print("⚠️ Wheel installation failed (maybe version mismatch).")
print("🔨 Falling back to source compilation (this will take a few minutes)...")
subprocess.run(
[sys.executable, "-m", "pip", "install", "--no-build-isolation", "git+https://github.com/facebookresearch/pytorch3d.git@stable"],
check=True
)
print("✅ PyTorch3D installed via Source Build!")
install_pytorch3d()
def install_mmcv():
try:
import mmcv
print(f"✅ mmcv {mmcv.__version__} is already installed.")
return
except ImportError:
print("⏳ mmcv not found. Starting dynamic installation...")
# 1. Detect environment versions to construct the URL dynamically
# Example: CUDA 12.1 -> "121", Torch 2.4.0 -> "2.4"
cuda_ver = torch.version.cuda.replace(".", "")
torch_ver = ".".join(torch.__version__.split(".")[:2])
# 2. Construct the find-links URL matching OpenMMLab's structure
# Structure: https://download.openmmlab.com/mmcv/dist/cu{CUDA}/torch{TORCH}/index.html
find_links_url = f"https://download.openmmlab.com/mmcv/dist/cu{cuda_ver}/torch{torch_ver}/index.html"
print(f"🔍 Detected Env: CUDA={cuda_ver}, Torch={torch_ver}")
print(f"⬇️ Installing mmcv==2.2.0 from: {find_links_url}")
try:
# 3. Run pip install with the specific version and dynamic link
subprocess.run([
sys.executable, "-m", "pip", "install",
"mmcv==2.2.0",
"--find-links", find_links_url
], check=True)
print("✅ mmcv installed successfully.")
except subprocess.CalledProcessError:
print("⚠️ Installation failed. The specific version might not exist for this environment.")
print("🔄 Attempting fallback using openmim (auto-resolve mode)...")
# Fallback: Install openmim and let it handle the resolution
subprocess.run([sys.executable, "-m", "pip", "install", "openmim"], check=True)
subprocess.run(["mim", "install", "mmcv==2.2.0"], check=True)
def install_sam2():
try:
import sam2
except ImportError:
print("Installing SAM 2 with patch...")
subprocess.run(["git", "clone", "https://github.com/facebookresearch/segment-anything-2.git", "_tmp_sam2"], check=True)
setup_path = "_tmp_sam2/setup.py"
with open(setup_path, "r") as f:
content = f.read()
content = content.replace("torch>=2.5.1", "torch>=2.4.1")
with open(setup_path, "w") as f:
f.write(content)
subprocess.run(["pip", "install", "--no-build-isolation", "--no-deps", "-v", "."], cwd="_tmp_sam2", check=True)
shutil.rmtree("_tmp_sam2")
install_sam2()
sys.path.append(BASE_DIR)
from unish.utils.inference_utils import (
load_model, process_video, run_inference,
generate_mixed_geometries_in_memory,
save_smpl_meshes_per_frame
)
MODEL = None
BODY_MODELS_PATH = "body_models/"
# ==========================================
# 4. 辅助函数
# ==========================================
def download_smpl_assets(body_models_path):
if 'smpl' not in body_models_path:
model_path = os.path.join(body_models_path, 'smpl')
else:
model_path = body_models_path
target_dir = os.path.join(model_path, 'smpl')
os.makedirs(target_dir, exist_ok=True)
files = ["SMPL_NEUTRAL.pkl", "SMPL_MALE.pkl", "SMPL_FEMALE.pkl"]
repo_id = "Murphyyyy/UniSH-Private-Assets" # <--- 修改为你的仓库
token = os.environ.get("SMPL_DOWNLOAD_TOKEN")
if not token:
print("❌ CRITICAL ERROR: 'SMPL_DOWNLOAD_TOKEN' not found in environment variables!")
print("👉 Since 'UniSH-Private-Assets' is likely private, inference WILL fail without a token.")
for filename in files:
file_path = os.path.join(target_dir, filename)
if not os.path.exists(file_path):
try:
print(f"📥 Downloading {filename} from {repo_id}...")
hf_hub_download(
repo_id=repo_id,
filename=filename,
token=token,
local_dir=target_dir,
local_dir_use_symlinks=False
)
print(f"✅ Downloaded to: {file_path}")
except Exception as e:
print(f"❌ Failed to download {filename}: {e}")
print(f" (Check if your HF Token has access to {repo_id})")
def pack_sequence_to_glb(base_dir, output_path, start_frame, end_frame, scene_rate=1.0):
scene = trimesh.Scene()
scene_cloud_dir = os.path.join(base_dir, "scene_clouds_per_frame")
smpl_mesh_dir = os.path.join(base_dir, "smpl_meshes_per_frame")
MAX_POINTS_PER_FRAME = 60000
for i in range(start_frame, end_frame):
candidates = [
os.path.join(smpl_mesh_dir, f"combined_smpl_mesh_frame_{i:04d}.ply"),
os.path.join(smpl_mesh_dir, f"smpl_mesh_frame_{i:04d}.ply")
]
target_human_path = None
for p in candidates:
if os.path.exists(p):
target_human_path = p
break
if target_human_path:
try:
human_mesh = trimesh.load(target_human_path)
node_name = f"frame_{i}_human"
scene.add_geometry(human_mesh, node_name=node_name, geom_name=node_name)
except Exception:
pass
scene_pcd_path = os.path.join(scene_cloud_dir, f"scene_frame_{i:04d}.ply")
if os.path.exists(scene_pcd_path):
try:
scene_pc = trimesh.load(scene_pcd_path)
if hasattr(scene_pc, 'vertices') and len(scene_pc.vertices) > 0:
num_points = len(scene_pc.vertices)
if num_points > MAX_POINTS_PER_FRAME:
choice = np.random.choice(num_points, MAX_POINTS_PER_FRAME, replace=False)
scene_pc.vertices = scene_pc.vertices[choice]
if hasattr(scene_pc, 'colors') and len(scene_pc.colors) > 0:
scene_pc.colors = scene_pc.colors[choice]
node_name = f"frame_{i}_scene"
scene.add_geometry(scene_pc, node_name=node_name, geom_name=node_name)
except Exception:
pass
if len(scene.geometry) == 0:
dummy = trimesh.creation.box(extents=[0.01, 0.01, 0.01])
scene.add_geometry(dummy, node_name='dummy')
scene.export(output_path)
if not os.path.exists(output_path):
raise FileNotFoundError(f"Export failed: {output_path}")
def get_video_duration(video_path):
if not video_path: return 10.0
try:
cap = cv2.VideoCapture(video_path)
if not cap.isOpened(): return 10.0
fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
duration = frame_count / fps if fps > 0 else 10.0
cap.release()
return duration
except:
return 10.0
def get_loading_html(message="Processing..."):
return f"""
<div style="height: 600px; width: 100%; background: #f9fafb; border-radius: 12px; border: 1px solid #e5e7eb; display: flex; flex-direction: column; align-items: center; justify-content: center; font-family: sans-serif; color: #4b5563;">
<div class="loader-ring"></div>
<p style="margin-top: 20px; font-weight: 500; font-size: 1.1em; animation: pulse 2s infinite;">{message}</p>
<style>
.loader-ring {{ display: inline-block; width: 64px; height: 64px; }}
.loader-ring:after {{ content: " "; display: block; width: 46px; height: 46px; margin: 8px; border-radius: 50%; border: 5px solid #4f46e5; border-color: #4f46e5 transparent #4f46e5 transparent; animation: ring-spin 1.2s linear infinite; }}
@keyframes ring-spin {{ 0% {{ transform: rotate(0deg); }} 100% {{ transform: rotate(360deg); }} }}
@keyframes pulse {{ 0% {{ opacity: 0.6; }} 50% {{ opacity: 1; }} 100% {{ opacity: 0.6; }} }}
</style>
</div>
"""
def get_player_html(abs_glb_path):
def read_and_patch(filename):
path = os.path.join(STATIC_DIR, filename)
if not os.path.exists(path): return ""
with open(path, "r", encoding="utf-8") as f: content = f.read()
content = content.replace('../utils/BufferGeometryUtils.js', 'BufferGeometryUtils')
content = content.replace('./BufferGeometryUtils.js', 'BufferGeometryUtils')
content = content.replace('../../build/three.module.js', 'three')
content = content.replace('../build/three.module.js', 'three')
return content
js_three = read_and_patch("three.module.js")
js_orbit = read_and_patch("OrbitControls.js")
js_loader = read_and_patch("GLTFLoader.js")
js_buffer = read_and_patch("BufferGeometryUtils.js")
def to_data_url(content, mime="text/javascript"):
b64 = base64.b64encode(content.encode('utf-8')).decode('utf-8')
return f"data:{mime};base64,{b64}"
blob_three = to_data_url(js_three)
blob_orbit = to_data_url(js_orbit)
blob_loader = to_data_url(js_loader)
blob_buffer = to_data_url(js_buffer)
if not os.path.exists(abs_glb_path): return '<div style="color:red; padding:20px;">Error: Output file not found.</div>'
with open(abs_glb_path, "rb") as f: glb_data = f.read()
model_data_url = f"data:model/gltf-binary;base64,{base64.b64encode(glb_data).decode('utf-8')}"
raw_html = f"""
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<style>
body {{ margin: 0; background: transparent; height: 100vh; overflow: hidden; font-family: sans-serif; }}
#container {{ width: 100%; height: 100%; }}
.controls {{ position: absolute; bottom: 20px; left: 50%; transform: translateX(-50%); display: flex; gap: 12px; padding: 12px 20px; background: rgba(255, 255, 255, 0.9); border-radius: 30px; z-index: 100; box-shadow: 0 4px 15px rgba(0,0,0,0.15); backdrop-filter: blur(5px); align-items: center; }}
button {{ padding: 8px 20px; cursor: pointer; background: #4f46e5; color: white; border: none; border-radius: 20px; font-weight: 600; font-size: 14px; transition: background 0.2s; }}
button:hover {{ background: #4338ca; }}
input[type=range] {{ width: 200px; cursor: pointer; accent-color: #4f46e5; }}
</style>
<script type="importmap">
{{ "imports": {{ "three": "{blob_three}", "three/addons/controls/OrbitControls.js": "{blob_orbit}", "three/addons/loaders/GLTFLoader.js": "{blob_loader}", "BufferGeometryUtils": "{blob_buffer}" }} }}
</script>
</head>
<body>
<div id="container"></div>
<div class="controls"><button id="btn-play">Play</button><input type="range" id="slider" min="0" max="0" value="0" step="1"></div>
<script type="module">
import * as THREE from 'three';
import {{ OrbitControls }} from 'three/addons/controls/OrbitControls.js';
import {{ GLTFLoader }} from 'three/addons/loaders/GLTFLoader.js';
const container = document.getElementById('container');
const scene = new THREE.Scene(); scene.background = new THREE.Color(0xf9fafb);
const camera = new THREE.PerspectiveCamera(45, window.innerWidth / window.innerHeight, 0.1, 100);
camera.position.set(0, 1.5, 4);
const renderer = new THREE.WebGLRenderer({{ antialias: true, alpha: true }}); renderer.setSize(window.innerWidth, window.innerHeight); renderer.setPixelRatio(window.devicePixelRatio); container.appendChild(renderer.domElement);
const controls = new OrbitControls(camera, renderer.domElement); controls.enableDamping = true;
scene.add(new THREE.AmbientLight(0xffffff, 0.8)); const dirLight = new THREE.DirectionalLight(0xffffff, 1.2); dirLight.position.set(5, 10, 7); scene.add(dirLight);
let frames = []; let currentFrame = 0; let isPlaying = false; let timer = null;
new GLTFLoader().load("{model_data_url}", (gltf) => {{
gltf.scene.rotation.x = Math.PI; scene.add(gltf.scene);
gltf.scene.traverse(n => {{
if(n.name && n.name.startsWith('frame_')) {{
let parts = n.name.split('_'); let idx = parseInt(parts[1]);
if(!isNaN(idx)) {{ if(!frames[idx]) frames[idx] = []; frames[idx].push(n); n.visible = false; }}
}}
if(n.isMesh) {{ n.geometry.computeVertexNormals(); n.material = new THREE.MeshStandardMaterial({{ color: 0x6366f1, roughness: 0.4, metalness: 0.1 }}); }}
if(n.isPoints) {{
let size = n.name.includes('scene') ? 0.05 : 0.005;
n.material.size = size;
}}
}});
if(frames.length > 0) {{ document.getElementById('slider').max = frames.length - 1; showFrame(0); animate(); }}
}}, undefined, (e) => console.error(e));
function showFrame(idx) {{
if(frames[currentFrame]) frames[currentFrame].forEach(o => o.visible = false);
if(frames[idx]) frames[idx].forEach(o => o.visible = true);
currentFrame = idx;
const slider = document.getElementById('slider');
if(slider) slider.value = idx;
}}
function animate() {{ requestAnimationFrame(animate); controls.update(); renderer.render(scene, camera); }}
document.getElementById('btn-play').onclick = () => {{
isPlaying = !isPlaying; const btn = document.getElementById('btn-play');
if(isPlaying) {{
btn.innerText = "Pause"; btn.style.background = "#ef4444";
timer = setInterval(() => {{ if(frames.length > 0) {{ let next = (currentFrame + 1) % frames.length; showFrame(next); }} }}, 166);
}} else {{ clearInterval(timer); btn.innerText = "Play"; btn.style.background = "#4f46e5"; }}
}};
document.getElementById('slider').oninput = (e) => {{ if(isPlaying) document.getElementById('btn-play').click(); showFrame(parseInt(e.target.value)); }};
window.onresize = () => {{ camera.aspect = window.innerWidth / window.innerHeight; camera.updateProjectionMatrix(); renderer.setSize(window.innerWidth, window.innerHeight); }};
</script>
</body>
</html>
"""
return f'<iframe srcdoc="{html.escape(raw_html)}" width="100%" height="600px" style="border:none; border-radius: 12px;"></iframe>'
if os.environ.get("SPACE_ID"):
from spaces import GPU
gpu_decorator = GPU(duration=120)
else:
def gpu_decorator(func):
def wrapper(*args, **kwargs):
if inspect.isgeneratorfunction(func):
yield from func(*args, **kwargs)
else:
return func(*args, **kwargs)
return wrapper
@gpu_decorator
def predict(video_path, start_time=0.0, end_time=10.0):
duration_input = end_time - start_time
if duration_input > 10.0:
raise gr.Error(f"Video limit exceeded ({duration_input:.1f}s). Please keep it under 10 seconds.")
if start_time >= end_time:
raise gr.Error("Error: End time must be greater than Start time.")
yield get_loading_html("Processing...")
output_dir = tempfile.mkdtemp()
trimmed_video_path = os.path.join(output_dir, "input_trim.mp4")
duration = end_time - start_time
cmd = ["ffmpeg", "-y", "-ss", str(start_time), "-i", video_path, "-t", str(duration), "-c:v", "libx264", "-c:a", "aac", trimmed_video_path]
subprocess.run(cmd, check=True)
global MODEL
if MODEL is None: MODEL = load_model()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
MODEL.to(device)
MODEL.eval()
download_smpl_assets(BODY_MODELS_PATH)
data_dict = process_video(trimmed_video_path, 6.0, 0, 518, bbox_scale=1.0)
results = run_inference(MODEL, data_dict, device, chunk_size=300)
seq_name = results['seq_name']
viz_scene, viz_smpl, viz_scene_only, _ = generate_mixed_geometries_in_memory(
results, BODY_MODELS_PATH, fps=6.0, conf_thres=0.1
)
save_smpl_meshes_per_frame(results, output_dir, BODY_MODELS_PATH)
base_dir = os.path.join(output_dir, seq_name)
scene_cloud_dir = os.path.join(base_dir, "scene_clouds_per_frame")
os.makedirs(scene_cloud_dir, exist_ok=True)
for i, pcd in enumerate(viz_scene_only):
if len(pcd.points) > 0:
o3d.io.write_point_cloud(os.path.join(scene_cloud_dir, f"scene_frame_{i:04d}.ply"), pcd)
tmp_glb_path = os.path.join(output_dir, "output.glb")
pack_sequence_to_glb(base_dir, tmp_glb_path, 0, len(viz_scene), 0.5)
yield get_player_html(tmp_glb_path)
examples_list = []
if os.path.exists(EXAMPLES_DIR):
examples_list = [[os.path.join("examples", f)] for f in os.listdir(EXAMPLES_DIR) if f.endswith(".mp4")]
js_scrub = """(val) => {
var video = document.querySelector('#input-video video');
if (video) {
video.currentTime = val;
}
return val;
}"""
js_reset_video = """() => {
setTimeout(() => {
var video = document.querySelector('#input-video video');
if (video) {
video.currentTime = 0;
video.pause();
}
}, 200);
}"""
custom_css = """
footer {visibility: hidden}
h1.header-title { text-align: center; font-family: 'Segoe UI', sans-serif; font-weight: 700; color: #1f2937; margin-bottom: 2rem; }
.alert-box { padding: 1rem; border-radius: 0.5rem; margin-bottom: 1rem; font-size: 0.9rem; line-height: 1.5; }
.info-box { background-color: #eff6ff; border-left: 4px solid #3b82f6; color: #1e40af; }
.warning-box { background-color: #fefce8; border-left: 4px solid #eab308; color: #854d0e; }
.tip-box { background-color: #f0fdf4; border-left: 4px solid #22c55e; color: #15803d; }
.viewer-container { box-shadow: 0 10px 15px -3px rgba(0, 0, 0, 0.1); border-radius: 12px; overflow: hidden; border: 1px solid #e5e7eb; background: #f9fafb; }
#run-btn { background: linear-gradient(135deg, #6366f1 0%, #a855f7 100%); border: none; color: white; font-weight: bold; transition: all 0.2s; }
#run-btn:hover { transform: translateY(-2px); box-shadow: 0 10px 20px rgba(99, 102, 241, 0.4); }
#input-video { margin-bottom: 15px !important; }
"""
with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo", radius_size="md"), css=custom_css, title="UniSH Demo") as demo:
with gr.Column(elem_classes=["header-container"]):
gr.Markdown("# UniSH: Unifying Scene and Human Reconstruction in a Feed-Forward Pass", elem_classes=["header-title"])
with gr.Row(equal_height=False):
with gr.Column(scale=3, variant="panel"):
gr.Markdown("### 🛠️ Configuration")
with gr.Group():
input_video = gr.Video(label="Upload Video", format="mp4", height=260, interactive=True, elem_id="input-video")
with gr.Row():
start_time = gr.Slider(minimum=0, maximum=10, value=0, step=0.01, label="Start Time (s)")
end_time = gr.Slider(minimum=0, maximum=10, value=3, step=0.01, label="End Time (s)")
gr.HTML("""
<div class="alert-box tip-box">
<strong>💡 Use Tips:</strong>
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 10px; margin-top: 8px; font-size: 0.85em; line-height: 1.4;">
<div>1. Contain only a <strong>single person</strong>.</div>
<div>2. <strong>No occlusion</strong> (self-occlusion is fine).</div>
<div>3. Keep the <strong>full body</strong> mostly visible.</div>
<div>4. Subject should <strong>not be too small</strong>.</div>
</div>
</div>
""")
submit_btn = gr.Button("🚀 Start Reconstruction", variant="primary", elem_id="run-btn", size="lg")
if examples_list:
gr.Markdown("### 🎥 Examples")
gr.Examples(
examples=examples_list,
inputs=[input_video, start_time, end_time],
label="Click to try:",
cache_examples=False
)
with gr.Column(scale=7):
gr.Markdown("### ▶️ Interactive Results")
with gr.Group(elem_classes=["viewer-container"]):
output_html = gr.HTML(
label="3D Viewer",
min_height=600,
value='<div style="height:600px; display:flex; align-items:center; justify-content:center; color:#aaa; font-family:sans-serif; background:#f9fafb;">Upload a video and click Start to view result.</div>'
)
gr.HTML("""
<div class="alert-box warning-box">
<strong>⚡ Performance Note:</strong><br>
Inference (feed-forward) is very fast, but generating visualization assets takes up most of the processing time.
</div>
<div class="alert-box info-box">
<strong>👁️ Visual Quality:</strong><br>
The displayed results are downsampled for better web rendering performance.
</div>
""")
def update_slider_range(video_path):
if not video_path:
return gr.update(value=0), gr.update(value=0)
dur = get_video_duration(video_path)
dur = round(dur, 2)
return gr.update(maximum=dur, value=0), gr.update(maximum=dur, value=dur)
input_video.change(fn=update_slider_range, inputs=[input_video], outputs=[start_time, end_time])
input_video.upload(fn=update_slider_range, inputs=[input_video], outputs=[start_time, end_time])
input_video.change(fn=None, inputs=[], outputs=[], js=js_reset_video)
input_video.upload(fn=None, inputs=[], outputs=[], js=js_reset_video)
start_time.change(fn=None, inputs=[start_time], outputs=None, js=js_scrub)
end_time.change(fn=None, inputs=[end_time], outputs=None, js=js_scrub)
submit_btn.click(
fn=predict,
inputs=[input_video, start_time, end_time],
outputs=[output_html]
)
demo.queue()
demo.launch(
server_name="0.0.0.0",
server_port=7860,
allowed_paths=[BASE_DIR, "/tmp", EXAMPLES_DIR]
)