wop's picture
Update app.py
37adf47 verified
import gradio as gr
import cv2
import numpy as np
import plotly.graph_objects as go
import plotly.io as pio
import json
import os
from datetime import datetime
import torch
from PIL import Image
from depth_anything_3.api import DepthAnything3
# ====================== DEPTH ANYTHING 3 (CPU) ======================
MODEL_DIR = os.environ.get("DA3_MODEL_DIR", "depth-anything/DA3NESTED-GIANT-LARGE")
print(f"🔄 Loading DepthAnything3 '{MODEL_DIR}' on CPU (16GB RAM / 2 vCPU optimized)...")
model = DepthAnything3.from_pretrained(MODEL_DIR)
model.to(torch.device("cpu"))
model.eval()
torch.set_num_threads(os.cpu_count() or 4) # Use all available cores for maximum CPU speedup
print(f"✅ Depth model ready on CPU with {os.cpu_count()} threads")
def extract_depth_from_pred(pred):
"""Robust extraction - works with the exact API return style from your CPU script."""
depth_map = None
if hasattr(pred, "depth"):
depth_map = pred.depth
elif isinstance(pred, dict) and "depth" in pred:
depth_map = pred["depth"]
elif hasattr(pred, "predictions") and len(pred.predictions or []) > 0:
first = pred.predictions[0]
if hasattr(first, "depth"):
depth_map = first.depth
if isinstance(depth_map, torch.Tensor):
depth_map = depth_map.cpu().numpy()
if isinstance(depth_map, (list, tuple)) and len(depth_map) > 0:
depth_map = depth_map[0]
if isinstance(depth_map, np.ndarray):
if depth_map.ndim == 3:
if depth_map.shape[0] == 1:
depth_map = depth_map[0]
elif depth_map.shape[0] == 3:
depth_map = depth_map.mean(axis=0)
return depth_map
def get_normalized_depth(depth: np.ndarray) -> np.ndarray:
"""Normalize to [0,1] where 1 = closest (standard for Depth Anything visualization)."""
if depth is None or depth.size == 0:
return np.full((256, 256), 0.5, dtype=np.float32)
d = np.asarray(depth, dtype=np.float32)
d = np.nan_to_num(d, nan=0.0, posinf=1.0, neginf=0.0)
vmin = np.percentile(d, 1.0)
vmax = np.percentile(d, 99.0)
if vmax - vmin < 1e-6:
return np.full_like(d, 0.5)
d_norm = (d - vmin) / (vmax - vmin)
return np.clip(d_norm, 0.0, 1.0)
# ====================== MAIN FUNCTION ======================
def create_animated_point_cloud(
video,
resolution: int = 256,
depth_process_res: int = 384,
density: float = 0.25,
depth: float = 0.5, # renamed label only - still "Depth Intensity"
point_size: float = 3.0,
max_frames: int = 10,
frame_step: int = 5,
progress=gr.Progress()
):
if video is None:
return None, "Upload a short video first", None, None
cap = cv2.VideoCapture(video)
if not cap.isOpened():
return None, "Cannot open video", None, None
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
max_possible = min(max_frames, (total_frames // frame_step) + 1)
progress(0, desc="Reading video & computing real depth...")
all_points = []
processed = 0
used_real_depth = True
for i in range(0, total_frames, frame_step):
if len(all_points) >= max_frames:
break
cap.set(cv2.CAP_PROP_POS_FRAMES, i)
ret, frame = cap.read()
if not ret:
break
# RGB for coloring (always at point-cloud resolution)
small = cv2.resize(frame, (resolution, resolution))
rgb = cv2.cvtColor(small, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0
# === REAL DEPTH INFERENCE (single frame, CPU) ===
try:
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(frame_rgb)
pred = model.inference(
[pil_img],
process_res=depth_process_res,
process_res_method="upper_bound_resize",
export_format="mini_npz",
)
depth_map = extract_depth_from_pred(pred)
if depth_map is None:
raise ValueError("No depth map in prediction")
# Resize depth to match point-cloud resolution (handles non-square original frames)
if depth_map.shape[:2] != (resolution, resolution):
depth_map = cv2.resize(
depth_map.squeeze() if depth_map.ndim == 3 else depth_map,
(resolution, resolution),
interpolation=cv2.INTER_LINEAR
)
depth_norm = get_normalized_depth(depth_map)
except Exception as e:
print(f"Depth inference failed for frame {i}: {e} → falling back to grayscale")
used_real_depth = False
gray = cv2.cvtColor(small, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0
depth_norm = gray
# Sample points exactly like original (but with real depth)
mask = np.random.rand(resolution, resolution) < density
ys, xs = np.nonzero(mask)
zs = (depth_norm[mask] - 0.5) * depth * 8 # same scaling as original
xs_norm = (xs / resolution - 0.5) * 12
ys_norm = (0.5 - ys / resolution) * 12
colors = rgb[mask].tolist()
all_points.append({
'x': xs_norm.tolist(),
'y': ys_norm.tolist(),
'z': zs.tolist(),
'color': colors
})
processed += 1
progress(processed / max_possible, desc=f"Frame {processed}/{max_possible} — real depth + points")
cap.release()
if not all_points:
return None, "No frames processed", None, None
# ====================== BUILD FIGURE (unchanged) ======================
initial = all_points[0]
initial_colors = [f"rgb({int(255*r)},{int(255*g)},{int(255*b)})" for r,g,b in initial['color']]
trace = go.Scatter3d(
x=initial['x'], y=initial['y'], z=initial['z'],
mode='markers',
marker=dict(
size=point_size,
color=initial_colors,
opacity=0.85
)
)
fig = go.Figure(data=[trace])
fig.update_layout(
scene=dict(
aspectmode='cube',
xaxis_title='X',
yaxis_title='Y',
zaxis_title='Depth (real DA3)',
camera=dict(eye=dict(x=0, y=0, z=2.8), up=dict(x=0, y=1, z=0), center=dict(x=0, y=0, z=0))
),
title=f"Real-Depth Animated Point Cloud — {len(all_points)} frames",
height=650,
margin=dict(l=0, r=0, b=0, t=90),
)
points_data = []
for pts in all_points:
rgb_colors = [f"rgb({int(255*r)},{int(255*g)},{int(255*b)})" for r,g,b in pts['color']]
points_data.append({'x': pts['x'], 'y': pts['y'], 'z': pts['z'], 'color': rgb_colors})
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
html_path = f"real_depth_pointcloud_{timestamp}.html"
html_str = pio.to_html(
fig,
include_plotlyjs='cdn',
full_html=True,
default_width="100%",
default_height="650px"
)
overlay_id = f"pc_ctrl_{timestamp}"
loop_script = f"""
<script>
(function() {{
const pointsData = {json.dumps(points_data)};
let loopInterval = null;
let gd = null;
let idx = 0;
let fps = 12;
const ctrlId = "{overlay_id}";
function createOverlay() {{
if (document.getElementById(ctrlId)) return;
const wrap = document.createElement('div');
wrap.id = ctrlId;
wrap.style = "position:fixed; left:12px; top:12px; z-index:9999; background:rgba(0,0,0,0.55); padding:8px 10px; border-radius:8px; color:#fff; font-family:Arial,monospace; display:flex; gap:8px; align-items:center; flex-wrap:wrap;";
wrap.innerHTML = `
<button id="{overlay_id}_play" style="padding:6px 8px;border-radius:6px;background:#16a085;color:white;border:none;cursor:pointer;">▶ Play</button>
<button id="{overlay_id}_pause" style="padding:6px 8px;border-radius:6px;background:#7f8c8d;color:white;border:none;cursor:pointer;">⏸ Pause</button>
<label style="font-size:12px;margin-left:6px;">FPS</label>
<input id="{overlay_id}_fps" type="range" min="1" max="60" value="12" style="vertical-align:middle;">
<span id="{overlay_id}_fps_txt" style="min-width:30px; text-align:center; display:inline-block;">12</span>
<label style="font-size:12px;margin-left:8px;">Size</label>
<input id="{overlay_id}_size" type="range" min="1" max="12" step="0.5" value="{point_size}" style="vertical-align:middle;">
<span id="{overlay_id}_size_txt" style="min-width:30px; text-align:center; display:inline-block;">{point_size}</span>
<button id="{overlay_id}_resetcam" title="Reset camera" style="padding:6px 8px;border-radius:6px;background:#2d6cdf;color:white;border:none;cursor:pointer;">↺ Cam</button>
`;
document.body.appendChild(wrap);
}}
function findPlotlyDiv() {{
const direct = document.querySelector('.js-plotly-plot');
if (direct) return direct;
const anyDiv = document.querySelector('[data-plotly]') || document.querySelector('.plotly-graph-div');
if (anyDiv) return anyDiv;
return null;
}}
function waitForGd(cb) {{
const existing = findPlotlyDiv();
if (existing) return cb(existing);
const mo = new MutationObserver((mut, obs) => {{
const d = findPlotlyDiv();
if (d) {{
obs.disconnect();
cb(d);
}}
}});
mo.observe(document.body, {{ childList:true, subtree:true }});
setTimeout(() => {{
const d = findPlotlyDiv();
if (!d) {{
try {{ mo.disconnect(); }} catch(e){{}}
cb(null);
}}
}}, 8000);
}}
function updateFrame(i, customSize = null) {{
if (!gd) return;
const p = pointsData[i];
const sizeToUse = customSize !== null ? customSize : {point_size};
try {{
Plotly.restyle(gd, {{
x: [p.x],
y: [p.y],
z: [p.z],
'marker.color': [p.color],
'marker.size': [sizeToUse]
}}, [0]);
}} catch (e) {{
console.error('restyle failed', e);
}}
}}
function startLoop() {{
if (loopInterval) return;
const fpsSlider = document.getElementById('{overlay_id}_fps');
const sizeSlider = document.getElementById('{overlay_id}_size');
const val = parseInt(fpsSlider?.value || 12);
const delay = Math.max(1, Math.round(1000 / val));
idx = (idx + 1) % pointsData.length;
updateFrame(idx, parseFloat(sizeSlider?.value || {point_size}));
loopInterval = setInterval(() => {{
idx = (idx + 1) % pointsData.length;
updateFrame(idx, parseFloat(sizeSlider?.value || {point_size}));
}}, delay);
}}
function pauseLoop() {{
if (loopInterval) {{
clearInterval(loopInterval);
loopInterval = null;
}}
}}
function resetCamera() {{
if (!gd) return;
try {{
Plotly.relayout(gd, {{ 'scene.camera.eye': {{x:0,y:0,z:2.8}} }});
}} catch(e) {{}}
}}
function bindOverlay() {{
createOverlay();
const playBtn = document.getElementById('{overlay_id}_play');
const pauseBtn = document.getElementById('{overlay_id}_pause');
const fpsSlider = document.getElementById('{overlay_id}_fps');
const fpsTxt = document.getElementById('{overlay_id}_fps_txt');
const sizeSlider= document.getElementById('{overlay_id}_size');
const sizeTxt = document.getElementById('{overlay_id}_size_txt');
const resetBtn = document.getElementById('{overlay_id}_resetcam');
playBtn.onclick = () => startLoop();
pauseBtn.onclick = () => pauseLoop();
resetBtn.onclick = () => resetCamera();
fpsSlider.oninput = (e) => {{
const v = e.target.value;
fpsTxt.innerText = v;
if (loopInterval) {{ pauseLoop(); startLoop(); }}
}};
sizeSlider.oninput = (e) => {{
const v = parseFloat(e.target.value);
sizeTxt.innerText = v.toFixed(1);
updateFrame(idx, v);
}};
}}
function init() {{
bindOverlay();
waitForGd(function(div) {{
if (!div) return;
gd = div;
updateFrame(0);
try {{
const ro = new ResizeObserver(() => Plotly.Plots.resize(gd));
ro.observe(gd);
}} catch(e) {{}}
window.addEventListener('keydown', (ev) => {{
if (ev.code === 'Space') {{
ev.preventDefault();
if (loopInterval) pauseLoop(); else startLoop();
}}
}});
}});
}}
window.addEventListener('load', init);
window.addEventListener('beforeunload', () => {{ if (loopInterval) clearInterval(loopInterval); }});
}})();
</script>
"""
html_str = html_str.replace('</body>', loop_script + '</body>')
with open(html_path, 'w', encoding='utf-8') as f:
f.write(html_str)
if os.environ.get("SPACE_ID"):
space_url = f"https://{os.environ['SPACE_ID']}.hf.space"
iframe_src = f"{space_url}/file={html_path}"
else:
iframe_src = f"/file={html_path}"
iframe_html = f'''
<iframe
src="{iframe_src}"
width="100%"
height="720"
style="border:none; border-radius:8px; box-shadow:0 4px 20px rgba(0,0,0,0.1);"
allowfullscreen
sandbox="allow-scripts allow-same-origin allow-popups">
</iframe>
'''
preview_fig = go.Figure(data=[trace])
preview_fig.update_layout(scene=fig.layout.scene, height=500)
depth_str = "Real Depth (Depth Anything 3 Giant)" if used_real_depth else "Grayscale fallback"
status = f"✅ Done! {len(all_points)} frames • {depth_str} • Intensity: {depth} • Download: {html_path}"
return preview_fig, status, iframe_html, html_path
# ====================== GRADIO INTERFACE ======================
demo = gr.Interface(
fn=create_animated_point_cloud,
inputs=[
gr.Video(label="Upload Video (short = faster)"),
gr.Slider(128, 1024, value=256, step=64, label="Point Cloud Resolution"),
gr.Slider(128, 768, value=384, step=64, label="Depth Process Resolution (higher = sharper depth, slower)"),
gr.Slider(0.05, 0.5, value=0.25, step=0.05, label="Point Density"),
gr.Slider(0.2, 1.5, value=0.5, step=0.1, label="Depth Intensity"),
gr.Slider(0.2, 12, value=3.0, step=0.2, label="Point Size"),
gr.Slider(4, 65, value=10, step=1, label="Max Frames"),
gr.Slider(2, 10, value=5, step=1, label="Frame Step (higher = faster)")
],
outputs=[
gr.Plot(label="Static Preview (first frame)"),
gr.Textbox(label="Status"),
gr.HTML(label="🎥 Live Interactive 3D Animation (real depth)"),
gr.File(label="↓ Download HTML (offline use)")
],
title="Video → Real-Depth Animated 3D Point Cloud (Depth Anything 3 on CPU)",
description="""Uses the **real Depth Anything 3 Giant model** (CPU) instead of fake brightness.
Fully optimized for 16GB RAM / 2 vCPU: model loaded once, all CPU threads used, batch-friendly inference path ready.
Live controls + offline HTML download. Short videos recommended (<30s).""",
flagging_mode="never"
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)