Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import cv2 | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| import plotly.io as pio | |
| import json | |
| import os | |
| from datetime import datetime | |
| import torch | |
| from PIL import Image | |
| from depth_anything_3.api import DepthAnything3 | |
| # ====================== DEPTH ANYTHING 3 (CPU) ====================== | |
| MODEL_DIR = os.environ.get("DA3_MODEL_DIR", "depth-anything/DA3NESTED-GIANT-LARGE") | |
| print(f"🔄 Loading DepthAnything3 '{MODEL_DIR}' on CPU (16GB RAM / 2 vCPU optimized)...") | |
| model = DepthAnything3.from_pretrained(MODEL_DIR) | |
| model.to(torch.device("cpu")) | |
| model.eval() | |
| torch.set_num_threads(os.cpu_count() or 4) # Use all available cores for maximum CPU speedup | |
| print(f"✅ Depth model ready on CPU with {os.cpu_count()} threads") | |
| def extract_depth_from_pred(pred): | |
| """Robust extraction - works with the exact API return style from your CPU script.""" | |
| depth_map = None | |
| if hasattr(pred, "depth"): | |
| depth_map = pred.depth | |
| elif isinstance(pred, dict) and "depth" in pred: | |
| depth_map = pred["depth"] | |
| elif hasattr(pred, "predictions") and len(pred.predictions or []) > 0: | |
| first = pred.predictions[0] | |
| if hasattr(first, "depth"): | |
| depth_map = first.depth | |
| if isinstance(depth_map, torch.Tensor): | |
| depth_map = depth_map.cpu().numpy() | |
| if isinstance(depth_map, (list, tuple)) and len(depth_map) > 0: | |
| depth_map = depth_map[0] | |
| if isinstance(depth_map, np.ndarray): | |
| if depth_map.ndim == 3: | |
| if depth_map.shape[0] == 1: | |
| depth_map = depth_map[0] | |
| elif depth_map.shape[0] == 3: | |
| depth_map = depth_map.mean(axis=0) | |
| return depth_map | |
| def get_normalized_depth(depth: np.ndarray) -> np.ndarray: | |
| """Normalize to [0,1] where 1 = closest (standard for Depth Anything visualization).""" | |
| if depth is None or depth.size == 0: | |
| return np.full((256, 256), 0.5, dtype=np.float32) | |
| d = np.asarray(depth, dtype=np.float32) | |
| d = np.nan_to_num(d, nan=0.0, posinf=1.0, neginf=0.0) | |
| vmin = np.percentile(d, 1.0) | |
| vmax = np.percentile(d, 99.0) | |
| if vmax - vmin < 1e-6: | |
| return np.full_like(d, 0.5) | |
| d_norm = (d - vmin) / (vmax - vmin) | |
| return np.clip(d_norm, 0.0, 1.0) | |
| # ====================== MAIN FUNCTION ====================== | |
| def create_animated_point_cloud( | |
| video, | |
| resolution: int = 256, | |
| depth_process_res: int = 384, | |
| density: float = 0.25, | |
| depth: float = 0.5, # renamed label only - still "Depth Intensity" | |
| point_size: float = 3.0, | |
| max_frames: int = 10, | |
| frame_step: int = 5, | |
| progress=gr.Progress() | |
| ): | |
| if video is None: | |
| return None, "Upload a short video first", None, None | |
| cap = cv2.VideoCapture(video) | |
| if not cap.isOpened(): | |
| return None, "Cannot open video", None, None | |
| total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) | |
| max_possible = min(max_frames, (total_frames // frame_step) + 1) | |
| progress(0, desc="Reading video & computing real depth...") | |
| all_points = [] | |
| processed = 0 | |
| used_real_depth = True | |
| for i in range(0, total_frames, frame_step): | |
| if len(all_points) >= max_frames: | |
| break | |
| cap.set(cv2.CAP_PROP_POS_FRAMES, i) | |
| ret, frame = cap.read() | |
| if not ret: | |
| break | |
| # RGB for coloring (always at point-cloud resolution) | |
| small = cv2.resize(frame, (resolution, resolution)) | |
| rgb = cv2.cvtColor(small, cv2.COLOR_BGR2RGB).astype(np.float32) / 255.0 | |
| # === REAL DEPTH INFERENCE (single frame, CPU) === | |
| try: | |
| frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) | |
| pil_img = Image.fromarray(frame_rgb) | |
| pred = model.inference( | |
| [pil_img], | |
| process_res=depth_process_res, | |
| process_res_method="upper_bound_resize", | |
| export_format="mini_npz", | |
| ) | |
| depth_map = extract_depth_from_pred(pred) | |
| if depth_map is None: | |
| raise ValueError("No depth map in prediction") | |
| # Resize depth to match point-cloud resolution (handles non-square original frames) | |
| if depth_map.shape[:2] != (resolution, resolution): | |
| depth_map = cv2.resize( | |
| depth_map.squeeze() if depth_map.ndim == 3 else depth_map, | |
| (resolution, resolution), | |
| interpolation=cv2.INTER_LINEAR | |
| ) | |
| depth_norm = get_normalized_depth(depth_map) | |
| except Exception as e: | |
| print(f"Depth inference failed for frame {i}: {e} → falling back to grayscale") | |
| used_real_depth = False | |
| gray = cv2.cvtColor(small, cv2.COLOR_BGR2GRAY).astype(np.float32) / 255.0 | |
| depth_norm = gray | |
| # Sample points exactly like original (but with real depth) | |
| mask = np.random.rand(resolution, resolution) < density | |
| ys, xs = np.nonzero(mask) | |
| zs = (depth_norm[mask] - 0.5) * depth * 8 # same scaling as original | |
| xs_norm = (xs / resolution - 0.5) * 12 | |
| ys_norm = (0.5 - ys / resolution) * 12 | |
| colors = rgb[mask].tolist() | |
| all_points.append({ | |
| 'x': xs_norm.tolist(), | |
| 'y': ys_norm.tolist(), | |
| 'z': zs.tolist(), | |
| 'color': colors | |
| }) | |
| processed += 1 | |
| progress(processed / max_possible, desc=f"Frame {processed}/{max_possible} — real depth + points") | |
| cap.release() | |
| if not all_points: | |
| return None, "No frames processed", None, None | |
| # ====================== BUILD FIGURE (unchanged) ====================== | |
| initial = all_points[0] | |
| initial_colors = [f"rgb({int(255*r)},{int(255*g)},{int(255*b)})" for r,g,b in initial['color']] | |
| trace = go.Scatter3d( | |
| x=initial['x'], y=initial['y'], z=initial['z'], | |
| mode='markers', | |
| marker=dict( | |
| size=point_size, | |
| color=initial_colors, | |
| opacity=0.85 | |
| ) | |
| ) | |
| fig = go.Figure(data=[trace]) | |
| fig.update_layout( | |
| scene=dict( | |
| aspectmode='cube', | |
| xaxis_title='X', | |
| yaxis_title='Y', | |
| zaxis_title='Depth (real DA3)', | |
| camera=dict(eye=dict(x=0, y=0, z=2.8), up=dict(x=0, y=1, z=0), center=dict(x=0, y=0, z=0)) | |
| ), | |
| title=f"Real-Depth Animated Point Cloud — {len(all_points)} frames", | |
| height=650, | |
| margin=dict(l=0, r=0, b=0, t=90), | |
| ) | |
| points_data = [] | |
| for pts in all_points: | |
| rgb_colors = [f"rgb({int(255*r)},{int(255*g)},{int(255*b)})" for r,g,b in pts['color']] | |
| points_data.append({'x': pts['x'], 'y': pts['y'], 'z': pts['z'], 'color': rgb_colors}) | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| html_path = f"real_depth_pointcloud_{timestamp}.html" | |
| html_str = pio.to_html( | |
| fig, | |
| include_plotlyjs='cdn', | |
| full_html=True, | |
| default_width="100%", | |
| default_height="650px" | |
| ) | |
| overlay_id = f"pc_ctrl_{timestamp}" | |
| loop_script = f""" | |
| <script> | |
| (function() {{ | |
| const pointsData = {json.dumps(points_data)}; | |
| let loopInterval = null; | |
| let gd = null; | |
| let idx = 0; | |
| let fps = 12; | |
| const ctrlId = "{overlay_id}"; | |
| function createOverlay() {{ | |
| if (document.getElementById(ctrlId)) return; | |
| const wrap = document.createElement('div'); | |
| wrap.id = ctrlId; | |
| wrap.style = "position:fixed; left:12px; top:12px; z-index:9999; background:rgba(0,0,0,0.55); padding:8px 10px; border-radius:8px; color:#fff; font-family:Arial,monospace; display:flex; gap:8px; align-items:center; flex-wrap:wrap;"; | |
| wrap.innerHTML = ` | |
| <button id="{overlay_id}_play" style="padding:6px 8px;border-radius:6px;background:#16a085;color:white;border:none;cursor:pointer;">▶ Play</button> | |
| <button id="{overlay_id}_pause" style="padding:6px 8px;border-radius:6px;background:#7f8c8d;color:white;border:none;cursor:pointer;">⏸ Pause</button> | |
| <label style="font-size:12px;margin-left:6px;">FPS</label> | |
| <input id="{overlay_id}_fps" type="range" min="1" max="60" value="12" style="vertical-align:middle;"> | |
| <span id="{overlay_id}_fps_txt" style="min-width:30px; text-align:center; display:inline-block;">12</span> | |
| <label style="font-size:12px;margin-left:8px;">Size</label> | |
| <input id="{overlay_id}_size" type="range" min="1" max="12" step="0.5" value="{point_size}" style="vertical-align:middle;"> | |
| <span id="{overlay_id}_size_txt" style="min-width:30px; text-align:center; display:inline-block;">{point_size}</span> | |
| <button id="{overlay_id}_resetcam" title="Reset camera" style="padding:6px 8px;border-radius:6px;background:#2d6cdf;color:white;border:none;cursor:pointer;">↺ Cam</button> | |
| `; | |
| document.body.appendChild(wrap); | |
| }} | |
| function findPlotlyDiv() {{ | |
| const direct = document.querySelector('.js-plotly-plot'); | |
| if (direct) return direct; | |
| const anyDiv = document.querySelector('[data-plotly]') || document.querySelector('.plotly-graph-div'); | |
| if (anyDiv) return anyDiv; | |
| return null; | |
| }} | |
| function waitForGd(cb) {{ | |
| const existing = findPlotlyDiv(); | |
| if (existing) return cb(existing); | |
| const mo = new MutationObserver((mut, obs) => {{ | |
| const d = findPlotlyDiv(); | |
| if (d) {{ | |
| obs.disconnect(); | |
| cb(d); | |
| }} | |
| }}); | |
| mo.observe(document.body, {{ childList:true, subtree:true }}); | |
| setTimeout(() => {{ | |
| const d = findPlotlyDiv(); | |
| if (!d) {{ | |
| try {{ mo.disconnect(); }} catch(e){{}} | |
| cb(null); | |
| }} | |
| }}, 8000); | |
| }} | |
| function updateFrame(i, customSize = null) {{ | |
| if (!gd) return; | |
| const p = pointsData[i]; | |
| const sizeToUse = customSize !== null ? customSize : {point_size}; | |
| try {{ | |
| Plotly.restyle(gd, {{ | |
| x: [p.x], | |
| y: [p.y], | |
| z: [p.z], | |
| 'marker.color': [p.color], | |
| 'marker.size': [sizeToUse] | |
| }}, [0]); | |
| }} catch (e) {{ | |
| console.error('restyle failed', e); | |
| }} | |
| }} | |
| function startLoop() {{ | |
| if (loopInterval) return; | |
| const fpsSlider = document.getElementById('{overlay_id}_fps'); | |
| const sizeSlider = document.getElementById('{overlay_id}_size'); | |
| const val = parseInt(fpsSlider?.value || 12); | |
| const delay = Math.max(1, Math.round(1000 / val)); | |
| idx = (idx + 1) % pointsData.length; | |
| updateFrame(idx, parseFloat(sizeSlider?.value || {point_size})); | |
| loopInterval = setInterval(() => {{ | |
| idx = (idx + 1) % pointsData.length; | |
| updateFrame(idx, parseFloat(sizeSlider?.value || {point_size})); | |
| }}, delay); | |
| }} | |
| function pauseLoop() {{ | |
| if (loopInterval) {{ | |
| clearInterval(loopInterval); | |
| loopInterval = null; | |
| }} | |
| }} | |
| function resetCamera() {{ | |
| if (!gd) return; | |
| try {{ | |
| Plotly.relayout(gd, {{ 'scene.camera.eye': {{x:0,y:0,z:2.8}} }}); | |
| }} catch(e) {{}} | |
| }} | |
| function bindOverlay() {{ | |
| createOverlay(); | |
| const playBtn = document.getElementById('{overlay_id}_play'); | |
| const pauseBtn = document.getElementById('{overlay_id}_pause'); | |
| const fpsSlider = document.getElementById('{overlay_id}_fps'); | |
| const fpsTxt = document.getElementById('{overlay_id}_fps_txt'); | |
| const sizeSlider= document.getElementById('{overlay_id}_size'); | |
| const sizeTxt = document.getElementById('{overlay_id}_size_txt'); | |
| const resetBtn = document.getElementById('{overlay_id}_resetcam'); | |
| playBtn.onclick = () => startLoop(); | |
| pauseBtn.onclick = () => pauseLoop(); | |
| resetBtn.onclick = () => resetCamera(); | |
| fpsSlider.oninput = (e) => {{ | |
| const v = e.target.value; | |
| fpsTxt.innerText = v; | |
| if (loopInterval) {{ pauseLoop(); startLoop(); }} | |
| }}; | |
| sizeSlider.oninput = (e) => {{ | |
| const v = parseFloat(e.target.value); | |
| sizeTxt.innerText = v.toFixed(1); | |
| updateFrame(idx, v); | |
| }}; | |
| }} | |
| function init() {{ | |
| bindOverlay(); | |
| waitForGd(function(div) {{ | |
| if (!div) return; | |
| gd = div; | |
| updateFrame(0); | |
| try {{ | |
| const ro = new ResizeObserver(() => Plotly.Plots.resize(gd)); | |
| ro.observe(gd); | |
| }} catch(e) {{}} | |
| window.addEventListener('keydown', (ev) => {{ | |
| if (ev.code === 'Space') {{ | |
| ev.preventDefault(); | |
| if (loopInterval) pauseLoop(); else startLoop(); | |
| }} | |
| }}); | |
| }}); | |
| }} | |
| window.addEventListener('load', init); | |
| window.addEventListener('beforeunload', () => {{ if (loopInterval) clearInterval(loopInterval); }}); | |
| }})(); | |
| </script> | |
| """ | |
| html_str = html_str.replace('</body>', loop_script + '</body>') | |
| with open(html_path, 'w', encoding='utf-8') as f: | |
| f.write(html_str) | |
| if os.environ.get("SPACE_ID"): | |
| space_url = f"https://{os.environ['SPACE_ID']}.hf.space" | |
| iframe_src = f"{space_url}/file={html_path}" | |
| else: | |
| iframe_src = f"/file={html_path}" | |
| iframe_html = f''' | |
| <iframe | |
| src="{iframe_src}" | |
| width="100%" | |
| height="720" | |
| style="border:none; border-radius:8px; box-shadow:0 4px 20px rgba(0,0,0,0.1);" | |
| allowfullscreen | |
| sandbox="allow-scripts allow-same-origin allow-popups"> | |
| </iframe> | |
| ''' | |
| preview_fig = go.Figure(data=[trace]) | |
| preview_fig.update_layout(scene=fig.layout.scene, height=500) | |
| depth_str = "Real Depth (Depth Anything 3 Giant)" if used_real_depth else "Grayscale fallback" | |
| status = f"✅ Done! {len(all_points)} frames • {depth_str} • Intensity: {depth} • Download: {html_path}" | |
| return preview_fig, status, iframe_html, html_path | |
| # ====================== GRADIO INTERFACE ====================== | |
| demo = gr.Interface( | |
| fn=create_animated_point_cloud, | |
| inputs=[ | |
| gr.Video(label="Upload Video (short = faster)"), | |
| gr.Slider(128, 1024, value=256, step=64, label="Point Cloud Resolution"), | |
| gr.Slider(128, 768, value=384, step=64, label="Depth Process Resolution (higher = sharper depth, slower)"), | |
| gr.Slider(0.05, 0.5, value=0.25, step=0.05, label="Point Density"), | |
| gr.Slider(0.2, 1.5, value=0.5, step=0.1, label="Depth Intensity"), | |
| gr.Slider(0.2, 12, value=3.0, step=0.2, label="Point Size"), | |
| gr.Slider(4, 65, value=10, step=1, label="Max Frames"), | |
| gr.Slider(2, 10, value=5, step=1, label="Frame Step (higher = faster)") | |
| ], | |
| outputs=[ | |
| gr.Plot(label="Static Preview (first frame)"), | |
| gr.Textbox(label="Status"), | |
| gr.HTML(label="🎥 Live Interactive 3D Animation (real depth)"), | |
| gr.File(label="↓ Download HTML (offline use)") | |
| ], | |
| title="Video → Real-Depth Animated 3D Point Cloud (Depth Anything 3 on CPU)", | |
| description="""Uses the **real Depth Anything 3 Giant model** (CPU) instead of fake brightness. | |
| Fully optimized for 16GB RAM / 2 vCPU: model loaded once, all CPU threads used, batch-friendly inference path ready. | |
| Live controls + offline HTML download. Short videos recommended (<30s).""", | |
| flagging_mode="never" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) |