+
+
+ """
+
+
+def get_description_html():
+ """
+ Generate the main description and getting started HTML.
+
+ Returns:
+ str: HTML string for the description
+ """
+ return """
+
+
+ What This Demo Does
+
+
+
+ Upload images or videos β Get MetricPoint Clouds, Cameras and Novel Views β Explore in 3D
+
+
+
+
+
+ Tip: Landscape-oriented images or videos are preferred for best 3D recovering.
+
+
+
+
+
+ """
+
+
+def get_acknowledgements_html():
+ """
+ Generate the acknowledgements section HTML.
+
+ Returns:
+ str: HTML string for the acknowledgements
+ """
+ return """
+
"
+ ""
+ if _gallery_dir and os.path.exists(_gallery_dir)
+ else ""
+ )
+ + """
+
+
+
+
+
+ """
+ )
+ return HTMLResponse(html_content)
+
+ @_app.get("/dashboard", response_class=HTMLResponse)
+ async def dashboard():
+ """HTML dashboard for monitoring backend status and tasks."""
+ if _backend is None:
+ return HTMLResponse("
Backend not initialized
", status_code=500)
+
+ # Get backend status
+ status = _backend.get_status()
+
+ # Safely format status values
+ if status["load_time"] is not None:
+ load_time_str = f"{status['load_time']:.2f}s"
+ else:
+ load_time_str = "Not loaded"
+
+ if status["uptime"] is not None:
+ uptime_str = f"{status['uptime']:.2f}s"
+ else:
+ uptime_str = "Not running"
+
+ # Get tasks information
+ active_tasks = [task for task in _tasks.values() if task.status in ["pending", "running"]]
+ completed_tasks = [
+ task for task in _tasks.values() if task.status in ["completed", "failed"]
+ ]
+
+ # Generate task HTML
+ active_tasks_html = ""
+ if active_tasks:
+ for task in active_tasks:
+ task_details = f"""
+
+
+ {task.task_id}
+ {task.status}
+
+
{task.message}
+
+
+ Images: {task.num_images or 'N/A'} |
+ Format: {task.export_format or 'N/A'} |
+ Method: {task.process_res_method or 'N/A'} |
+ Export Dir: {task.export_dir or 'N/A'}
+
+ {f' Video: {task.video_path}' if task.video_path else ''}
+
Last updated: {time.strftime('%Y-%m-%d %H:%M:%S')}
+
+ {active_tasks_html}
+
+
+
+
Recent Completed Tasks
+ {completed_tasks_html}
+
+
+
+
+
+
+ """
+
+ return HTMLResponse(html_content)
+
+ @_app.get("/status")
+ async def get_status():
+ """Get backend status with GPU memory information."""
+ if _backend is None:
+ raise HTTPException(status_code=500, detail="Backend not initialized")
+
+ status = _backend.get_status()
+
+ # Add GPU memory information
+ gpu_memory = get_gpu_memory_info()
+ if gpu_memory:
+ status["gpu_memory"] = {
+ "total_gb": round(gpu_memory["total_gb"], 2),
+ "allocated_gb": round(gpu_memory["allocated_gb"], 2),
+ "reserved_gb": round(gpu_memory["reserved_gb"], 2),
+ "free_gb": round(gpu_memory["free_gb"], 2),
+ "utilization_percent": round(gpu_memory["utilization"], 1),
+ }
+ else:
+ status["gpu_memory"] = None
+
+ return status
+
+ @_app.post("/inference", response_model=InferenceResponse)
+ async def run_inference(request: InferenceRequest):
+ """Submit inference task and return task ID."""
+ global _running_task_id
+
+ if _backend is None:
+ raise HTTPException(status_code=500, detail="Backend not initialized")
+
+ # Generate unique task ID
+ task_id = str(uuid.uuid4())
+
+ # Create task status
+ if _running_task_id is not None:
+ status_msg = f"[{task_id}] Task queued (waiting for {_running_task_id} to complete)"
+ else:
+ status_msg = f"[{task_id}] Task submitted"
+
+ _tasks[task_id] = TaskStatus(
+ task_id=task_id,
+ status="pending",
+ message=status_msg,
+ created_at=time.time(),
+ export_dir=request.export_dir,
+ request=request,
+ # Record essential parameters
+ num_images=len(request.image_paths),
+ export_format=request.export_format,
+ process_res_method=request.process_res_method,
+ video_path=(
+ request.image_paths[0] if request.image_paths else None
+ ), # Use first image path as video reference
+ )
+
+ # Add task to queue
+ _task_queue.append(task_id)
+
+ # If no task is running, start processing the queue
+ if _running_task_id is None:
+ _process_next_task()
+
+ return InferenceResponse(
+ success=True,
+ message="Task submitted successfully",
+ task_id=task_id,
+ export_dir=request.export_dir,
+ export_format=request.export_format,
+ )
+
+ @_app.get("/task/{task_id}", response_model=TaskStatus)
+ async def get_task_status(task_id: str):
+ """Get task status by task ID."""
+ if task_id not in _tasks:
+ raise HTTPException(status_code=404, detail="Task not found")
+
+ return _tasks[task_id]
+
+ @_app.get("/gpu-memory")
+ async def get_gpu_memory():
+ """Get detailed GPU memory information."""
+ gpu_memory = get_gpu_memory_info()
+ if gpu_memory is None:
+ return {
+ "available": False,
+ "message": "CUDA not available or memory info cannot be retrieved",
+ }
+
+ return {
+ "available": True,
+ "total_gb": round(gpu_memory["total_gb"], 2),
+ "allocated_gb": round(gpu_memory["allocated_gb"], 2),
+ "reserved_gb": round(gpu_memory["reserved_gb"], 2),
+ "free_gb": round(gpu_memory["free_gb"], 2),
+ "utilization_percent": round(gpu_memory["utilization"], 1),
+ "status": (
+ "healthy"
+ if gpu_memory["utilization"] < 80
+ else "warning" if gpu_memory["utilization"] < 95 else "critical"
+ ),
+ }
+
+ @_app.get("/tasks")
+ async def list_tasks():
+ """List all tasks."""
+ # Separate active and completed tasks
+ active_tasks = [task for task in _tasks.values() if task.status in ["pending", "running"]]
+ completed_tasks = [
+ task for task in _tasks.values() if task.status in ["completed", "failed"]
+ ]
+
+ return {
+ "tasks": list(_tasks.values()),
+ "active_tasks": active_tasks,
+ "completed_tasks": completed_tasks,
+ "active_count": len(active_tasks),
+ "total_count": len(_tasks),
+ }
+
+ @_app.post("/cleanup")
+ async def manual_cleanup():
+ """Manually trigger task cleanup."""
+ try:
+ _cleanup_old_tasks()
+ return {"message": "Cleanup completed", "active_tasks": len(_tasks)}
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=f"Cleanup failed: {str(e)}")
+
+ @_app.delete("/task/{task_id}")
+ async def delete_task(task_id: str):
+ """Delete a specific task."""
+ if task_id not in _tasks:
+ raise HTTPException(status_code=404, detail="Task not found")
+
+ # Only allow deletion of completed/failed tasks
+ if _tasks[task_id].status not in ["completed", "failed"]:
+ raise HTTPException(status_code=400, detail="Cannot delete running or pending tasks")
+
+ del _tasks[task_id]
+ return {"message": f"Task {task_id} deleted successfully"}
+
+ @_app.post("/reload")
+ async def reload_model():
+ """Reload the model."""
+ if _backend is None:
+ raise HTTPException(status_code=500, detail="Backend not initialized")
+
+ try:
+ _backend.model = None
+ _backend.model_loaded = False
+ _backend.load_model()
+ return {"message": "Model reloaded successfully"}
+ except Exception as e:
+ raise HTTPException(status_code=500, detail=f"Failed to reload model: {str(e)}")
+
+ # ============================================================================
+ # Gallery routes
+ # ============================================================================
+
+ if _gallery_dir and os.path.exists(_gallery_dir):
+ # Load gallery HTML page (with modified paths for /gallery/ subdirectory)
+ _gallery_html = _load_gallery_html()
+
+ @_app.get("/gallery/", response_class=HTMLResponse)
+ @_app.get("/gallery", response_class=HTMLResponse)
+ async def gallery_home():
+ """Gallery home page."""
+ return HTMLResponse(_gallery_html)
+
+ @_app.get("/gallery/manifest.json")
+ async def gallery_manifest():
+ """Get gallery group list."""
+ try:
+ return build_group_list(_gallery_dir)
+ except Exception as e:
+ raise HTTPException(
+ status_code=500, detail=f"Failed to build group list: {str(e)}"
+ )
+
+ @_app.get("/gallery/manifest/{group}.json")
+ async def gallery_group_manifest(group: str):
+ """Get manifest for a specific group."""
+ if not _is_plain_name(group):
+ raise HTTPException(status_code=400, detail="Invalid group name")
+ try:
+ return build_group_manifest(_gallery_dir, group)
+ except Exception as e:
+ raise HTTPException(
+ status_code=500, detail=f"Failed to build group manifest: {str(e)}"
+ )
+
+ @_app.get("/gallery/{path:path}")
+ async def gallery_files(path: str):
+ """Serve gallery static files (GLB, JPG, etc.)."""
+ # Security check: prevent directory traversal
+ path_parts = path.split("/")
+ if any(not _is_plain_name(part) for part in path_parts if part):
+ raise HTTPException(status_code=400, detail="Invalid path")
+
+ file_path = os.path.join(_gallery_dir, *path_parts)
+
+ # Ensure the file is within gallery directory
+ real_file_path = os.path.realpath(file_path)
+ real_gallery_dir = os.path.realpath(_gallery_dir)
+ if not real_file_path.startswith(real_gallery_dir):
+ raise HTTPException(status_code=403, detail="Access denied")
+
+ if not os.path.exists(file_path) or not os.path.isfile(file_path):
+ raise HTTPException(status_code=404, detail="File not found")
+
+ return FileResponse(file_path)
+
+ return _app
+
+
+def start_server(
+ model_dir: str,
+ device: str = "cuda",
+ host: str = "127.0.0.1",
+ port: int = 8000,
+ gallery_dir: Optional[str] = None,
+):
+ """Start the backend server."""
+ app = create_app(model_dir, device, gallery_dir)
+
+ print("Starting Depth Anything 3 Backend...")
+ print(f"Model directory: {model_dir}")
+ print(f"Device: {device}")
+ print(f"Server: http://{host}:{port}")
+ print(f"Dashboard: http://{host}:{port}/dashboard")
+ print(f"API Status: http://{host}:{port}/status")
+
+ if gallery_dir and os.path.exists(gallery_dir):
+ print(f"Gallery: http://{host}:{port}/gallery/")
+
+ print("=" * 60)
+ print("Backend is running! You can now:")
+ print(f" β’ Open home page: http://{host}:{port}")
+ print(f" β’ Open dashboard: http://{host}:{port}/dashboard")
+ print(f" β’ Check API status: http://{host}:{port}/status")
+
+ if gallery_dir and os.path.exists(gallery_dir):
+ print(f" β’ Browse gallery: http://{host}:{port}/gallery/")
+
+ print(" β’ Submit inference tasks via API")
+ print("=" * 60)
+
+ uvicorn.run(app, host=host, port=port, log_level="info")
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser(description="Depth Anything 3 Backend Server")
+ parser.add_argument("--model-dir", required=True, help="Model directory path")
+ parser.add_argument("--device", default="cuda", help="Device to use")
+ parser.add_argument("--host", default="127.0.0.1", help="Host to bind to")
+ parser.add_argument("--port", type=int, default=8000, help="Port to bind to")
+ parser.add_argument("--gallery-dir", help="Gallery directory path (optional)")
+
+ args = parser.parse_args()
+ start_server(args.model_dir, args.device, args.host, args.port, args.gallery_dir)
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/services/gallery.py b/Depth-Anything-3-anysize/src/depth_anything_3/services/gallery.py
new file mode 100644
index 0000000000000000000000000000000000000000..f72bb5e5f6defbc24cf2278a53b7162a8ad5519d
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/services/gallery.py
@@ -0,0 +1,806 @@
+#!/usr/bin/env python3
+# flake8: noqa: E501
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Depth Anything 3 Gallery Server (two-level, single-file)
+Now supports paginated depth preview (4 per page).
+"""
+
+import argparse
+import json
+import mimetypes
+import os
+import posixpath
+import sys
+from functools import partial
+from http import HTTPStatus
+from http.server import SimpleHTTPRequestHandler, ThreadingHTTPServer
+from urllib.parse import quote, unquote
+
+# ------------------------------ Embedded HTML ------------------------------ #
+
+HTML_PAGE = r"""
+
+
+
+ Depth Anything 3 Gallery
+
+
+
+
+
+
+
+
+
+
+
Depth Anything 3 Gallery
+
+
+
+
Level 1 shows groups only; click a group to browse scenes and previews.
+
+
+
+
+
+
+ π― Depth Anything 3 Gallery
+
+
+ Explore 3D reconstructions and depth visualizations from Depth Anything 3.
+ Browse through groups of scenes, preview 3D models, and examine depth maps interactively.
+
+
+
+
+
+
No available groups
+
+
+
+
+
+
No available scenes in this group
+
+
+
+
+
+
+
Loadingβ¦
+
+
+
+
β¬
+
+
+
+
+
+
+
+
+
+
+
+
+"""
+
+# ------------------------------ Utilities ------------------------------ #
+
+IMAGE_EXTS = (".png", ".jpg", ".jpeg", ".webp", ".bmp")
+
+
+def _url_join(*parts: str) -> str:
+ norm = posixpath.join(*[p.replace("\\", "/") for p in parts])
+ segs = [s for s in norm.split("/") if s not in ("", ".")]
+ return "/".join(quote(s) for s in segs)
+
+
+def _is_plain_name(name: str) -> bool:
+ return all(c not in name for c in ("/", "\\")) and name not in (".", "..")
+
+
+def build_group_list(root_dir: str) -> dict:
+ groups = []
+ try:
+ for gname in sorted(os.listdir(root_dir)):
+ gpath = os.path.join(root_dir, gname)
+ if not os.path.isdir(gpath):
+ continue
+ has_scene = False
+ try:
+ for sname in os.listdir(gpath):
+ spath = os.path.join(gpath, sname)
+ if not os.path.isdir(spath):
+ continue
+ if os.path.exists(os.path.join(spath, "scene.glb")) and os.path.exists(
+ os.path.join(spath, "scene.jpg")
+ ):
+ has_scene = True
+ break
+ except Exception:
+ pass
+ if has_scene:
+ groups.append({"id": gname, "title": gname})
+ except Exception as e:
+ print(f"[warn] build_group_list failed: {e}", file=sys.stderr)
+ return {"groups": groups}
+
+
+def build_group_manifest(root_dir: str, group: str) -> dict:
+ items = []
+ gpath = os.path.join(root_dir, group)
+ try:
+ if not os.path.isdir(gpath):
+ return {"group": group, "items": []}
+ for sname in sorted(os.listdir(gpath)):
+ spath = os.path.join(gpath, sname)
+ if not os.path.isdir(spath):
+ continue
+ glb_fs = os.path.join(spath, "scene.glb")
+ jpg_fs = os.path.join(spath, "scene.jpg")
+ if not (os.path.exists(glb_fs) and os.path.exists(jpg_fs)):
+ continue
+ depth_images = []
+ dpath = os.path.join(spath, "depth_vis")
+ if os.path.isdir(dpath):
+ files = [
+ f for f in os.listdir(dpath) if os.path.splitext(f)[1].lower() in IMAGE_EXTS
+ ]
+ for fn in sorted(files):
+ depth_images.append("/" + _url_join(group, sname, "depth_vis", fn))
+ items.append(
+ {
+ "id": sname,
+ "title": sname,
+ "model": "/" + _url_join(group, sname, "scene.glb"),
+ "thumbnail": "/" + _url_join(group, sname, "scene.jpg"),
+ "depth_images": depth_images,
+ }
+ )
+ except Exception as e:
+ print(f"[warn] build_group_manifest failed for {group}: {e}", file=sys.stderr)
+ return {"group": group, "items": items}
+
+
+class GalleryHandler(SimpleHTTPRequestHandler):
+ def __init__(self, *args, directory=None, **kwargs):
+ super().__init__(*args, directory=directory, **kwargs)
+
+ def do_GET(self):
+ if self.path in ("/", "/index.html") or self.path.startswith("/?"):
+ content = HTML_PAGE.encode("utf-8")
+ self.send_response(HTTPStatus.OK)
+ self.send_header("Content-Type", "text/html; charset=utf-8")
+ self.send_header("Content-Length", str(len(content)))
+ self.send_header("Cache-Control", "no-store")
+ self.end_headers()
+ self.wfile.write(content)
+ return
+ if self.path == "/manifest.json":
+ data = json.dumps(
+ build_group_list(self.directory), ensure_ascii=False, indent=2
+ ).encode("utf-8")
+ self.send_response(HTTPStatus.OK)
+ self.send_header("Content-Type", "application/json; charset=utf-8")
+ self.send_header("Content-Length", str(len(data)))
+ self.send_header("Cache-Control", "no-store")
+ self.end_headers()
+ self.wfile.write(data)
+ return
+ if self.path.startswith("/manifest/") and self.path.endswith(".json"):
+ group_enc = self.path[len("/manifest/") : -len(".json")]
+ try:
+ group = unquote(group_enc)
+ except Exception:
+ group = group_enc
+ if not _is_plain_name(group):
+ self.send_error(HTTPStatus.BAD_REQUEST, "Invalid group name")
+ return
+ data = json.dumps(
+ build_group_manifest(self.directory, group), ensure_ascii=False, indent=2
+ ).encode("utf-8")
+ self.send_response(HTTPStatus.OK)
+ self.send_header("Content-Type", "application/json; charset=utf-8")
+ self.send_header("Content-Length", str(len(data)))
+ self.send_header("Cache-Control", "no-store")
+ self.end_headers()
+ self.wfile.write(data)
+ return
+ if self.path == "/favicon.ico":
+ self.send_response(HTTPStatus.NO_CONTENT)
+ self.end_headers()
+ return
+ return super().do_GET()
+
+ def list_directory(self, path):
+ self.send_error(HTTPStatus.NOT_FOUND, "Directory listing disabled")
+ return None
+
+
+def gallery():
+ parser = argparse.ArgumentParser(
+ description="Depth Anything 3 Gallery Server (two-level, with pagination)"
+ )
+ parser.add_argument(
+ "-d", "--dir", required=True, help="Gallery root directory (two-level: group/scene)"
+ )
+ parser.add_argument("-p", "--port", type=int, default=8000, help="Port (default 8000)")
+ parser.add_argument("--host", default="127.0.0.1", help="Host address (default 127.0.0.1)")
+ parser.add_argument("--open", action="store_true", help="Open browser after launch")
+ args = parser.parse_args()
+
+ root_dir = os.path.abspath(args.dir)
+ if not os.path.isdir(root_dir):
+ print(f"[error] Directory not found: {root_dir}", file=sys.stderr)
+ sys.exit(1)
+
+ Handler = partial(GalleryHandler, directory=root_dir)
+ server = ThreadingHTTPServer((args.host, args.port), Handler)
+
+ addr = f"http://{args.host}:{args.port}/"
+ print(f"[info] Serving gallery from: {root_dir}")
+ print(f"[info] Open: {addr}")
+
+ if args.open:
+ try:
+ import webbrowser
+
+ webbrowser.open(addr)
+ except Exception as e:
+ print(f"[warn] Failed to open browser: {e}", file=sys.stderr)
+
+ try:
+ server.serve_forever()
+ except KeyboardInterrupt:
+ print("\n[info] Shutting down...")
+ finally:
+ server.server_close()
+
+
+def main():
+ """Main entry point for gallery server."""
+ mimetypes.add_type("model/gltf-binary", ".glb")
+ gallery()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/services/inference_service.py b/Depth-Anything-3-anysize/src/depth_anything_3/services/inference_service.py
new file mode 100644
index 0000000000000000000000000000000000000000..07ca1657a43c7407bf8eaee080adde466480e417
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/services/inference_service.py
@@ -0,0 +1,225 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Unified Inference Service
+Provides unified interface for local and remote inference
+"""
+
+from typing import Any, Dict, List, Optional, Union
+import numpy as np
+import requests
+import typer
+
+from ..api import DepthAnything3
+
+
+class InferenceService:
+ """Unified inference service class"""
+
+ def __init__(self, model_dir: str, device: str = "cuda"):
+ self.model_dir = model_dir
+ self.device = device
+ self.model = None
+
+ def load_model(self):
+ """Load model"""
+ if self.model is None:
+ typer.echo(f"Loading model from {self.model_dir}...")
+ self.model = DepthAnything3.from_pretrained(self.model_dir).to(self.device)
+ return self.model
+
+ def run_local_inference(
+ self,
+ image_paths: List[str],
+ export_dir: str,
+ export_format: str = "mini_npz-glb",
+ process_res: Optional[int] = None,
+ process_res_method: str = "keep",
+ export_feat_layers: List[int] = None,
+ extrinsics: Optional[np.ndarray] = None,
+ intrinsics: Optional[np.ndarray] = None,
+ align_to_input_ext_scale: bool = True,
+ conf_thresh_percentile: float = 40.0,
+ num_max_points: int = 1_000_000,
+ show_cameras: bool = True,
+ feat_vis_fps: int = 15,
+ ) -> Any:
+ """Run local inference"""
+ if export_feat_layers is None:
+ export_feat_layers = []
+
+ model = self.load_model()
+
+ # Prepare inference parameters
+ inference_kwargs = {
+ "image": image_paths,
+ "export_dir": export_dir,
+ "export_format": export_format,
+ "process_res": process_res,
+ "process_res_method": process_res_method,
+ "export_feat_layers": export_feat_layers,
+ "align_to_input_ext_scale": align_to_input_ext_scale,
+ "conf_thresh_percentile": conf_thresh_percentile,
+ "num_max_points": num_max_points,
+ "show_cameras": show_cameras,
+ "feat_vis_fps": feat_vis_fps,
+ }
+
+ # Add pose data (if exists)
+ if extrinsics is not None:
+ inference_kwargs["extrinsics"] = extrinsics
+ if intrinsics is not None:
+ inference_kwargs["intrinsics"] = intrinsics
+
+ # Run inference
+ typer.echo(f"Running inference on {len(image_paths)} images...")
+ prediction = model.inference(**inference_kwargs)
+
+ typer.echo(f"Results saved to {export_dir}")
+ typer.echo(f"Export format: {export_format}")
+
+ return prediction
+
+ def run_backend_inference(
+ self,
+ image_paths: List[str],
+ export_dir: str,
+ backend_url: str,
+ export_format: str = "mini_npz-glb",
+ process_res: Optional[int] = None,
+ process_res_method: str = "keep",
+ export_feat_layers: List[int] = None,
+ extrinsics: Optional[np.ndarray] = None,
+ intrinsics: Optional[np.ndarray] = None,
+ align_to_input_ext_scale: bool = True,
+ conf_thresh_percentile: float = 40.0,
+ num_max_points: int = 1_000_000,
+ show_cameras: bool = True,
+ feat_vis_fps: int = 15,
+ ) -> Dict[str, Any]:
+ """Run backend inference"""
+ if export_feat_layers is None:
+ export_feat_layers = []
+
+ # Check backend status
+ if not self._check_backend_status(backend_url):
+ raise typer.BadParameter(f"Backend service is not running at {backend_url}")
+
+ # Prepare payload
+ payload = {
+ "image_paths": image_paths,
+ "export_dir": export_dir,
+ "export_format": export_format,
+ "process_res": process_res,
+ "process_res_method": process_res_method,
+ "export_feat_layers": export_feat_layers,
+ "align_to_input_ext_scale": align_to_input_ext_scale,
+ "conf_thresh_percentile": conf_thresh_percentile,
+ "num_max_points": num_max_points,
+ "show_cameras": show_cameras,
+ "feat_vis_fps": feat_vis_fps,
+ }
+
+ # Add pose data (if exists)
+ if extrinsics is not None:
+ payload["extrinsics"] = [ext.astype(np.float64).tolist() for ext in extrinsics]
+ if intrinsics is not None:
+ payload["intrinsics"] = [intr.astype(np.float64).tolist() for intr in intrinsics]
+
+ # Submit task
+ typer.echo("Submitting inference task to backend...")
+ try:
+ response = requests.post(f"{backend_url}/inference", json=payload, timeout=30)
+ response.raise_for_status()
+ result = response.json()
+
+ if result["success"]:
+ task_id = result["task_id"]
+ typer.echo("Task submitted successfully!")
+ typer.echo(f"Task ID: {task_id}")
+ typer.echo(f"Results will be saved to: {export_dir}")
+ typer.echo(f"Check backend logs for progress updates with task ID: {task_id}")
+ return result
+ else:
+ raise typer.BadParameter(
+ f"Backend inference submission failed: {result['message']}"
+ )
+ except requests.exceptions.RequestException as e:
+ raise typer.BadParameter(f"Backend inference submission failed: {e}")
+
+ def _check_backend_status(self, backend_url: str) -> bool:
+ """Check backend status"""
+ try:
+ response = requests.get(f"{backend_url}/status", timeout=5)
+ return response.status_code == 200
+ except Exception:
+ return False
+
+
+def run_inference(
+ image_paths: List[str],
+ export_dir: str,
+ model_dir: str,
+ device: str = "cuda",
+ backend_url: Optional[str] = None,
+ export_format: str = "mini_npz-glb",
+ process_res: Optional[int] = None,
+ process_res_method: str = "keep",
+ export_feat_layers: List[int] = None,
+ extrinsics: Optional[np.ndarray] = None,
+ intrinsics: Optional[np.ndarray] = None,
+ align_to_input_ext_scale: bool = True,
+ conf_thresh_percentile: float = 40.0,
+ num_max_points: int = 1_000_000,
+ show_cameras: bool = True,
+ feat_vis_fps: int = 15,
+) -> Union[Any, Dict[str, Any]]:
+ """Unified inference interface"""
+
+ service = InferenceService(model_dir, device)
+
+ if backend_url:
+ return service.run_backend_inference(
+ image_paths=image_paths,
+ export_dir=export_dir,
+ backend_url=backend_url,
+ export_format=export_format,
+ process_res=process_res,
+ process_res_method=process_res_method,
+ export_feat_layers=export_feat_layers,
+ extrinsics=extrinsics,
+ intrinsics=intrinsics,
+ align_to_input_ext_scale=align_to_input_ext_scale,
+ conf_thresh_percentile=conf_thresh_percentile,
+ num_max_points=num_max_points,
+ show_cameras=show_cameras,
+ feat_vis_fps=feat_vis_fps,
+ )
+ else:
+ return service.run_local_inference(
+ image_paths=image_paths,
+ export_dir=export_dir,
+ export_format=export_format,
+ process_res=process_res,
+ process_res_method=process_res_method,
+ export_feat_layers=export_feat_layers,
+ extrinsics=extrinsics,
+ intrinsics=intrinsics,
+ align_to_input_ext_scale=align_to_input_ext_scale,
+ conf_thresh_percentile=conf_thresh_percentile,
+ num_max_points=num_max_points,
+ show_cameras=show_cameras,
+ feat_vis_fps=feat_vis_fps,
+ )
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/services/input_handlers.py b/Depth-Anything-3-anysize/src/depth_anything_3/services/input_handlers.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc0536b51e5c71bb6fb9d10f8e74fbb12fc42d31
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/services/input_handlers.py
@@ -0,0 +1,266 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Input Processing Service
+Handles different types of inputs (image, images, colmap, video)
+"""
+
+import glob
+import os
+from typing import List, Tuple
+import cv2
+import numpy as np
+import typer
+
+from ..utils.read_write_model import read_model
+
+
+class InputHandler:
+ """Base input handler class"""
+
+ @staticmethod
+ def validate_path(path: str, path_type: str = "file") -> str:
+ """Validate path"""
+ if not os.path.exists(path):
+ raise typer.BadParameter(f"{path_type} not found: {path}")
+ return path
+
+ @staticmethod
+ def handle_export_dir(export_dir: str, auto_cleanup: bool = False) -> str:
+ """Handle export directory"""
+ if os.path.exists(export_dir):
+ if auto_cleanup:
+ typer.echo(f"Auto-cleaning existing export directory: {export_dir}")
+ import shutil
+
+ shutil.rmtree(export_dir)
+ os.makedirs(export_dir, exist_ok=True)
+ else:
+ typer.echo(f"Export directory '{export_dir}' already exists.")
+ if typer.confirm("Do you want to clean it and continue?"):
+ import shutil
+
+ shutil.rmtree(export_dir)
+ os.makedirs(export_dir, exist_ok=True)
+ typer.echo(f"Cleaned export directory: {export_dir}")
+ else:
+ typer.echo("Operation cancelled.")
+ raise typer.Exit(0)
+ else:
+ os.makedirs(export_dir, exist_ok=True)
+ return export_dir
+
+
+class ImageHandler(InputHandler):
+ """Single image handler"""
+
+ @staticmethod
+ def process(image_path: str) -> List[str]:
+ """Process single image"""
+ InputHandler.validate_path(image_path, "Image file")
+ return [image_path]
+
+
+class ImagesHandler(InputHandler):
+ """Image directory handler"""
+
+ @staticmethod
+ def process(images_dir: str, image_extensions: str = "png,jpg,jpeg") -> List[str]:
+ """Process image directory"""
+ InputHandler.validate_path(images_dir, "Images directory")
+
+ # Parse extensions
+ extensions = [ext.strip().lower() for ext in image_extensions.split(",")]
+ extensions = [ext if ext.startswith(".") else f".{ext}" for ext in extensions]
+
+ # Find image files
+ image_files = []
+ for ext in extensions:
+ pattern = f"*{ext}"
+ image_files.extend(glob.glob(os.path.join(images_dir, pattern)))
+ image_files.extend(glob.glob(os.path.join(images_dir, pattern.upper())))
+
+ image_files = sorted(list(set(image_files))) # Remove duplicates and sort
+
+ if not image_files:
+ raise typer.BadParameter(
+ f"No image files found in {images_dir} with extensions: {extensions}"
+ )
+
+ typer.echo(f"Found {len(image_files)} images to process")
+ return image_files
+
+
+class ColmapHandler(InputHandler):
+ """COLMAP data handler"""
+
+ @staticmethod
+ def process(
+ colmap_dir: str, sparse_subdir: str = ""
+ ) -> Tuple[List[str], np.ndarray, np.ndarray]:
+ """Process COLMAP data"""
+ InputHandler.validate_path(colmap_dir, "COLMAP directory")
+
+ # Build paths
+ images_dir = os.path.join(colmap_dir, "images")
+ if sparse_subdir:
+ sparse_dir = os.path.join(colmap_dir, "sparse", sparse_subdir)
+ else:
+ sparse_dir = os.path.join(colmap_dir, "sparse")
+
+ InputHandler.validate_path(images_dir, "Images directory")
+ InputHandler.validate_path(sparse_dir, "Sparse reconstruction directory")
+
+ # Load COLMAP data
+ typer.echo("Loading COLMAP reconstruction data...")
+ try:
+ cameras, images, points3D = read_model(sparse_dir)
+
+ typer.echo(
+ f"Loaded COLMAP data: {len(cameras)} cameras, {len(images)} images, "
+ f"{len(points3D)} 3D points."
+ )
+
+ # Get image files and pose data
+ image_files = []
+ extrinsics = []
+ intrinsics = []
+
+ for image_id, image_data in images.items():
+ image_name = image_data.name
+ image_path = os.path.join(images_dir, image_name)
+
+ if os.path.exists(image_path):
+ image_files.append(image_path)
+
+ # Get camera parameters
+ camera = cameras[image_data.camera_id]
+
+ # Convert quaternion to rotation matrix
+ R = image_data.qvec2rotmat()
+ t = image_data.tvec
+
+ # Create extrinsic matrix (world to camera)
+ extrinsic = np.eye(4)
+ extrinsic[:3, :3] = R
+ extrinsic[:3, 3] = t
+ extrinsics.append(extrinsic)
+
+ # Create intrinsic matrix
+ if camera.model == "PINHOLE":
+ fx, fy, cx, cy = camera.params
+ elif camera.model == "SIMPLE_PINHOLE":
+ f, cx, cy = camera.params
+ fx = fy = f
+ else:
+ # For other models, use basic pinhole approximation
+ fx = fy = camera.params[0] if len(camera.params) > 0 else 1000
+ cx = camera.width / 2
+ cy = camera.height / 2
+
+ intrinsic = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
+ intrinsics.append(intrinsic)
+
+ if not image_files:
+ raise typer.BadParameter("No valid images found in COLMAP data")
+
+ typer.echo(f"Found {len(image_files)} valid images with pose data")
+
+ return image_files, np.array(extrinsics), np.array(intrinsics)
+
+ except Exception as e:
+ raise typer.BadParameter(f"Failed to load COLMAP data: {e}")
+
+
+class VideoHandler(InputHandler):
+ """Video handler"""
+
+ @staticmethod
+ def process(video_path: str, output_dir: str, fps: float = 1.0) -> List[str]:
+ """Process video, extract frames"""
+ InputHandler.validate_path(video_path, "Video file")
+
+ cap = cv2.VideoCapture(video_path)
+ if not cap.isOpened():
+ raise typer.BadParameter(f"Cannot open video: {video_path}")
+
+ # Get video properties
+ video_fps = cap.get(cv2.CAP_PROP_FPS)
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
+ duration = total_frames / video_fps
+
+ # Calculate frame interval (ensure at least 1)
+ frame_interval = max(1, int(video_fps / fps))
+ actual_fps = video_fps / frame_interval
+
+ typer.echo(f"Video FPS: {video_fps:.2f}, Duration: {duration:.2f}s")
+
+ # Warn if requested FPS is higher than video FPS
+ if fps > video_fps:
+ typer.echo(
+ f"β οΈ Warning: Requested sampling FPS ({fps:.2f}) exceeds video FPS ({video_fps:.2f})", # noqa: E501
+ err=True,
+ )
+ typer.echo(
+ f"β οΈ Using maximum available FPS: {actual_fps:.2f} (extracting every frame)",
+ err=True,
+ )
+
+ typer.echo(f"Extracting frames at {actual_fps:.2f} FPS (every {frame_interval} frame(s))")
+
+ # Create output directory
+ frames_dir = os.path.join(output_dir, "input_images")
+ os.makedirs(frames_dir, exist_ok=True)
+
+ frame_count = 0
+ saved_count = 0
+
+ while True:
+ ret, frame = cap.read()
+ if not ret:
+ break
+
+ if frame_count % frame_interval == 0:
+ frame_path = os.path.join(frames_dir, f"{saved_count:06d}.png")
+ cv2.imwrite(frame_path, frame)
+ saved_count += 1
+
+ frame_count += 1
+
+ cap.release()
+ typer.echo(f"Extracted {saved_count} frames to {frames_dir}")
+
+ # Get frame file list
+ frame_files = sorted(
+ [f for f in os.listdir(frames_dir) if f.endswith((".png", ".jpg", ".jpeg"))]
+ )
+ if not frame_files:
+ raise typer.BadParameter("No frames extracted from video")
+
+ return [os.path.join(frames_dir, f) for f in frame_files]
+
+
+def parse_export_feat(export_feat_str: str) -> List[int]:
+ """Parse export_feat parameter"""
+ if not export_feat_str:
+ return []
+
+ try:
+ return [int(x.strip()) for x in export_feat_str.split(",") if x.strip()]
+ except ValueError:
+ raise typer.BadParameter(
+ f"Invalid export_feat format: {export_feat_str}. "
+ "Use comma-separated integers like '0,1,2'"
+ )
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/specs.py b/Depth-Anything-3-anysize/src/depth_anything_3/specs.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe5b30255e9fd48d988ff00c896fb3dbadf197ea
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/specs.py
@@ -0,0 +1,45 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+from dataclasses import dataclass
+from typing import Any, Optional
+import numpy as np
+import torch
+
+
+@dataclass
+class Gaussians:
+ """3DGS parameters, all in world space"""
+
+ means: torch.Tensor # world points, "batch gaussian dim"
+ scales: torch.Tensor # scales_std, "batch gaussian 3"
+ rotations: torch.Tensor # world_quat_wxyz, "batch gaussian 4"
+ harmonics: torch.Tensor # world SH, "batch gaussian 3 d_sh"
+ opacities: torch.Tensor # opacity | opacity SH, "batch gaussian" | "batch gaussian 1 d_sh"
+
+
+@dataclass
+class Prediction:
+ depth: np.ndarray # N, H, W
+ is_metric: int
+ sky: np.ndarray | None = None # N, H, W
+ conf: np.ndarray | None = None # N, H, W
+ extrinsics: np.ndarray | None = None # N, 4, 4
+ intrinsics: np.ndarray | None = None # N, 3, 3
+ processed_images: np.ndarray | None = None # N, H, W, 3 - processed images for visualization
+ gaussians: Gaussians | None = None # 3D gaussians
+ aux: dict[str, Any] = None #
+ scale_factor: Optional[float] = None # metric scale
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/alignment.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/alignment.py
new file mode 100644
index 0000000000000000000000000000000000000000..42f8e6571a8a49b17e1cf85175461c75f9fb1e80
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/alignment.py
@@ -0,0 +1,161 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Alignment utilities for depth estimation and metric scaling.
+"""
+
+from typing import Tuple
+import torch
+
+
+def least_squares_scale_scalar(
+ a: torch.Tensor, b: torch.Tensor, eps: float = 1e-12
+) -> torch.Tensor:
+ """
+ Compute least squares scale factor s such that a β s * b.
+
+ Args:
+ a: First tensor
+ b: Second tensor
+ eps: Small epsilon for numerical stability
+
+ Returns:
+ Scalar tensor containing the scale factor
+
+ Raises:
+ ValueError: If tensors have mismatched shapes or devices
+ TypeError: If tensors are not floating point
+ """
+ if a.shape != b.shape:
+ raise ValueError(f"Shape mismatch: {a.shape} vs {b.shape}")
+ if a.device != b.device:
+ raise ValueError(f"Device mismatch: {a.device} vs {b.device}")
+ if not a.is_floating_point() or not b.is_floating_point():
+ raise TypeError("Tensors must be floating point type")
+
+ # Compute dot products for least squares solution
+ num = torch.dot(a.reshape(-1), b.reshape(-1))
+ den = torch.dot(b.reshape(-1), b.reshape(-1)).clamp_min(eps)
+ return num / den
+
+
+def compute_sky_mask(sky_prediction: torch.Tensor, threshold: float = 0.3) -> torch.Tensor:
+ """
+ Compute non-sky mask from sky prediction.
+
+ Args:
+ sky_prediction: Sky prediction tensor
+ threshold: Threshold for sky classification
+
+ Returns:
+ Boolean mask where True indicates non-sky regions
+ """
+ return sky_prediction < threshold
+
+
+def compute_alignment_mask(
+ depth_conf: torch.Tensor,
+ non_sky_mask: torch.Tensor,
+ depth: torch.Tensor,
+ metric_depth: torch.Tensor,
+ median_conf: torch.Tensor,
+ min_depth_threshold: float = 1e-3,
+ min_metric_depth_threshold: float = 1e-2,
+) -> torch.Tensor:
+ """
+ Compute mask for depth alignment based on confidence and depth thresholds.
+
+ Args:
+ depth_conf: Depth confidence tensor
+ non_sky_mask: Non-sky region mask
+ depth: Predicted depth tensor
+ metric_depth: Metric depth tensor
+ median_conf: Median confidence threshold
+ min_depth_threshold: Minimum depth threshold
+ min_metric_depth_threshold: Minimum metric depth threshold
+
+ Returns:
+ Boolean mask for valid alignment regions
+ """
+ return (
+ (depth_conf >= median_conf)
+ & non_sky_mask
+ & (metric_depth > min_metric_depth_threshold)
+ & (depth > min_depth_threshold)
+ )
+
+
+def sample_tensor_for_quantile(tensor: torch.Tensor, max_samples: int = 100000) -> torch.Tensor:
+ """
+ Sample tensor elements for quantile computation to reduce memory usage.
+
+ Args:
+ tensor: Input tensor to sample
+ max_samples: Maximum number of samples to take
+
+ Returns:
+ Sampled tensor
+ """
+ if tensor.numel() <= max_samples:
+ return tensor
+
+ idx = torch.randperm(tensor.numel(), device=tensor.device)[:max_samples]
+ return tensor.flatten()[idx]
+
+
+def apply_metric_scaling(
+ depth: torch.Tensor, intrinsics: torch.Tensor, scale_factor: float = 300.0
+) -> torch.Tensor:
+ """
+ Apply metric scaling to depth based on camera intrinsics.
+
+ Args:
+ depth: Input depth tensor
+ intrinsics: Camera intrinsics tensor
+ scale_factor: Scaling factor for metric conversion
+
+ Returns:
+ Scaled depth tensor
+ """
+ focal_length = (intrinsics[:, :, 0, 0] + intrinsics[:, :, 1, 1]) / 2
+ return depth * (focal_length[:, :, None, None] / scale_factor)
+
+
+def set_sky_regions_to_max_depth(
+ depth: torch.Tensor,
+ depth_conf: torch.Tensor,
+ non_sky_mask: torch.Tensor,
+ max_depth: float = 200.0,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Set sky regions to maximum depth and high confidence.
+
+ Args:
+ depth: Depth tensor
+ depth_conf: Depth confidence tensor
+ non_sky_mask: Non-sky region mask
+ max_depth: Maximum depth value for sky regions
+
+ Returns:
+ Tuple of (updated_depth, updated_depth_conf)
+ """
+ depth = depth.clone()
+ depth_conf = depth_conf.clone()
+
+ # Set sky regions to max depth and high confidence
+ depth[~non_sky_mask] = max_depth
+ depth_conf[~non_sky_mask] = 1.0
+
+ return depth, depth_conf
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/api_helpers.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/api_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..b327331d9ec61a3047be3e330f41156eb124adc4
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/api_helpers.py
@@ -0,0 +1,58 @@
+import argparse
+
+
+def parse_scalar(s):
+ if not isinstance(s, str):
+ return s
+ t = s.strip()
+ l = t.lower()
+ if l == "true":
+ return True
+ if l == "false":
+ return False
+ if l in ("none", "null"):
+ return None
+ try:
+ return int(t, 10)
+ except Exception:
+ pass
+ try:
+ return float(t)
+ except Exception:
+ return s
+
+
+def fn_kv_csv(s: str) -> dict[str, dict[str, object]]:
+ """
+ Parse a string of comma-separated triplets: fn:key:value
+
+ Returns:
+ dict[fn_name] -> dict[key] = parsed_value
+
+ Example:
+ "fn1:width:1920,fn1:height:1080,fn2:quality:0.8"
+ -> {"fn1": {"width": 1920, "height": 1080}, "fn2": {"quality": 0.8}}
+ """
+ result: dict[str, dict[str, object]] = {}
+ if not s:
+ return result
+
+ for item in s.split(","):
+ if not item:
+ continue
+ parts = item.split(":", 2) # allow value to contain ":" beyond first two separators
+ if len(parts) < 3:
+ raise argparse.ArgumentTypeError(f"Bad item '{item}', expected FN:KEY:VALUE")
+ fn, key, raw_val = parts[0], parts[1], parts[2]
+ # If you need to allow colons in values, join leftover parts:
+ # fn, key, raw_val = parts[0], parts[1], ":".join(parts[2:])
+
+ if not fn:
+ raise argparse.ArgumentTypeError(f"Bad item '{item}': empty function name")
+ if not key:
+ raise argparse.ArgumentTypeError(f"Bad item '{item}': empty key")
+
+ val = parse_scalar(raw_val)
+ bucket = result.setdefault(fn, {})
+ bucket[key] = val
+ return result
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/camera_trj_helpers.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/camera_trj_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..83624f359553908abd28af2d59f2066a7f7a7b15
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/camera_trj_helpers.py
@@ -0,0 +1,479 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import cv2
+import numpy as np
+import torch
+import torch.nn.functional as F
+from einops import einsum, rearrange, reduce
+
+try:
+ from scipy.spatial.transform import Rotation as R
+except ImportError:
+ from depth_anything_3.utils.logger import logger
+
+ logger.warn("Dependency 'scipy' not found. Required for interpolating camera trajectory.")
+
+from depth_anything_3.utils.geometry import as_homogeneous
+
+
+@torch.no_grad()
+def render_stabilization_path(poses, k_size=45):
+ """Rendering stabilized camera path.
+ poses: [batch, 4, 4] or [batch, 3, 4],
+ return:
+ smooth path: [batch 4 4]"""
+ num_frames = poses.shape[0]
+ device = poses.device
+ dtype = poses.dtype
+
+ # Early exit for trivial cases
+ if num_frames <= 1:
+ return as_homogeneous(poses)
+
+ # Make k_size safe: positive odd and not larger than num_frames
+ # 1) Ensure odd
+ if k_size < 1:
+ k_size = 1
+ if k_size % 2 == 0:
+ k_size += 1
+ # 2) Cap to num_frames (keep odd)
+ max_odd = num_frames if (num_frames % 2 == 1) else (num_frames - 1)
+ if max_odd < 1:
+ max_odd = 1 # covers num_frames == 0 theoretically
+ k_size = min(k_size, max_odd)
+ # 3) enforce a minimum of 3 when possible (for better smoothing)
+ if num_frames >= 3 and k_size < 3:
+ k_size = 3
+
+ input_poses = []
+ for i in range(num_frames):
+ input_poses.append(
+ torch.cat([poses[i, :3, 0:1], poses[i, :3, 1:2], poses[i, :3, 3:4]], dim=-1)
+ )
+ input_poses = torch.stack(input_poses) # (num_frames, 3, 3)
+
+ # Prepare Gaussian kernel
+ gaussian_kernel = cv2.getGaussianKernel(ksize=k_size, sigma=-1).astype(np.float32).squeeze()
+ gaussian_kernel = torch.tensor(gaussian_kernel, dtype=dtype, device=device).view(1, 1, -1)
+ pad = k_size // 2
+
+ output_vectors = []
+ for idx in range(3): # For r1, r2, t
+ vec = (
+ input_poses[:, :, idx].T.unsqueeze(0).unsqueeze(0)
+ ) # (1, 1, 3, num_frames) -> (1, 1, 3, num_frames)
+ # But actually, we want (batch=3, channel=1, width=num_frames)
+ # So:
+ vec = input_poses[:, :, idx].T.unsqueeze(1) # (3, 1, num_frames)
+ vec_padded = F.pad(vec, (pad, pad), mode="reflect")
+ filtered = F.conv1d(vec_padded, gaussian_kernel)
+ output_vectors.append(filtered.squeeze(1).T) # (num_frames, 3)
+
+ output_r1, output_r2, output_t = output_vectors # Each is (num_frames, 3)
+
+ # Normalize r1 and r2
+ output_r1 = output_r1 / output_r1.norm(dim=-1, keepdim=True)
+ output_r2 = output_r2 / output_r2.norm(dim=-1, keepdim=True)
+
+ output_poses = []
+ for i in range(num_frames):
+ output_r3 = torch.linalg.cross(output_r1[i], output_r2[i])
+ render_pose = torch.cat(
+ [
+ output_r1[i].unsqueeze(-1),
+ output_r2[i].unsqueeze(-1),
+ output_r3.unsqueeze(-1),
+ output_t[i].unsqueeze(-1),
+ ],
+ dim=-1,
+ )
+ output_poses.append(render_pose[:3, :])
+ output_poses = as_homogeneous(torch.stack(output_poses, dim=0))
+
+ return output_poses
+
+
+@torch.no_grad()
+def render_wander_path(
+ cam2world: torch.Tensor,
+ intrinsic: torch.Tensor,
+ h: int,
+ w: int,
+ num_frames: int = 120,
+ max_disp: float = 48.0,
+):
+ device, dtype = cam2world.device, cam2world.dtype
+ fx = intrinsic[0, 0] * w
+ r = max_disp / fx
+ th = torch.linspace(0, 2.0 * torch.pi, steps=num_frames, device=device, dtype=dtype)
+ x = r * torch.sin(th)
+ yz = r * torch.cos(th) / 3.0
+ T = torch.eye(4, device=device, dtype=dtype).unsqueeze(0).repeat(num_frames, 1, 1)
+ T[:, :3, 3] = torch.stack([x, yz, yz], dim=-1) * -1.0
+ c2ws = cam2world.unsqueeze(0) @ T
+ # Start at reference pose and end back at reference pose
+ c2ws = torch.cat([cam2world.unsqueeze(0), c2ws, cam2world.unsqueeze(0)], dim=0)
+ Ks = intrinsic.unsqueeze(0).repeat(c2ws.shape[0], 1, 1)
+ return c2ws, Ks
+
+
+@torch.no_grad()
+def render_dolly_zoom_path(
+ cam2world: torch.Tensor,
+ intrinsic: torch.Tensor,
+ h: int,
+ w: int,
+ num_frames: int = 120,
+ max_disp: float = 0.1,
+ D_focus: float = 10.0,
+):
+ device, dtype = cam2world.device, cam2world.dtype
+ fx0, fy0 = intrinsic[0, 0] * w, intrinsic[1, 1] * h
+ t = torch.linspace(0.0, 2.0, steps=num_frames, device=device, dtype=dtype)
+ z = 0.5 * (1.0 - torch.cos(torch.pi * t)) * max_disp
+ T = torch.eye(4, device=device, dtype=dtype).unsqueeze(0).repeat(num_frames, 1, 1)
+ T[:, 2, 3] = -z
+ c2ws = cam2world.unsqueeze(0) @ T
+ Df = torch.as_tensor(D_focus, device=device, dtype=dtype)
+ scale = (Df / (Df + z)).clamp(min=1e-6)
+ Ks = intrinsic.unsqueeze(0).repeat(num_frames, 1, 1)
+ Ks[:, 0, 0] = (fx0 * scale) / w
+ Ks[:, 1, 1] = (fy0 * scale) / h
+ return c2ws, Ks
+
+
+@torch.no_grad()
+def interpolate_intrinsics(
+ initial: torch.Tensor, # "*#batch 3 3"
+ final: torch.Tensor, # "*#batch 3 3"
+ t: torch.Tensor, # " time_step"
+) -> torch.Tensor: # "*batch time_step 3 3"
+ initial = rearrange(initial, "... i j -> ... () i j")
+ final = rearrange(final, "... i j -> ... () i j")
+ t = rearrange(t, "t -> t () ()")
+ return initial + (final - initial) * t
+
+
+def intersect_rays(
+ a_origins: torch.Tensor, # "*#batch dim"
+ a_directions: torch.Tensor, # "*#batch dim"
+ b_origins: torch.Tensor, # "*#batch dim"
+ b_directions: torch.Tensor, # "*#batch dim"
+) -> torch.Tensor: # "*batch dim"
+ """Compute the least-squares intersection of rays. Uses the math from here:
+ https://math.stackexchange.com/a/1762491/286022
+ """
+
+ # Broadcast and stack the tensors.
+ a_origins, a_directions, b_origins, b_directions = torch.broadcast_tensors(
+ a_origins, a_directions, b_origins, b_directions
+ )
+ origins = torch.stack((a_origins, b_origins), dim=-2)
+ directions = torch.stack((a_directions, b_directions), dim=-2)
+
+ # Compute n_i * n_i^T - eye(3) from the equation.
+ n = einsum(directions, directions, "... n i, ... n j -> ... n i j")
+ n = n - torch.eye(3, dtype=origins.dtype, device=origins.device)
+
+ # Compute the left-hand side of the equation.
+ lhs = reduce(n, "... n i j -> ... i j", "sum")
+
+ # Compute the right-hand side of the equation.
+ rhs = einsum(n, origins, "... n i j, ... n j -> ... n i")
+ rhs = reduce(rhs, "... n i -> ... i", "sum")
+
+ # Left-matrix-multiply both sides by the inverse of lhs to find p.
+ return torch.linalg.lstsq(lhs, rhs).solution
+
+
+def normalize(a: torch.Tensor) -> torch.Tensor: # "*#batch dim" -> "*#batch dim"
+ return a / a.norm(dim=-1, keepdim=True)
+
+
+def generate_coordinate_frame(
+ y: torch.Tensor, # "*#batch 3"
+ z: torch.Tensor, # "*#batch 3"
+) -> torch.Tensor: # "*batch 3 3"
+ """Generate a coordinate frame given perpendicular, unit-length Y and Z vectors."""
+ y, z = torch.broadcast_tensors(y, z)
+ return torch.stack([y.cross(z, dim=-1), y, z], dim=-1)
+
+
+def generate_rotation_coordinate_frame(
+ a: torch.Tensor, # "*#batch 3"
+ b: torch.Tensor, # "*#batch 3"
+ eps: float = 1e-4,
+) -> torch.Tensor: # "*batch 3 3"
+ """Generate a coordinate frame where the Y direction is normal to the plane defined
+ by unit vectors a and b. The other axes are arbitrary."""
+ device = a.device
+
+ # Replace every entry in b that's parallel to the corresponding entry in a with an
+ # arbitrary vector.
+ b = b.detach().clone()
+ parallel = (einsum(a, b, "... i, ... i -> ...").abs() - 1).abs() < eps
+ b[parallel] = torch.tensor([0, 0, 1], dtype=b.dtype, device=device)
+ parallel = (einsum(a, b, "... i, ... i -> ...").abs() - 1).abs() < eps
+ b[parallel] = torch.tensor([0, 1, 0], dtype=b.dtype, device=device)
+
+ # Generate the coordinate frame. The initial cross product defines the plane.
+ return generate_coordinate_frame(normalize(torch.linalg.cross(a, b)), a)
+
+
+def matrix_to_euler(
+ rotations: torch.Tensor, # "*batch 3 3"
+ pattern: str,
+) -> torch.Tensor: # "*batch 3"
+ *batch, _, _ = rotations.shape
+ rotations = rotations.reshape(-1, 3, 3)
+ angles_np = R.from_matrix(rotations.detach().cpu().numpy()).as_euler(pattern)
+ rotations = torch.tensor(angles_np, dtype=rotations.dtype, device=rotations.device)
+ return rotations.reshape(*batch, 3)
+
+
+def euler_to_matrix(
+ rotations: torch.Tensor, # "*batch 3"
+ pattern: str,
+) -> torch.Tensor: # "*batch 3 3"
+ *batch, _ = rotations.shape
+ rotations = rotations.reshape(-1, 3)
+ matrix_np = R.from_euler(pattern, rotations.detach().cpu().numpy()).as_matrix()
+ rotations = torch.tensor(matrix_np, dtype=rotations.dtype, device=rotations.device)
+ return rotations.reshape(*batch, 3, 3)
+
+
+def extrinsics_to_pivot_parameters(
+ extrinsics: torch.Tensor, # "*#batch 4 4"
+ pivot_coordinate_frame: torch.Tensor, # "*#batch 3 3"
+ pivot_point: torch.Tensor, # "*#batch 3"
+) -> torch.Tensor: # "*batch 5"
+ """Convert the extrinsics to a representation with 5 degrees of freedom:
+ 1. Distance from pivot point in the "X" (look cross pivot axis) direction.
+ 2. Distance from pivot point in the "Y" (pivot axis) direction.
+ 3. Distance from pivot point in the Z (look) direction
+ 4. Angle in plane
+ 5. Twist (rotation not in plane)
+ """
+
+ # The pivot coordinate frame's Z axis is normal to the plane.
+ pivot_axis = pivot_coordinate_frame[..., :, 1]
+
+ # Compute the translation elements of the pivot parametrization.
+ translation_frame = generate_coordinate_frame(pivot_axis, extrinsics[..., :3, 2])
+ origin = extrinsics[..., :3, 3]
+ delta = pivot_point - origin
+ translation = einsum(translation_frame, delta, "... i j, ... i -> ... j")
+
+ # Add the rotation elements of the pivot parametrization.
+ inverted = pivot_coordinate_frame.inverse() @ extrinsics[..., :3, :3]
+ y, _, z = matrix_to_euler(inverted, "YXZ").unbind(dim=-1)
+
+ return torch.cat([translation, y[..., None], z[..., None]], dim=-1)
+
+
+def pivot_parameters_to_extrinsics(
+ parameters: torch.Tensor, # "*#batch 5"
+ pivot_coordinate_frame: torch.Tensor, # "*#batch 3 3"
+ pivot_point: torch.Tensor, # "*#batch 3"
+) -> torch.Tensor: # "*batch 4 4"
+ translation, y, z = parameters.split((3, 1, 1), dim=-1)
+
+ euler = torch.cat((y, torch.zeros_like(y), z), dim=-1)
+ rotation = pivot_coordinate_frame @ euler_to_matrix(euler, "YXZ")
+
+ # The pivot coordinate frame's Z axis is normal to the plane.
+ pivot_axis = pivot_coordinate_frame[..., :, 1]
+
+ translation_frame = generate_coordinate_frame(pivot_axis, rotation[..., :3, 2])
+ delta = einsum(translation_frame, translation, "... i j, ... j -> ... i")
+ origin = pivot_point - delta
+
+ *batch, _ = origin.shape
+ extrinsics = torch.eye(4, dtype=parameters.dtype, device=parameters.device)
+ extrinsics = extrinsics.broadcast_to((*batch, 4, 4)).clone()
+ extrinsics[..., 3, 3] = 1
+ extrinsics[..., :3, :3] = rotation
+ extrinsics[..., :3, 3] = origin
+ return extrinsics
+
+
+def interpolate_circular(
+ a: torch.Tensor, # "*#batch"
+ b: torch.Tensor, # "*#batch"
+ t: torch.Tensor, # "*#batch"
+) -> torch.Tensor: # " *batch"
+ a, b, t = torch.broadcast_tensors(a, b, t)
+
+ tau = 2 * torch.pi
+ a = a % tau
+ b = b % tau
+
+ # Consider piecewise edge cases.
+ d = (b - a).abs()
+ a_left = a - tau
+ d_left = (b - a_left).abs()
+ a_right = a + tau
+ d_right = (b - a_right).abs()
+ use_d = (d < d_left) & (d < d_right)
+ use_d_left = (d_left < d_right) & (~use_d)
+ use_d_right = (~use_d) & (~use_d_left)
+
+ result = a + (b - a) * t
+ result[use_d_left] = (a_left + (b - a_left) * t)[use_d_left]
+ result[use_d_right] = (a_right + (b - a_right) * t)[use_d_right]
+
+ return result
+
+
+def interpolate_pivot_parameters(
+ initial: torch.Tensor, # "*#batch 5"
+ final: torch.Tensor, # "*#batch 5"
+ t: torch.Tensor, # " time_step"
+) -> torch.Tensor: # "*batch time_step 5"
+ initial = rearrange(initial, "... d -> ... () d")
+ final = rearrange(final, "... d -> ... () d")
+ t = rearrange(t, "t -> t ()")
+ ti, ri = initial.split((3, 2), dim=-1)
+ tf, rf = final.split((3, 2), dim=-1)
+
+ t_lerp = ti + (tf - ti) * t
+ r_lerp = interpolate_circular(ri, rf, t)
+
+ return torch.cat((t_lerp, r_lerp), dim=-1)
+
+
+@torch.no_grad()
+def interpolate_extrinsics(
+ initial: torch.Tensor, # "*#batch 4 4"
+ final: torch.Tensor, # "*#batch 4 4"
+ t: torch.Tensor, # " time_step"
+ eps: float = 1e-4,
+) -> torch.Tensor: # "*batch time_step 4 4"
+ """Interpolate extrinsics by rotating around their "focus point," which is the
+ least-squares intersection between the look vectors of the initial and final
+ extrinsics.
+ """
+
+ initial = initial.type(torch.float64)
+ final = final.type(torch.float64)
+ t = t.type(torch.float64)
+
+ # Based on the dot product between the look vectors, pick from one of two cases:
+ # 1. Look vectors are parallel: interpolate about their origins' midpoint.
+ # 3. Look vectors aren't parallel: interpolate about their focus point.
+ initial_look = initial[..., :3, 2]
+ final_look = final[..., :3, 2]
+ dot_products = einsum(initial_look, final_look, "... i, ... i -> ...")
+ parallel_mask = (dot_products.abs() - 1).abs() < eps
+
+ # Pick focus points.
+ initial_origin = initial[..., :3, 3]
+ final_origin = final[..., :3, 3]
+ pivot_point = 0.5 * (initial_origin + final_origin)
+ pivot_point[~parallel_mask] = intersect_rays(
+ initial_origin[~parallel_mask],
+ initial_look[~parallel_mask],
+ final_origin[~parallel_mask],
+ final_look[~parallel_mask],
+ )
+
+ # Convert to pivot parameters.
+ pivot_frame = generate_rotation_coordinate_frame(initial_look, final_look, eps=eps)
+ initial_params = extrinsics_to_pivot_parameters(initial, pivot_frame, pivot_point)
+ final_params = extrinsics_to_pivot_parameters(final, pivot_frame, pivot_point)
+
+ # Interpolate the pivot parameters.
+ interpolated_params = interpolate_pivot_parameters(initial_params, final_params, t)
+
+ # Convert back.
+ return pivot_parameters_to_extrinsics(
+ interpolated_params.type(torch.float32),
+ rearrange(pivot_frame, "... i j -> ... () i j").type(torch.float32),
+ rearrange(pivot_point, "... xyz -> ... () xyz").type(torch.float32),
+ )
+
+
+@torch.no_grad()
+def generate_wobble_transformation(
+ radius: torch.Tensor, # "*#batch"
+ t: torch.Tensor, # " time_step"
+ num_rotations: int = 1,
+ scale_radius_with_t: bool = True,
+) -> torch.Tensor: # "*batch time_step 4 4"]:
+ # Generate a translation in the image plane.
+ tf = torch.eye(4, dtype=torch.float32, device=t.device)
+ tf = tf.broadcast_to((*radius.shape, t.shape[0], 4, 4)).clone()
+ radius = radius[..., None]
+ if scale_radius_with_t:
+ radius = radius * t
+ tf[..., 0, 3] = torch.sin(2 * torch.pi * num_rotations * t) * radius
+ tf[..., 1, 3] = -torch.cos(2 * torch.pi * num_rotations * t) * radius
+ return tf
+
+
+@torch.no_grad()
+def render_wobble_inter_path(
+ cam2world: torch.Tensor, intr_normed: torch.Tensor, inter_len: int, n_skip: int = 3
+):
+ """
+ cam2world: [batch, 4, 4],
+ intr_normed: [batch, 3, 3]
+ """
+ frame_per_round = n_skip * inter_len
+ num_rotations = 1
+
+ t = torch.linspace(0, 1, frame_per_round, dtype=torch.float32, device=cam2world.device)
+ # t = (torch.cos(torch.pi * (t + 1)) + 1) / 2
+ tgt_c2w_b = []
+ tgt_intr_b = []
+ for b_idx in range(cam2world.shape[0]):
+ tgt_c2w = []
+ tgt_intr = []
+ for cur_idx in range(0, cam2world.shape[1] - n_skip, n_skip):
+ origin_a = cam2world[b_idx, cur_idx, :3, 3]
+ origin_b = cam2world[b_idx, cur_idx + n_skip, :3, 3]
+ delta = (origin_a - origin_b).norm(dim=-1)
+ if cur_idx == 0:
+ delta_prev = delta
+ else:
+ delta = (delta_prev + delta) / 2
+ delta_prev = delta
+ tf = generate_wobble_transformation(
+ radius=delta * 0.5,
+ t=t,
+ num_rotations=num_rotations,
+ scale_radius_with_t=False,
+ )
+ cur_extrs = (
+ interpolate_extrinsics(
+ cam2world[b_idx, cur_idx],
+ cam2world[b_idx, cur_idx + n_skip],
+ t,
+ )
+ @ tf
+ )
+ tgt_c2w.append(cur_extrs[(0 if cur_idx == 0 else 1) :])
+ tgt_intr.append(
+ interpolate_intrinsics(
+ intr_normed[b_idx, cur_idx],
+ intr_normed[b_idx, cur_idx + n_skip],
+ t,
+ )[(0 if cur_idx == 0 else 1) :]
+ )
+ tgt_c2w_b.append(torch.cat(tgt_c2w))
+ tgt_intr_b.append(torch.cat(tgt_intr))
+ tgt_c2w = torch.stack(tgt_c2w_b) # b v 4 4
+ tgt_intr = torch.stack(tgt_intr_b) # b v 3 3
+ return tgt_c2w, tgt_intr
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/constants.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3fb4f11827c9b8f9373df7ce10141071d89300e
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/constants.py
@@ -0,0 +1,18 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+DEFAULT_MODEL = "depth-anything/DA3NESTED-GIANT-LARGE"
+DEFAULT_EXPORT_DIR = "workspace/gallery/scene"
+DEFAULT_GALLERY_DIR = "workspace/gallery"
+DEFAULT_GRADIO_DIR = "workspace/gradio"
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/__init__.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3e4c657983b19a75865ad7d3329f9f037f60cd6
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/__init__.py
@@ -0,0 +1,59 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from depth_anything_3.specs import Prediction
+from depth_anything_3.utils.export.gs import export_to_gs_ply, export_to_gs_video
+
+from .colmap import export_to_colmap
+from .depth_vis import export_to_depth_vis
+from .feat_vis import export_to_feat_vis
+from .glb import export_to_glb
+from .npz import export_to_mini_npz, export_to_npz
+
+
+def export(
+ prediction: Prediction,
+ export_format: str,
+ export_dir: str,
+ **kwargs,
+):
+ if "-" in export_format:
+ export_formats = export_format.split("-")
+ for export_format in export_formats:
+ export(prediction, export_format, export_dir, **kwargs)
+ return # Prevent falling through to single-format handling
+
+ if export_format == "glb":
+ export_to_glb(prediction, export_dir, **kwargs.get(export_format, {}))
+ elif export_format == "mini_npz":
+ export_to_mini_npz(prediction, export_dir)
+ elif export_format == "npz":
+ export_to_npz(prediction, export_dir)
+ elif export_format == "feat_vis":
+ export_to_feat_vis(prediction, export_dir, **kwargs.get(export_format, {}))
+ elif export_format == "depth_vis":
+ export_to_depth_vis(prediction, export_dir)
+ elif export_format == "gs_ply":
+ export_to_gs_ply(prediction, export_dir, **kwargs.get(export_format, {}))
+ elif export_format == "gs_video":
+ export_to_gs_video(prediction, export_dir, **kwargs.get(export_format, {}))
+ elif export_format == "colmap":
+ export_to_colmap(prediction, export_dir, **kwargs.get(export_format, {}))
+ else:
+ raise ValueError(f"Unsupported export format: {export_format}")
+
+
+__all__ = [
+ export,
+]
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/colmap.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/colmap.py
new file mode 100644
index 0000000000000000000000000000000000000000..81aa71365be3154dcfb5467d723fb35955418560
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/colmap.py
@@ -0,0 +1,150 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import pycolmap
+import cv2 as cv
+import numpy as np
+
+from PIL import Image
+
+from depth_anything_3.specs import Prediction
+from depth_anything_3.utils.logger import logger
+
+from .glb import _depths_to_world_points_with_colors
+
+
+def export_to_colmap(
+ prediction: Prediction,
+ export_dir: str,
+ image_paths: list[str],
+ conf_thresh_percentile: float = 40.0,
+ process_res_method: str = "keep",
+) -> None:
+ # 1. Data preparation
+ conf_thresh = np.percentile(prediction.conf, conf_thresh_percentile)
+ points, colors = _depths_to_world_points_with_colors(
+ prediction.depth,
+ prediction.intrinsics,
+ prediction.extrinsics, # w2c
+ prediction.processed_images,
+ prediction.conf,
+ conf_thresh,
+ )
+ num_points = len(points)
+ logger.info(f"Exporting to COLMAP with {num_points} points")
+ num_frames = len(prediction.processed_images)
+ h, w = prediction.processed_images.shape[1:3]
+ points_xyf = _create_xyf(num_frames, h, w)
+ points_xyf = points_xyf[prediction.conf >= conf_thresh]
+
+ # 2. Set Reconstruction
+ reconstruction = pycolmap.Reconstruction()
+
+ point3d_ids = []
+ for vidx in range(num_points):
+ point3d_id = reconstruction.add_point3D(points[vidx], pycolmap.Track(), colors[vidx])
+ point3d_ids.append(point3d_id)
+
+ for fidx in range(num_frames):
+ orig_w, orig_h = Image.open(image_paths[fidx]).size
+
+ intrinsic = prediction.intrinsics[fidx]
+ if process_res_method.endswith("resize") or process_res_method in ("keep", "original"):
+ intrinsic[:1] *= orig_w / w
+ intrinsic[1:2] *= orig_h / h
+ elif process_res_method == "crop":
+ raise NotImplementedError("COLMAP export for crop method is not implemented")
+ else:
+ raise ValueError(f"Unknown process_res_method: {process_res_method}")
+
+ pycolmap_intri = np.array(
+ [intrinsic[0, 0], intrinsic[1, 1], intrinsic[0, 2], intrinsic[1, 2]]
+ )
+
+ extrinsic = prediction.extrinsics[fidx]
+ cam_from_world = pycolmap.Rigid3d(pycolmap.Rotation3d(extrinsic[:3, :3]), extrinsic[:3, 3])
+
+ # set and add camera
+ camera = pycolmap.Camera()
+ camera.camera_id = fidx + 1
+ camera.model = pycolmap.CameraModelId.PINHOLE
+ camera.width = orig_w
+ camera.height = orig_h
+ camera.params = pycolmap_intri
+ reconstruction.add_camera(camera)
+
+ # set and add rig (from camera)
+ rig = pycolmap.Rig()
+ rig.rig_id = camera.camera_id
+ rig.add_ref_sensor(camera.sensor_id)
+ reconstruction.add_rig(rig)
+
+ # set image
+ image = pycolmap.Image()
+ image.image_id = fidx + 1
+ image.camera_id = camera.camera_id
+
+ # set and add frame (from image)
+ frame = pycolmap.Frame()
+ frame.frame_id = image.image_id
+ frame.rig_id = camera.camera_id
+ frame.add_data_id(image.data_id)
+ frame.rig_from_world = cam_from_world
+ reconstruction.add_frame(frame)
+
+ # set point2d and update track
+ point2d_list = []
+ points_in_frame = points_xyf[:, 2].astype(np.int32) == fidx
+ for vidx in np.where(points_in_frame)[0]:
+ point2d = points_xyf[vidx][:2]
+ point2d[0] *= orig_w / w
+ point2d[1] *= orig_h / h
+ point3d_id = point3d_ids[vidx]
+ point2d_list.append(pycolmap.Point2D(point2d, point3d_id))
+ reconstruction.point3D(point3d_id).track.add_element(
+ image.image_id, len(point2d_list) - 1
+ )
+
+ # set and add image
+ image.frame_id = image.image_id
+ image.name = os.path.basename(image_paths[fidx])
+ image.points2D = pycolmap.Point2DList(point2d_list)
+ reconstruction.add_image(image)
+
+ # 3. Export
+ reconstruction.write(export_dir)
+
+
+def _create_xyf(num_frames, height, width):
+ """
+ Creates a grid of pixel coordinates and frame indices (fidx) for all frames.
+ """
+ # Create coordinate grids for a single frame
+ y_grid, x_grid = np.indices((height, width), dtype=np.int32)
+ x_grid = x_grid[np.newaxis, :, :]
+ y_grid = y_grid[np.newaxis, :, :]
+
+ # Broadcast to all frames
+ x_coords = np.broadcast_to(x_grid, (num_frames, height, width))
+ y_coords = np.broadcast_to(y_grid, (num_frames, height, width))
+
+ # Create frame indices and broadcast
+ f_idx = np.arange(num_frames, dtype=np.int32)[:, np.newaxis, np.newaxis]
+ f_coords = np.broadcast_to(f_idx, (num_frames, height, width))
+
+ # Stack coordinates and frame indices
+ points_xyf = np.stack((x_coords, y_coords, f_coords), axis=-1)
+
+ return points_xyf
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/depth_vis.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/depth_vis.py
new file mode 100644
index 0000000000000000000000000000000000000000..8accc04e92985e26b8d78a56db80989be515aac7
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/depth_vis.py
@@ -0,0 +1,41 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import imageio
+import numpy as np
+
+from depth_anything_3.specs import Prediction
+from depth_anything_3.utils.visualize import visualize_depth
+
+
+def export_to_depth_vis(
+ prediction: Prediction,
+ export_dir: str,
+):
+ # Use prediction.processed_images, which is already processed image data
+ if prediction.processed_images is None:
+ raise ValueError("prediction.processed_images is required but not available")
+
+ images_u8 = prediction.processed_images # (N,H,W,3) uint8
+
+ os.makedirs(os.path.join(export_dir, "depth_vis"), exist_ok=True)
+ for idx in range(prediction.depth.shape[0]):
+ depth_vis = visualize_depth(prediction.depth[idx])
+ image_vis = images_u8[idx]
+ depth_vis = depth_vis.astype(np.uint8)
+ image_vis = image_vis.astype(np.uint8)
+ vis_image = np.concatenate([image_vis, depth_vis], axis=1)
+ save_path = os.path.join(export_dir, f"depth_vis/{idx:04d}.jpg")
+ imageio.imwrite(save_path, vis_image, quality=95)
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/feat_vis.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/feat_vis.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3dc780ea509d8b5f4660212c2914d3e81f2364a
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/feat_vis.py
@@ -0,0 +1,65 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import cv2
+import imageio
+import numpy as np
+from tqdm.auto import tqdm
+
+from depth_anything_3.utils.parallel_utils import async_call
+from depth_anything_3.utils.pca_utils import PCARGBVisualizer
+
+
+@async_call
+def export_to_feat_vis(
+ prediction,
+ export_dir,
+ fps=15,
+):
+ """Export feature visualization with PCA.
+
+ Args:
+ prediction: Model prediction containing feature maps
+ export_dir: Directory to export results
+ fps: Frame rate for output video (default: 15)
+ """
+ out_dir = os.path.join(export_dir, "feat_vis")
+ os.makedirs(out_dir, exist_ok=True)
+
+ images = prediction.processed_images
+ for k, v in prediction.aux.items():
+ if not k.startswith("feat_layer_"):
+ continue
+ os.makedirs(os.path.join(out_dir, k), exist_ok=True)
+ viz = PCARGBVisualizer(basis_mode="fixed", percentile_mode="global", clip_percent=10.0)
+ viz.fit_reference(v)
+ feats_vis = viz.transform_video(v)
+ for idx in tqdm(range(len(feats_vis))):
+ img = images[idx]
+ feat_vis = (feats_vis[idx] * 255).astype(np.uint8)
+ feat_vis = cv2.resize(
+ feat_vis, (img.shape[1], img.shape[0]), interpolation=cv2.INTER_NEAREST
+ )
+ save_path = os.path.join(out_dir, f"{k}/{idx:06d}.jpg")
+ save = np.concatenate([img, feat_vis], axis=1)
+ imageio.imwrite(save_path, save, quality=95)
+ cmd = (
+ "ffmpeg -loglevel error -hide_banner -y "
+ f"-framerate {fps} -start_number 0 "
+ f"-i {out_dir}/{k}/%06d.jpg "
+ f"-c:v libx264 -pix_fmt yuv420p "
+ f"{out_dir}/{k}.mp4"
+ )
+ os.system(cmd)
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/glb.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/glb.py
new file mode 100644
index 0000000000000000000000000000000000000000..ece1379d98fedceba03cb53ee4cb62bd49a1ae4f
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/glb.py
@@ -0,0 +1,432 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import annotations
+
+import os
+import numpy as np
+import trimesh
+
+from depth_anything_3.specs import Prediction
+from depth_anything_3.utils.logger import logger
+
+from .depth_vis import export_to_depth_vis
+
+
+def set_sky_depth(prediction: Prediction, sky_mask: np.ndarray, sky_depth_def: float = 98.0):
+ non_sky_mask = ~sky_mask
+ valid_depth = prediction.depth[non_sky_mask]
+ if valid_depth.size > 0:
+ max_depth = np.percentile(valid_depth, sky_depth_def)
+ prediction.depth[sky_mask] = max_depth
+
+
+def get_conf_thresh(
+ prediction: Prediction,
+ sky_mask: np.ndarray,
+ conf_thresh: float,
+ conf_thresh_percentile: float = 10.0,
+ ensure_thresh_percentile: float = 90.0,
+):
+ if sky_mask is not None and (~sky_mask).sum() > 10:
+ conf_pixels = prediction.conf[~sky_mask]
+ else:
+ conf_pixels = prediction.conf
+ lower = np.percentile(conf_pixels, conf_thresh_percentile)
+ upper = np.percentile(conf_pixels, ensure_thresh_percentile)
+ conf_thresh = min(max(conf_thresh, lower), upper)
+ return conf_thresh
+
+
+def export_to_glb(
+ prediction: Prediction,
+ export_dir: str,
+ num_max_points: int = 1_000_000,
+ conf_thresh: float = 1.05,
+ filter_black_bg: bool = False,
+ filter_white_bg: bool = False,
+ conf_thresh_percentile: float = 40.0,
+ ensure_thresh_percentile: float = 90.0,
+ sky_depth_def: float = 98.0,
+ show_cameras: bool = True,
+ camera_size: float = 0.03,
+ export_depth_vis: bool = True,
+) -> str:
+ """Generate a 3D point cloud and camera wireframes and export them as a ``.glb`` file.
+
+ The function builds a point cloud from the predicted depth maps, aligns it to the
+ first camera in glTF coordinates (X-right, Y-up, Z-backward), optionally draws
+ camera wireframes, and writes the result to ``scene.glb``. Auxiliary assets such as
+ depth visualizations can also be generated alongside the main export.
+
+ Args:
+ prediction: Model prediction containing depth, confidence, intrinsics, extrinsics,
+ and pre-processed images.
+ export_dir: Output directory where the glTF assets will be written.
+ num_max_points: Maximum number of points retained after downsampling.
+ conf_thresh: Base confidence threshold used before percentile adjustments.
+ filter_black_bg: Mark near-black background pixels for removal during confidence filtering.
+ filter_white_bg: Mark near-white background pixels for removal during confidence filtering.
+ conf_thresh_percentile: Lower percentile used when adapting the confidence threshold.
+ ensure_thresh_percentile: Upper percentile clamp for the adaptive threshold.
+ sky_depth_def: Percentile used to fill sky pixels with plausible depth values.
+ show_cameras: Whether to render camera wireframes in the exported scene.
+ camera_size: Relative camera wireframe scale as a fraction of the scene diagonal.
+ export_depth_vis: Whether to export raster depth visualisations alongside the glTF.
+
+ Returns:
+ Path to the exported ``scene.glb`` file.
+ """
+ # 1) Use prediction.processed_images, which is already processed image data
+ assert (
+ prediction.processed_images is not None
+ ), "Export to GLB: prediction.processed_images is required but not available"
+ assert (
+ prediction.depth is not None
+ ), "Export to GLB: prediction.depth is required but not available"
+ assert (
+ prediction.intrinsics is not None
+ ), "Export to GLB: prediction.intrinsics is required but not available"
+ assert (
+ prediction.extrinsics is not None
+ ), "Export to GLB: prediction.extrinsics is required but not available"
+ assert (
+ prediction.conf is not None
+ ), "Export to GLB: prediction.conf is required but not available"
+ logger.info(f"conf_thresh_percentile: {conf_thresh_percentile}")
+ logger.info(f"num max points: {num_max_points}")
+ logger.info(f"Exporting to GLB with num_max_points: {num_max_points}")
+ if prediction.processed_images is None:
+ raise ValueError("prediction.processed_images is required but not available")
+
+ images_u8 = prediction.processed_images # (N,H,W,3) uint8
+
+ # 2) Sky processing (if sky_mask is provided)
+ if getattr(prediction, "sky_mask", None) is not None:
+ set_sky_depth(prediction, prediction.sky_mask, sky_depth_def)
+
+ # 3) Confidence threshold (if no conf, then no filtering)
+ if filter_black_bg:
+ prediction.conf[(prediction.processed_images < 16).all(axis=-1)] = 1.0
+ if filter_white_bg:
+ prediction.conf[(prediction.processed_images >= 240).all(axis=-1)] = 1.0
+ conf_thr = get_conf_thresh(
+ prediction,
+ getattr(prediction, "sky_mask", None),
+ conf_thresh,
+ conf_thresh_percentile,
+ ensure_thresh_percentile,
+ )
+
+ # 4) Back-project to world coordinates and get colors (world frame)
+ points, colors = _depths_to_world_points_with_colors(
+ prediction.depth,
+ prediction.intrinsics,
+ prediction.extrinsics, # w2c
+ images_u8,
+ prediction.conf,
+ conf_thr,
+ )
+
+ # 5) Based on first camera orientation + glTF axis system, center by point cloud,
+ # construct alignment transform, and apply to point cloud
+ A = _compute_alignment_transform_first_cam_glTF_center_by_points(
+ prediction.extrinsics[0], points
+ ) # (4,4)
+
+ if points.shape[0] > 0:
+ points = trimesh.transform_points(points, A)
+
+ # 6) Clean + downsample
+ points, colors = _filter_and_downsample(points, colors, num_max_points)
+
+ # 7) Assemble scene (add point cloud first)
+ scene = trimesh.Scene()
+ if scene.metadata is None:
+ scene.metadata = {}
+ scene.metadata["hf_alignment"] = A # For camera wireframes and external reuse
+
+ if points.shape[0] > 0:
+ pc = trimesh.points.PointCloud(vertices=points, colors=colors)
+ scene.add_geometry(pc)
+
+ # 8) Draw cameras (wireframe pyramids), using the same transform A
+ if show_cameras and prediction.intrinsics is not None and prediction.extrinsics is not None:
+ scene_scale = _estimate_scene_scale(points, fallback=1.0)
+ H, W = prediction.depth.shape[1:]
+ _add_cameras_to_scene(
+ scene=scene,
+ K=prediction.intrinsics,
+ ext_w2c=prediction.extrinsics,
+ image_sizes=[(H, W)] * prediction.depth.shape[0],
+ scale=scene_scale * camera_size,
+ )
+
+ # 9) Export
+ os.makedirs(export_dir, exist_ok=True)
+ out_path = os.path.join(export_dir, "scene.glb")
+ scene.export(out_path)
+
+ if export_depth_vis:
+ export_to_depth_vis(prediction, export_dir)
+ os.system(f"cp -r {export_dir}/depth_vis/0000.jpg {export_dir}/scene.jpg")
+ return out_path
+
+
+# =========================
+# utilities
+# =========================
+
+
+def _as_homogeneous44(ext: np.ndarray) -> np.ndarray:
+ """
+ Accept (4,4) or (3,4) extrinsic parameters, return (4,4) homogeneous matrix.
+ """
+ if ext.shape == (4, 4):
+ return ext
+ if ext.shape == (3, 4):
+ H = np.eye(4, dtype=ext.dtype)
+ H[:3, :4] = ext
+ return H
+ raise ValueError(f"extrinsic must be (4,4) or (3,4), got {ext.shape}")
+
+
+def _depths_to_world_points_with_colors(
+ depth: np.ndarray,
+ K: np.ndarray,
+ ext_w2c: np.ndarray,
+ images_u8: np.ndarray,
+ conf: np.ndarray | None,
+ conf_thr: float,
+) -> tuple[np.ndarray, np.ndarray]:
+ """
+ For each frame, transform (u,v,1) through K^{-1} to get rays,
+ multiply by depth to camera frame, then use (w2c)^{-1} to transform to world frame.
+ Simultaneously extract colors.
+ """
+ N, H, W = depth.shape
+ us, vs = np.meshgrid(np.arange(W), np.arange(H))
+ ones = np.ones_like(us)
+ pix = np.stack([us, vs, ones], axis=-1).reshape(-1, 3) # (H*W,3)
+
+ pts_all, col_all = [], []
+
+ for i in range(N):
+ d = depth[i] # (H,W)
+ valid = np.isfinite(d) & (d > 0)
+ if conf is not None:
+ valid &= conf[i] >= conf_thr
+ if not np.any(valid):
+ continue
+
+ d_flat = d.reshape(-1)
+ vidx = np.flatnonzero(valid.reshape(-1))
+
+ K_inv = np.linalg.inv(K[i]) # (3,3)
+ c2w = np.linalg.inv(_as_homogeneous44(ext_w2c[i])) # (4,4)
+
+ rays = K_inv @ pix[vidx].T # (3,M)
+ Xc = rays * d_flat[vidx][None, :] # (3,M)
+ Xc_h = np.vstack([Xc, np.ones((1, Xc.shape[1]))])
+ Xw = (c2w @ Xc_h)[:3].T.astype(np.float32) # (M,3)
+
+ cols = images_u8[i].reshape(-1, 3)[vidx].astype(np.uint8) # (M,3)
+
+ pts_all.append(Xw)
+ col_all.append(cols)
+
+ if len(pts_all) == 0:
+ return np.zeros((0, 3), dtype=np.float32), np.zeros((0, 3), dtype=np.uint8)
+
+ return np.concatenate(pts_all, 0), np.concatenate(col_all, 0)
+
+
+def _filter_and_downsample(points: np.ndarray, colors: np.ndarray, num_max: int):
+ if points.shape[0] == 0:
+ return points, colors
+ finite = np.isfinite(points).all(axis=1)
+ points, colors = points[finite], colors[finite]
+ if points.shape[0] > num_max:
+ idx = np.random.choice(points.shape[0], num_max, replace=False)
+ points, colors = points[idx], colors[idx]
+ return points, colors
+
+
+def _estimate_scene_scale(points: np.ndarray, fallback: float = 1.0) -> float:
+ if points.shape[0] < 2:
+ return fallback
+ lo = np.percentile(points, 5, axis=0)
+ hi = np.percentile(points, 95, axis=0)
+ diag = np.linalg.norm(hi - lo)
+ return float(diag if np.isfinite(diag) and diag > 0 else fallback)
+
+
+def _compute_alignment_transform_first_cam_glTF_center_by_points(
+ ext_w2c0: np.ndarray,
+ points_world: np.ndarray,
+) -> np.ndarray:
+ """Computes the transformation matrix to align the scene with glTF standards.
+
+ This function calculates a 4x4 homogeneous matrix that centers the scene's
+ point cloud and transforms its coordinate system from the computer vision (CV)
+ standard to the glTF standard.
+
+ The transformation process involves three main steps:
+ 1. **Initial Alignment**: Orients the world coordinate system to match the
+ first camera's view (x-right, y-down, z-forward).
+ 2. **Coordinate System Conversion**: Converts the CV camera frame to the
+ glTF frame (x-right, y-up, z-backward) by flipping the Y and Z axes.
+ 3. **Centering**: Translates the entire scene so that the median of the
+ point cloud becomes the new origin (0,0,0).
+
+ Returns:
+ A 4x4 homogeneous transformation matrix (torch.Tensor or np.ndarray)
+ that applies these transformations. A: X' = A @ [X;1]
+ """
+
+ w2c0 = _as_homogeneous44(ext_w2c0).astype(np.float64)
+
+ # CV -> glTF axis transformation
+ M = np.eye(4, dtype=np.float64)
+ M[1, 1] = -1.0 # flip Y
+ M[2, 2] = -1.0 # flip Z
+
+ # Don't center first
+ A_no_center = M @ w2c0
+
+ # Calculate point cloud center in new coordinate system (use median to resist outliers)
+ if points_world.shape[0] > 0:
+ pts_tmp = trimesh.transform_points(points_world, A_no_center)
+ center = np.median(pts_tmp, axis=0)
+ else:
+ center = np.zeros(3, dtype=np.float64)
+
+ T_center = np.eye(4, dtype=np.float64)
+ T_center[:3, 3] = -center
+
+ A = T_center @ A_no_center
+ return A
+
+
+def _add_cameras_to_scene(
+ scene: trimesh.Scene,
+ K: np.ndarray,
+ ext_w2c: np.ndarray,
+ image_sizes: list[tuple[int, int]],
+ scale: float,
+) -> None:
+ """Draws camera frustums to visualize their position and orientation.
+
+ This function renders each camera as a wireframe pyramid, originating from
+ the camera's center and extending to the corners of its imaging plane.
+
+ It reads the 'hf_alignment' metadata from the scene to ensure the
+ wireframes are correctly aligned with the 3D point cloud.
+ """
+ N = K.shape[0]
+ if N == 0:
+ return
+
+ # Alignment matrix consistent with point cloud (use identity matrix if missing)
+ A = None
+ try:
+ A = scene.metadata.get("hf_alignment", None) if scene.metadata else None
+ except Exception:
+ A = None
+ if A is None:
+ A = np.eye(4, dtype=np.float64)
+
+ for i in range(N):
+ H, W = image_sizes[i]
+ segs = _camera_frustum_lines(K[i], ext_w2c[i], W, H, scale) # (8,2,3) world frame
+ # Apply unified transformation
+ segs = trimesh.transform_points(segs.reshape(-1, 3), A).reshape(-1, 2, 3)
+ path = trimesh.load_path(segs)
+ color = _index_color_rgb(i, N)
+ if hasattr(path, "colors"):
+ path.colors = np.tile(color, (len(path.entities), 1))
+ scene.add_geometry(path)
+
+
+def _camera_frustum_lines(
+ K: np.ndarray, ext_w2c: np.ndarray, W: int, H: int, scale: float
+) -> np.ndarray:
+ corners = np.array(
+ [
+ [0, 0, 1.0],
+ [W - 1, 0, 1.0],
+ [W - 1, H - 1, 1.0],
+ [0, H - 1, 1.0],
+ ],
+ dtype=float,
+ ) # (4,3)
+
+ K_inv = np.linalg.inv(K)
+ c2w = np.linalg.inv(_as_homogeneous44(ext_w2c))
+
+ # camera center in world
+ Cw = (c2w @ np.array([0, 0, 0, 1.0]))[:3]
+
+ # rays -> z=1 plane points (camera frame)
+ rays = (K_inv @ corners.T).T
+ z = rays[:, 2:3]
+ z[z == 0] = 1.0
+ plane_cam = (rays / z) * scale # (4,3)
+
+ # to world
+ plane_w = []
+ for p in plane_cam:
+ pw = (c2w @ np.array([p[0], p[1], p[2], 1.0]))[:3]
+ plane_w.append(pw)
+ plane_w = np.stack(plane_w, 0) # (4,3)
+
+ segs = []
+ # center to corners
+ for k in range(4):
+ segs.append(np.stack([Cw, plane_w[k]], 0))
+ # rectangle edges
+ order = [0, 1, 2, 3, 0]
+ for a, b in zip(order[:-1], order[1:]):
+ segs.append(np.stack([plane_w[a], plane_w[b]], 0))
+
+ return np.stack(segs, 0) # (8,2,3)
+
+
+def _index_color_rgb(i: int, n: int) -> np.ndarray:
+ h = (i + 0.5) / max(n, 1)
+ s, v = 0.85, 0.95
+ r, g, b = _hsv_to_rgb(h, s, v)
+ return (np.array([r, g, b]) * 255).astype(np.uint8)
+
+
+def _hsv_to_rgb(h: float, s: float, v: float) -> tuple[float, float, float]:
+ i = int(h * 6.0)
+ f = h * 6.0 - i
+ p = v * (1.0 - s)
+ q = v * (1.0 - f * s)
+ t = v * (1.0 - (1.0 - f) * s)
+ i = i % 6
+ if i == 0:
+ r, g, b = v, t, p
+ elif i == 1:
+ r, g, b = q, v, p
+ elif i == 2:
+ r, g, b = p, v, t
+ elif i == 3:
+ r, g, b = p, q, v
+ elif i == 4:
+ r, g, b = t, p, v
+ else:
+ r, g, b = v, p, q
+ return r, g, b
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/gs.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/gs.py
new file mode 100644
index 0000000000000000000000000000000000000000..90077cf25651c7977c1c1da320b7c0931fdb18ba
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/gs.py
@@ -0,0 +1,154 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from typing import Literal, Optional
+import moviepy.editor as mpy
+import torch
+
+from depth_anything_3.model.utils.gs_renderer import run_renderer_in_chunk_w_trj_mode
+from depth_anything_3.specs import Prediction
+from depth_anything_3.utils.gsply_helpers import save_gaussian_ply
+from depth_anything_3.utils.layout_helpers import hcat, vcat
+from depth_anything_3.utils.visualize import vis_depth_map_tensor
+
+VIDEO_QUALITY_MAP = {
+ "low": {"crf": "28", "preset": "veryfast"},
+ "medium": {"crf": "23", "preset": "medium"},
+ "high": {"crf": "18", "preset": "slow"},
+}
+
+
+def export_to_gs_ply(
+ prediction: Prediction,
+ export_dir: str,
+ gs_views_interval: Optional[
+ int
+ ] = 1, # export GS every N views, useful for extremely dense inputs
+):
+ gs_world = prediction.gaussians
+ pred_depth = torch.from_numpy(prediction.depth).unsqueeze(-1).to(gs_world.means) # v h w 1
+ idx = 0
+ os.makedirs(os.path.join(export_dir, "gs_ply"), exist_ok=True)
+ save_path = os.path.join(export_dir, f"gs_ply/{idx:04d}.ply")
+ if gs_views_interval is None: # select around 12 views in total
+ gs_views_interval = max(pred_depth.shape[0] // 12, 1)
+ save_gaussian_ply(
+ gaussians=gs_world,
+ save_path=save_path,
+ ctx_depth=pred_depth,
+ shift_and_scale=False,
+ save_sh_dc_only=True,
+ gs_views_interval=gs_views_interval,
+ inv_opacity=True,
+ prune_by_depth_percent=0.9,
+ prune_border_gs=True,
+ match_3dgs_mcmc_dev=False,
+ )
+
+
+def export_to_gs_video(
+ prediction: Prediction,
+ export_dir: str,
+ extrinsics: Optional[torch.Tensor] = None, # render views' world2cam, "b v 4 4"
+ intrinsics: Optional[torch.Tensor] = None, # render views' unnormed intrinsics, "b v 3 3"
+ out_image_hw: Optional[tuple[int, int]] = None, # render views' resolution, (h, w)
+ chunk_size: Optional[int] = 4,
+ trj_mode: Literal[
+ "original",
+ "smooth",
+ "interpolate",
+ "interpolate_smooth",
+ "wander",
+ "dolly_zoom",
+ "extend",
+ "wobble_inter",
+ ] = "extend",
+ color_mode: Literal["RGB+D", "RGB+ED"] = "RGB+ED",
+ vis_depth: Optional[Literal["hcat", "vcat"]] = "hcat",
+ enable_tqdm: Optional[bool] = True,
+ output_name: Optional[str] = None,
+ video_quality: Literal["low", "medium", "high"] = "high",
+) -> None:
+ gs_world = prediction.gaussians
+ # if target poses are not provided, render the (smooth/interpolate) input poses
+ if extrinsics is not None:
+ tgt_extrs = extrinsics
+ else:
+ tgt_extrs = torch.from_numpy(prediction.extrinsics).unsqueeze(0).to(gs_world.means)
+ if prediction.is_metric:
+ scale_factor = prediction.scale_factor
+ if scale_factor is not None:
+ tgt_extrs[:, :, :3, 3] /= scale_factor
+ tgt_intrs = (
+ intrinsics
+ if intrinsics is not None
+ else torch.from_numpy(prediction.intrinsics).unsqueeze(0).to(gs_world.means)
+ )
+ # if render resolution is not provided, render the input ones
+ if out_image_hw is not None:
+ H, W = out_image_hw
+ else:
+ H, W = prediction.depth.shape[-2:]
+ # if single views, render wander trj
+ if tgt_extrs.shape[1] <= 1:
+ trj_mode = "wander"
+ # trj_mode = "dolly_zoom"
+
+ color, depth = run_renderer_in_chunk_w_trj_mode(
+ gaussians=gs_world,
+ extrinsics=tgt_extrs,
+ intrinsics=tgt_intrs,
+ image_shape=(H, W),
+ chunk_size=chunk_size,
+ trj_mode=trj_mode,
+ use_sh=True,
+ color_mode=color_mode,
+ enable_tqdm=enable_tqdm,
+ )
+
+ # save as video
+ ffmpeg_params = [
+ "-crf",
+ VIDEO_QUALITY_MAP[video_quality]["crf"],
+ "-preset",
+ VIDEO_QUALITY_MAP[video_quality]["preset"],
+ "-pix_fmt",
+ "yuv420p",
+ ] # best compatibility
+
+ os.makedirs(os.path.join(export_dir, "gs_video"), exist_ok=True)
+ for idx in range(color.shape[0]):
+ video_i = color[idx]
+ if vis_depth is not None:
+ depth_i = vis_depth_map_tensor(depth[0])
+ cat_fn = hcat if vis_depth == "hcat" else vcat
+ video_i = torch.stack([cat_fn(c, d) for c, d in zip(video_i, depth_i)])
+ frames = list(
+ (video_i.clamp(0, 1) * 255).byte().permute(0, 2, 3, 1).cpu().numpy()
+ ) # T x H x W x C, uint8, numpy()
+
+ fps = 24
+ clip = mpy.ImageSequenceClip(frames, fps=fps)
+ output_name = f"{idx:04d}_{trj_mode}" if output_name is None else output_name
+ save_path = os.path.join(export_dir, f"gs_video/{output_name}.mp4")
+ # clip.write_videofile(save_path, codec="libx264", audio=False, bitrate="4000k")
+ clip.write_videofile(
+ save_path,
+ codec="libx264",
+ audio=False,
+ fps=fps,
+ ffmpeg_params=ffmpeg_params,
+ )
+ return
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/npz.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/npz.py
new file mode 100644
index 0000000000000000000000000000000000000000..35ff250571b6cbc8bc4175add37011ce29db116d
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/npz.py
@@ -0,0 +1,73 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import numpy as np
+
+from depth_anything_3.specs import Prediction
+from depth_anything_3.utils.parallel_utils import async_call
+
+
+@async_call
+def export_to_npz(
+ prediction: Prediction,
+ export_dir: str,
+):
+ output_file = os.path.join(export_dir, "exports", "npz", "results.npz")
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
+
+ # Use prediction.processed_images, which is already processed image data
+ if prediction.processed_images is None:
+ raise ValueError("prediction.processed_images is required but not available")
+
+ image = prediction.processed_images # (N,H,W,3) uint8
+
+ # Build save dict with only non-None values
+ save_dict = {
+ "image": image,
+ "depth": np.round(prediction.depth, 6),
+ }
+
+ if prediction.conf is not None:
+ save_dict["conf"] = np.round(prediction.conf, 2)
+ if prediction.extrinsics is not None:
+ save_dict["extrinsics"] = prediction.extrinsics
+ if prediction.intrinsics is not None:
+ save_dict["intrinsics"] = prediction.intrinsics
+
+ # aux = {k: np.round(v, 4) for k, v in prediction.aux.items()}
+ np.savez_compressed(output_file, **save_dict)
+
+
+@async_call
+def export_to_mini_npz(
+ prediction: Prediction,
+ export_dir: str,
+):
+ output_file = os.path.join(export_dir, "exports", "mini_npz", "results.npz")
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
+
+ # Build save dict with only non-None values
+ save_dict = {
+ "depth": np.round(prediction.depth, 6),
+ }
+
+ if prediction.conf is not None:
+ save_dict["conf"] = np.round(prediction.conf, 2)
+ if prediction.extrinsics is not None:
+ save_dict["extrinsics"] = prediction.extrinsics
+ if prediction.intrinsics is not None:
+ save_dict["intrinsics"] = prediction.intrinsics
+
+ np.savez_compressed(output_file, **save_dict)
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/utils.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..81f45fb563ce595bf547bebe829c9b83eb175f1c
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/export/utils.py
@@ -0,0 +1,30 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import numpy as np
+import torch
+
+
+def _denorm_and_to_uint8(image_tensor: torch.Tensor) -> np.ndarray:
+ """Denormalize to [0,255] and output (N, H, W, 3) uint8."""
+ resnet_mean = torch.tensor(
+ [0.485, 0.456, 0.406], dtype=image_tensor.dtype, device=image_tensor.device
+ )
+ resnet_std = torch.tensor(
+ [0.229, 0.224, 0.225], dtype=image_tensor.dtype, device=image_tensor.device
+ )
+ img = image_tensor * resnet_std[None, :, None, None] + resnet_mean[None, :, None, None]
+ img = torch.clamp(img, 0.0, 1.0)
+ img = (img.permute(0, 2, 3, 1).cpu().numpy() * 255.0).round().astype(np.uint8) # (N,H,W,3)
+ return img
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/geometry.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/geometry.py
new file mode 100644
index 0000000000000000000000000000000000000000..a88289eb0243a3b337a06933922fd6c038b54b00
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/geometry.py
@@ -0,0 +1,349 @@
+# flake8: noqa: F722
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from types import SimpleNamespace
+from typing import Optional
+import numpy as np
+import torch
+import torch.nn.functional as F
+from einops import einsum
+
+
+def as_homogeneous(ext):
+ """
+ Accept (..., 3,4) or (..., 4,4) extrinsics, return (...,4,4) homogeneous matrix.
+ Supports torch.Tensor or np.ndarray.
+ """
+ if isinstance(ext, torch.Tensor):
+ # If already in homogeneous form
+ if ext.shape[-2:] == (4, 4):
+ return ext
+ elif ext.shape[-2:] == (3, 4):
+ # Create a new homogeneous matrix
+ ones = torch.zeros_like(ext[..., :1, :4])
+ ones[..., 0, 3] = 1.0
+ return torch.cat([ext, ones], dim=-2)
+ else:
+ raise ValueError(f"Invalid shape for torch.Tensor: {ext.shape}")
+
+ elif isinstance(ext, np.ndarray):
+ if ext.shape[-2:] == (4, 4):
+ return ext
+ elif ext.shape[-2:] == (3, 4):
+ ones = np.zeros_like(ext[..., :1, :4])
+ ones[..., 0, 3] = 1.0
+ return np.concatenate([ext, ones], axis=-2)
+ else:
+ raise ValueError(f"Invalid shape for np.ndarray: {ext.shape}")
+
+ else:
+ raise TypeError("Input must be a torch.Tensor or np.ndarray.")
+
+
+@torch.jit.script
+def affine_inverse(A: torch.Tensor):
+ R = A[..., :3, :3] # ..., 3, 3
+ T = A[..., :3, 3:] # ..., 3, 1
+ P = A[..., 3:, :] # ..., 1, 4
+ return torch.cat([torch.cat([R.mT, -R.mT @ T], dim=-1), P], dim=-2)
+
+
+def transpose_last_two_axes(arr):
+ """
+ for np < 2
+ """
+ if arr.ndim < 2:
+ return arr
+ axes = list(range(arr.ndim))
+ # swap the last two
+ axes[-2], axes[-1] = axes[-1], axes[-2]
+ return arr.transpose(axes)
+
+
+def affine_inverse_np(A: np.ndarray):
+ R = A[..., :3, :3]
+ T = A[..., :3, 3:]
+ P = A[..., 3:, :]
+ return np.concatenate(
+ [
+ np.concatenate([transpose_last_two_axes(R), -transpose_last_two_axes(R) @ T], axis=-1),
+ P,
+ ],
+ axis=-2,
+ )
+
+
+def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Quaternion Order: XYZW or say ijkr, scalar-last
+
+ Convert rotations given as quaternions to rotation matrices.
+ Args:
+ quaternions: quaternions with real part last,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ i, j, k, r = torch.unbind(quaternions, -1)
+ # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`.
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+
+def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part last, as tensor of shape (..., 4).
+ Quaternion Order: XYZW or say ijkr, scalar-last
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
+ matrix.reshape(batch_dim + (9,)), dim=-1
+ )
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [
+ 1.0 + m00 + m11 + m22,
+ 1.0 + m00 - m11 - m22,
+ 1.0 - m00 + m11 - m22,
+ 1.0 - m00 - m11 + m22,
+ ],
+ dim=-1,
+ )
+ )
+
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
+ # pyre-fixme[58]: `**` is not supported for operand types `Tensor` and
+ # `int`.
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+ out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(
+ batch_dim + (4,)
+ )
+
+ # Convert from rijk to ijkr
+ out = out[..., [1, 2, 3, 0]]
+
+ out = standardize_quaternion(out)
+
+ return out
+
+
+def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
+ """
+ Returns torch.sqrt(torch.max(0, x))
+ but with a zero subgradient where x is 0.
+ """
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ if torch.is_grad_enabled():
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ else:
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
+ return ret
+
+
+def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert a unit quaternion to a standard form: one in which the real
+ part is non negative.
+
+ Args:
+ quaternions: Quaternions with real part last,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Standardized quaternions as tensor of shape (..., 4).
+ """
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
+
+
+def sample_image_grid(
+ shape: tuple[int, ...],
+ device: torch.device = torch.device("cpu"),
+) -> tuple[
+ torch.Tensor, # float coordinates (xy indexing), "*shape dim"
+ torch.Tensor, # integer indices (ij indexing), "*shape dim"
+]:
+ """Get normalized (range 0 to 1) coordinates and integer indices for an image."""
+
+ # Each entry is a pixel-wise integer coordinate. In the 2D case, each entry is a
+ # (row, col) coordinate.
+ indices = [torch.arange(length, device=device) for length in shape]
+ stacked_indices = torch.stack(torch.meshgrid(*indices, indexing="ij"), dim=-1)
+
+ # Each entry is a floating-point coordinate in the range (0, 1). In the 2D case,
+ # each entry is an (x, y) coordinate.
+ coordinates = [(idx + 0.5) / length for idx, length in zip(indices, shape)]
+ coordinates = reversed(coordinates)
+ coordinates = torch.stack(torch.meshgrid(*coordinates, indexing="xy"), dim=-1)
+
+ return coordinates, stacked_indices
+
+
+def homogenize_points(points: torch.Tensor) -> torch.Tensor: # "*batch dim" # "*batch dim+1"
+ """Convert batched points (xyz) to (xyz1)."""
+ return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
+
+
+def homogenize_vectors(vectors: torch.Tensor) -> torch.Tensor: # "*batch dim" # "*batch dim+1"
+ """Convert batched vectors (xyz) to (xyz0)."""
+ return torch.cat([vectors, torch.zeros_like(vectors[..., :1])], dim=-1)
+
+
+def transform_rigid(
+ homogeneous_coordinates: torch.Tensor, # "*#batch dim"
+ transformation: torch.Tensor, # "*#batch dim dim"
+) -> torch.Tensor: # "*batch dim"
+ """Apply a rigid-body transformation to points or vectors."""
+ return einsum(
+ transformation,
+ homogeneous_coordinates.to(transformation.dtype),
+ "... i j, ... j -> ... i",
+ )
+
+
+def transform_cam2world(
+ homogeneous_coordinates: torch.Tensor, # "*#batch dim"
+ extrinsics: torch.Tensor, # "*#batch dim dim"
+) -> torch.Tensor: # "*batch dim"
+ """Transform points from 3D camera coordinates to 3D world coordinates."""
+ return transform_rigid(homogeneous_coordinates, extrinsics)
+
+
+def unproject(
+ coordinates: torch.Tensor, # "*#batch dim"
+ z: torch.Tensor, # "*#batch"
+ intrinsics: torch.Tensor, # "*#batch dim+1 dim+1"
+) -> torch.Tensor: # "*batch dim+1"
+ """Unproject 2D camera coordinates with the given Z values."""
+
+ # Apply the inverse intrinsics to the coordinates.
+ coordinates = homogenize_points(coordinates)
+ ray_directions = einsum(
+ intrinsics.float().inverse().to(intrinsics),
+ coordinates.to(intrinsics.dtype),
+ "... i j, ... j -> ... i",
+ )
+
+ # Apply the supplied depth values.
+ return ray_directions * z[..., None]
+
+
+def get_world_rays(
+ coordinates: torch.Tensor, # "*#batch dim"
+ extrinsics: torch.Tensor, # "*#batch dim+2 dim+2"
+ intrinsics: torch.Tensor, # "*#batch dim+1 dim+1"
+) -> tuple[
+ torch.Tensor, # origins, "*batch dim+1"
+ torch.Tensor, # directions, "*batch dim+1"
+]:
+ # Get camera-space ray directions.
+ directions = unproject(
+ coordinates,
+ torch.ones_like(coordinates[..., 0]),
+ intrinsics,
+ )
+ directions = directions / directions.norm(dim=-1, keepdim=True)
+
+ # Transform ray directions to world coordinates.
+ directions = homogenize_vectors(directions)
+ directions = transform_cam2world(directions, extrinsics)[..., :-1]
+
+ # Tile the ray origins to have the same shape as the ray directions.
+ origins = extrinsics[..., :-1, -1].broadcast_to(directions.shape)
+
+ return origins, directions
+
+
+def get_fov(intrinsics: torch.Tensor) -> torch.Tensor: # "batch 3 3" -> "batch 2"
+ intrinsics_inv = intrinsics.float().inverse().to(intrinsics)
+
+ def process_vector(vector):
+ vector = torch.tensor(vector, dtype=intrinsics.dtype, device=intrinsics.device)
+ vector = einsum(intrinsics_inv, vector, "b i j, j -> b i")
+ return vector / vector.norm(dim=-1, keepdim=True)
+
+ left = process_vector([0, 0.5, 1])
+ right = process_vector([1, 0.5, 1])
+ top = process_vector([0.5, 0, 1])
+ bottom = process_vector([0.5, 1, 1])
+ fov_x = (left * right).sum(dim=-1).acos()
+ fov_y = (top * bottom).sum(dim=-1).acos()
+ return torch.stack((fov_x, fov_y), dim=-1)
+
+
+def map_pdf_to_opacity(
+ pdf: torch.Tensor, # " *batch"
+ global_step: int = 0,
+ opacity_mapping: Optional[dict] = None,
+) -> torch.Tensor: # " *batch"
+ # https://www.desmos.com/calculator/opvwti3ba9
+
+ # Figure out the exponent.
+ if opacity_mapping is not None:
+ cfg = SimpleNamespace(**opacity_mapping)
+ x = cfg.initial + min(global_step / cfg.warm_up, 1) * (cfg.final - cfg.initial)
+ else:
+ x = 0.0
+ exponent = 2**x
+
+ # Map the probability density to an opacity.
+ return 0.5 * (1 - (1 - pdf) ** exponent + pdf ** (1 / exponent))
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/gsply_helpers.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/gsply_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..5733009e4e91ad80ab59179c80e2df8a0430ab5f
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/gsply_helpers.py
@@ -0,0 +1,173 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from pathlib import Path
+from typing import Optional
+import numpy as np
+import torch
+from einops import rearrange, repeat
+from plyfile import PlyData, PlyElement
+from torch import Tensor
+
+from depth_anything_3.specs import Gaussians
+
+
+def construct_list_of_attributes(num_rest: int) -> list[str]:
+ attributes = ["x", "y", "z", "nx", "ny", "nz"]
+ for i in range(3):
+ attributes.append(f"f_dc_{i}")
+ for i in range(num_rest):
+ attributes.append(f"f_rest_{i}")
+ attributes.append("opacity")
+ for i in range(3):
+ attributes.append(f"scale_{i}")
+ for i in range(4):
+ attributes.append(f"rot_{i}")
+ return attributes
+
+
+def export_ply(
+ means: Tensor, # "gaussian 3"
+ scales: Tensor, # "gaussian 3"
+ rotations: Tensor, # "gaussian 4"
+ harmonics: Tensor, # "gaussian 3 d_sh"
+ opacities: Tensor, # "gaussian"
+ path: Path,
+ shift_and_scale: bool = False,
+ save_sh_dc_only: bool = True,
+ match_3dgs_mcmc_dev: Optional[bool] = False,
+):
+ if shift_and_scale:
+ # Shift the scene so that the median Gaussian is at the origin.
+ means = means - means.median(dim=0).values
+
+ # Rescale the scene so that most Gaussians are within range [-1, 1].
+ scale_factor = means.abs().quantile(0.95, dim=0).max()
+ means = means / scale_factor
+ scales = scales / scale_factor
+
+ rotations = rotations.detach().cpu().numpy()
+
+ # Since current model use SH_degree = 4,
+ # which require large memory to store, we can only save the DC band to save memory.
+ f_dc = harmonics[..., 0]
+ f_rest = harmonics[..., 1:].flatten(start_dim=1)
+
+ if match_3dgs_mcmc_dev:
+ sh_degree = 3
+ n_rest = 3 * (sh_degree + 1) ** 2 - 3
+ f_rest = repeat(
+ torch.zeros_like(harmonics[..., :1]), "... i -> ... (n i)", n=(n_rest // 3)
+ ).flatten(start_dim=1)
+ dtype_full = [
+ (attribute, "f4")
+ for attribute in construct_list_of_attributes(num_rest=n_rest)
+ if attribute not in ("nx", "ny", "nz")
+ ]
+ else:
+ dtype_full = [
+ (attribute, "f4")
+ for attribute in construct_list_of_attributes(
+ 0 if save_sh_dc_only else f_rest.shape[1]
+ )
+ ]
+ elements = np.empty(means.shape[0], dtype=dtype_full)
+ attributes = [
+ means.detach().cpu().numpy(),
+ torch.zeros_like(means).detach().cpu().numpy(),
+ f_dc.detach().cpu().contiguous().numpy(),
+ f_rest.detach().cpu().contiguous().numpy(),
+ opacities[..., None].detach().cpu().numpy(),
+ scales.log().detach().cpu().numpy(),
+ rotations,
+ ]
+ if match_3dgs_mcmc_dev:
+ attributes.pop(1) # dummy normal is not needed
+ elif save_sh_dc_only:
+ attributes.pop(3) # remove f_rest from attributes
+
+ attributes = np.concatenate(attributes, axis=1)
+ elements[:] = list(map(tuple, attributes))
+ path.parent.mkdir(exist_ok=True, parents=True)
+ PlyData([PlyElement.describe(elements, "vertex")]).write(path)
+
+
+def inverse_sigmoid(x):
+ return torch.log(x / (1 - x))
+
+
+def save_gaussian_ply(
+ gaussians: Gaussians,
+ save_path: str,
+ ctx_depth: torch.Tensor, # depth of input views; for getting shape and filtering, "v h w 1"
+ shift_and_scale: bool = False,
+ save_sh_dc_only: bool = True,
+ gs_views_interval: int = 1,
+ inv_opacity: Optional[bool] = True,
+ prune_by_depth_percent: Optional[float] = 1.0,
+ prune_border_gs: Optional[bool] = True,
+ match_3dgs_mcmc_dev: Optional[bool] = False,
+):
+ b = gaussians.means.shape[0]
+ assert b == 1, "must set batch_size=1 when exporting 3D gaussians"
+ src_v, out_h, out_w, _ = ctx_depth.shape
+
+ # extract gs params
+ world_means = gaussians.means
+ world_shs = gaussians.harmonics
+ world_rotations = gaussians.rotations
+ gs_scales = gaussians.scales
+ gs_opacities = inverse_sigmoid(gaussians.opacities) if inv_opacity else gaussians.opacities
+
+ # Create a mask to filter the Gaussians.
+
+ # TODO: prune the sky region here
+
+ # throw away Gaussians at the borders, since they're generally of lower quality.
+ if prune_border_gs:
+ mask = torch.zeros_like(ctx_depth, dtype=torch.bool)
+ gstrim_h = int(8 / 256 * out_h)
+ gstrim_w = int(8 / 256 * out_w)
+ mask[:, gstrim_h:-gstrim_h, gstrim_w:-gstrim_w, :] = 1
+ else:
+ mask = torch.ones_like(ctx_depth, dtype=torch.bool)
+
+ # trim the far away point based on depth;
+ if prune_by_depth_percent is not None and prune_by_depth_percent < 1:
+ in_depths = ctx_depth
+ d_percentile = torch.quantile(
+ in_depths.view(in_depths.shape[0], -1), q=prune_by_depth_percent, dim=1
+ ).view(-1, 1, 1)
+ d_mask = (in_depths[..., 0] <= d_percentile).unsqueeze(-1)
+ mask = mask & d_mask
+ mask = mask.squeeze(-1) # v h w
+
+ # helper fn, must place after mask
+ def trim_select_reshape(element):
+ selected_element = rearrange(
+ element[0], "(v h w) ... -> v h w ...", v=src_v, h=out_h, w=out_w
+ )
+ selected_element = selected_element[::gs_views_interval][mask[::gs_views_interval]]
+ return selected_element
+
+ export_ply(
+ means=trim_select_reshape(world_means),
+ scales=trim_select_reshape(gs_scales),
+ rotations=trim_select_reshape(world_rotations),
+ harmonics=trim_select_reshape(world_shs),
+ opacities=trim_select_reshape(gs_opacities),
+ path=Path(save_path),
+ shift_and_scale=shift_and_scale,
+ save_sh_dc_only=save_sh_dc_only,
+ match_3dgs_mcmc_dev=match_3dgs_mcmc_dev,
+ )
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/io/input_processor.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/io/input_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..83114044fa6daad78abdf96412e259c9dc8fb04d
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/io/input_processor.py
@@ -0,0 +1,579 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Input processor for Depth Anything 3 (parallelized).
+
+This version removes the square center-crop step for "*crop" methods (same as your note).
+In addition, it parallelizes per-image preprocessing using the provided `parallel_execution`.
+"""
+
+from __future__ import annotations
+
+from typing import Optional, Sequence, Tuple
+import cv2
+import numpy as np
+import torch
+import torchvision.transforms as T
+from PIL import Image, ImageOps
+
+from depth_anything_3.utils.logger import logger
+from depth_anything_3.utils.parallel_utils import parallel_execution
+
+
+class InputProcessor:
+ """Prepares a batch of images for model inference.
+ This processor converts a list of image file paths into a single, model-ready
+ tensor. The processing pipeline is executed in parallel across multiple workers
+ for efficiency.
+
+ Pipeline:
+ 1) Load image and convert to RGB
+ 2) Boundary resize (upper/lower bound, preserving aspect ratio)
+ 3) Enforce divisibility by PATCH_SIZE:
+ - "*resize" methods: each dimension is rounded to nearest multiple
+ (may up/downscale a few px)
+ - "*crop" methods: each dimension is floored to nearest multiple via center crop
+ 4) Convert to tensor and apply ImageNet normalization
+ 5) Stack into (1, N, 3, H, W)
+
+ Parallelization:
+ - Each image is processed independently in a worker.
+ - Order of outputs matches the input order.
+ """
+
+ NORMALIZE = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
+ PATCH_SIZE = 14
+
+ def __init__(self):
+ pass
+
+ # -----------------------------
+ # Public API
+ # -----------------------------
+ def __call__(
+ self,
+ image: list[np.ndarray | Image.Image | str],
+ extrinsics: np.ndarray | None = None,
+ intrinsics: np.ndarray | None = None,
+ process_res: Optional[int] = None,
+ process_res_method: str = "keep",
+ *,
+ num_workers: int = 8,
+ print_progress: bool = False,
+ sequential: bool | None = None,
+ desc: str | None = "Preprocess",
+ ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None, list[dict]]:
+ """
+ Returns:
+ (tensor, extrinsics_list, intrinsics_list, pad_meta)
+ tensor shape: (1, N, 3, H, W)
+ """
+ sequential = self._resolve_sequential(sequential, num_workers)
+ exts_list, ixts_list = self._validate_and_pack_meta(image, extrinsics, intrinsics)
+
+ results = self._run_parallel(
+ image=image,
+ exts_list=exts_list,
+ ixts_list=ixts_list,
+ process_res=process_res,
+ process_res_method=process_res_method,
+ num_workers=num_workers,
+ print_progress=print_progress,
+ sequential=sequential,
+ desc=desc,
+ )
+
+ proc_imgs, out_sizes, out_ixts, out_exts, pad_meta = self._unpack_results(results)
+ proc_imgs, out_sizes, out_ixts, pad_meta = self._unify_batch_shapes(
+ proc_imgs, out_sizes, out_ixts, pad_meta
+ )
+
+ batch_tensor = self._stack_batch(proc_imgs)
+ out_exts = (
+ torch.from_numpy(np.asarray(out_exts)).float()
+ if out_exts is not None and out_exts[0] is not None
+ else None
+ )
+ out_ixts = (
+ torch.from_numpy(np.asarray(out_ixts)).float()
+ if out_ixts is not None and out_ixts[0] is not None
+ else None
+ )
+ return (batch_tensor, out_exts, out_ixts, pad_meta)
+
+ # -----------------------------
+ # __call__ helpers
+ # -----------------------------
+ def _resolve_sequential(self, sequential: bool | None, num_workers: int) -> bool:
+ return (num_workers <= 1) if sequential is None else sequential
+
+ def _validate_and_pack_meta(
+ self,
+ images: list[np.ndarray | Image.Image | str],
+ extrinsics: np.ndarray | None,
+ intrinsics: np.ndarray | None,
+ ) -> tuple[list[np.ndarray | None] | None, list[np.ndarray | None] | None]:
+ if extrinsics is not None and len(extrinsics) != len(images):
+ raise ValueError("Length of extrinsics must match images when provided.")
+ if intrinsics is not None and len(intrinsics) != len(images):
+ raise ValueError("Length of intrinsics must match images when provided.")
+ exts_list = [e for e in extrinsics] if extrinsics is not None else None
+ ixts_list = [k for k in intrinsics] if intrinsics is not None else None
+ return exts_list, ixts_list
+
+ def _run_parallel(
+ self,
+ *,
+ image: list[np.ndarray | Image.Image | str],
+ exts_list: list[np.ndarray | None] | None,
+ ixts_list: list[np.ndarray | None] | None,
+ process_res: int,
+ process_res_method: str,
+ num_workers: int,
+ print_progress: bool,
+ sequential: bool,
+ desc: str | None,
+ ):
+ results = parallel_execution(
+ image,
+ exts_list,
+ ixts_list,
+ action=self._process_one, # (img, extrinsic, intrinsic, ...)
+ num_processes=num_workers,
+ print_progress=print_progress,
+ sequential=sequential,
+ desc=desc,
+ process_res=process_res,
+ process_res_method=process_res_method,
+ )
+ if not results:
+ raise RuntimeError(
+ "No preprocessing results returned. Check inputs and parallel_execution."
+ )
+ return results
+
+ def _unpack_results(self, results):
+ """
+ results: List[
+ Tuple[
+ torch.Tensor,
+ Tuple[H, W],
+ Optional[np.ndarray],
+ Optional[np.ndarray],
+ dict,
+ ]
+ ]
+ -> processed_images, out_sizes, out_intrinsics, out_extrinsics, pad_meta
+ """
+ try:
+ processed_images, out_sizes, out_intrinsics, out_extrinsics, pad_meta = zip(*results)
+ except Exception as e:
+ raise RuntimeError(
+ "Unexpected results structure from parallel_execution: "
+ f"{type(results)} / sample: {results[0]}"
+ ) from e
+
+ return (
+ list(processed_images),
+ list(out_sizes),
+ list(out_intrinsics),
+ list(out_extrinsics),
+ list(pad_meta),
+ )
+
+ def _unify_batch_shapes(
+ self,
+ processed_images: list[torch.Tensor],
+ out_sizes: list[tuple[int, int]],
+ out_intrinsics: list[np.ndarray | None],
+ pad_meta: list[dict],
+ ) -> tuple[list[torch.Tensor], list[tuple[int, int]], list[np.ndarray | None], list[dict]]:
+ """Center-crop all tensors to the smallest H, W; adjust intrinsics' cx, cy accordingly."""
+ if len(set(out_sizes)) <= 1:
+ return processed_images, out_sizes, out_intrinsics, pad_meta
+
+ min_h = min(h for h, _ in out_sizes)
+ min_w = min(w for _, w in out_sizes)
+ logger.warn(
+ f"Images in batch have different sizes {out_sizes}; "
+ f"center-cropping all to smallest ({min_h},{min_w})"
+ )
+
+ center_crop = T.CenterCrop((min_h, min_w))
+ new_imgs, new_sizes, new_ixts, new_meta = [], [], [], []
+ for img_t, (H, W), K, meta in zip(processed_images, out_sizes, out_intrinsics, pad_meta):
+ crop_top = max(0, (H - min_h) // 2)
+ crop_left = max(0, (W - min_w) // 2)
+ new_imgs.append(center_crop(img_t))
+ new_sizes.append((min_h, min_w))
+ if K is None:
+ new_ixts.append(None)
+ else:
+ K_adj = K.copy()
+ K_adj[0, 2] -= crop_left
+ K_adj[1, 2] -= crop_top
+ new_ixts.append(K_adj)
+ # Cropping invalidates padding meta; reset so we do not apply another crop later.
+ new_meta.append({"orig_size": (min_h, min_w), "pad": (0, 0, 0, 0)})
+ return new_imgs, new_sizes, new_ixts, new_meta
+
+ def _stack_batch(self, processed_images: list[torch.Tensor]) -> torch.Tensor:
+ return torch.stack(processed_images)
+
+ # -----------------------------
+ # Per-item worker
+ # -----------------------------
+ def _process_one(
+ self,
+ img: np.ndarray | Image.Image | str,
+ extrinsic: np.ndarray | None = None,
+ intrinsic: np.ndarray | None = None,
+ *,
+ process_res: Optional[int],
+ process_res_method: str,
+ ) -> tuple[torch.Tensor, tuple[int, int], np.ndarray | None, np.ndarray | None, dict]:
+ # Load & remember original size
+ pil_img = self._load_image(img)
+ orig_w, orig_h = pil_img.size
+
+ # Boundary resize
+ pil_img = self._resize_image(pil_img, process_res, process_res_method)
+ w, h = pil_img.size
+ intrinsic = self._resize_ixt(intrinsic, orig_w, orig_h, w, h)
+ pad_left = pad_right = pad_top = pad_bottom = 0
+
+ # Enforce divisibility by PATCH_SIZE
+ if process_res_method in ("keep", "original"):
+ pil_img, pad_left, pad_right, pad_top, pad_bottom = self._make_divisible_by_pad(
+ pil_img, self.PATCH_SIZE
+ )
+ if any((pad_left, pad_right, pad_top, pad_bottom)):
+ intrinsic = self._pad_ixt(intrinsic, pad_left, pad_top)
+ w, h = pil_img.size
+ elif process_res_method.endswith("resize"):
+ pil_img = self._make_divisible_by_resize(pil_img, self.PATCH_SIZE)
+ new_w, new_h = pil_img.size
+ intrinsic = self._resize_ixt(intrinsic, w, h, new_w, new_h)
+ w, h = new_w, new_h
+ elif process_res_method.endswith("crop"):
+ pil_img = self._make_divisible_by_crop(pil_img, self.PATCH_SIZE)
+ new_w, new_h = pil_img.size
+ intrinsic = self._crop_ixt(intrinsic, w, h, new_w, new_h)
+ w, h = new_w, new_h
+ else:
+ raise ValueError(f"Unsupported process_res_method: {process_res_method}")
+
+ # Convert to tensor & normalize
+ img_tensor = self._normalize_image(pil_img)
+ _, H, W = img_tensor.shape
+ assert (W, H) == (w, h), "Tensor size mismatch with PIL image size after processing."
+
+ meta = {
+ "orig_size": (orig_h, orig_w),
+ "pad": (pad_top, pad_bottom, pad_left, pad_right),
+ }
+
+ # Return: (img_tensor, (H, W), intrinsic, extrinsic, meta)
+ return img_tensor, (H, W), intrinsic, extrinsic, meta
+
+ # -----------------------------
+ # Intrinsics transforms
+ # -----------------------------
+ def _resize_ixt(
+ self,
+ intrinsic: np.ndarray | None,
+ orig_w: int,
+ orig_h: int,
+ w: int,
+ h: int,
+ ) -> np.ndarray | None:
+ if intrinsic is None:
+ return None
+ K = intrinsic.copy()
+ # scale fx, cx by w ratio; fy, cy by h ratio
+ K[:1] *= w / float(orig_w)
+ K[1:2] *= h / float(orig_h)
+ return K
+
+ def _crop_ixt(
+ self,
+ intrinsic: np.ndarray | None,
+ orig_w: int,
+ orig_h: int,
+ w: int,
+ h: int,
+ ) -> np.ndarray | None:
+ if intrinsic is None:
+ return None
+ K = intrinsic.copy()
+ crop_h = (orig_h - h) // 2
+ crop_w = (orig_w - w) // 2
+ K[0, 2] -= crop_w
+ K[1, 2] -= crop_h
+ return K
+
+ def _pad_ixt(
+ self,
+ intrinsic: np.ndarray | None,
+ pad_left: int,
+ pad_top: int,
+ ) -> np.ndarray | None:
+ if intrinsic is None or (pad_left == 0 and pad_top == 0):
+ return intrinsic
+ K = intrinsic.copy()
+ K[0, 2] += pad_left
+ K[1, 2] += pad_top
+ return K
+
+ # -----------------------------
+ # I/O & normalization
+ # -----------------------------
+ def _load_image(self, img: np.ndarray | Image.Image | str) -> Image.Image:
+ if isinstance(img, str):
+ return Image.open(img).convert("RGB")
+ elif isinstance(img, np.ndarray):
+ # Assume HxWxC uint8/RGB
+ return Image.fromarray(img).convert("RGB")
+ elif isinstance(img, Image.Image):
+ return img.convert("RGB")
+ else:
+ raise ValueError(f"Unsupported image type: {type(img)}")
+
+ def _normalize_image(self, img: Image.Image) -> torch.Tensor:
+ img_tensor = T.ToTensor()(img)
+ return self.NORMALIZE(img_tensor)
+
+ # -----------------------------
+ # Boundary resizing
+ # -----------------------------
+ def _resize_image(
+ self, img: Image.Image, target_size: Optional[int], method: str
+ ) -> Image.Image:
+ if method in ("keep", "original"):
+ return img
+
+ if target_size is None or target_size <= 0:
+ raise ValueError(
+ f"process_res must be set when using '{method}'. Received: {target_size}"
+ )
+
+ if method in ("upper_bound_resize", "upper_bound_crop"):
+ return self._resize_longest_side(img, target_size)
+ elif method in ("lower_bound_resize", "lower_bound_crop"):
+ return self._resize_shortest_side(img, target_size)
+ else:
+ raise ValueError(f"Unsupported resize method: {method}")
+
+ def _resize_longest_side(self, img: Image.Image, target_size: int) -> Image.Image:
+ w, h = img.size
+ longest = max(w, h)
+ if longest == target_size:
+ return img
+ scale = target_size / float(longest)
+ new_w = max(1, int(round(w * scale)))
+ new_h = max(1, int(round(h * scale)))
+ interpolation = cv2.INTER_CUBIC if scale > 1.0 else cv2.INTER_AREA
+ arr = cv2.resize(np.asarray(img), (new_w, new_h), interpolation=interpolation)
+ return Image.fromarray(arr)
+
+ def _resize_shortest_side(self, img: Image.Image, target_size: int) -> Image.Image:
+ w, h = img.size
+ shortest = min(w, h)
+ if shortest == target_size:
+ return img
+ scale = target_size / float(shortest)
+ new_w = max(1, int(round(w * scale)))
+ new_h = max(1, int(round(h * scale)))
+ interpolation = cv2.INTER_CUBIC if scale > 1.0 else cv2.INTER_AREA
+ arr = cv2.resize(np.asarray(img), (new_w, new_h), interpolation=interpolation)
+ return Image.fromarray(arr)
+
+ # -----------------------------
+ # Make divisible by PATCH_SIZE
+ # -----------------------------
+ def _make_divisible_by_crop(self, img: Image.Image, patch: int) -> Image.Image:
+ """
+ Floor each dimension to the nearest multiple of PATCH_SIZE via center crop.
+ Example: 504x377 -> 504x364
+ """
+ w, h = img.size
+ new_w = (w // patch) * patch
+ new_h = (h // patch) * patch
+ if new_w == w and new_h == h:
+ return img
+ left = (w - new_w) // 2
+ top = (h - new_h) // 2
+ return img.crop((left, top, left + new_w, top + new_h))
+
+ def _make_divisible_by_resize(self, img: Image.Image, patch: int) -> Image.Image:
+ """
+ Round each dimension to nearest multiple of PATCH_SIZE via small resize.
+ """
+ w, h = img.size
+
+ def nearest_multiple(x: int, p: int) -> int:
+ down = (x // p) * p
+ up = down + p
+ return up if abs(up - x) <= abs(x - down) else down
+
+ new_w = max(1, nearest_multiple(w, patch))
+ new_h = max(1, nearest_multiple(h, patch))
+ if new_w == w and new_h == h:
+ return img
+ upscale = (new_w > w) or (new_h > h)
+ interpolation = cv2.INTER_CUBIC if upscale else cv2.INTER_AREA
+ arr = cv2.resize(np.asarray(img), (new_w, new_h), interpolation=interpolation)
+ return Image.fromarray(arr)
+
+ def _make_divisible_by_pad(
+ self, img: Image.Image, patch: int
+ ) -> tuple[Image.Image, int, int, int, int]:
+ """
+ Pad each dimension up to the nearest multiple of PATCH_SIZE.
+ Returns: (padded_img, pad_left, pad_right, pad_top, pad_bottom)
+ """
+ w, h = img.size
+ new_w = ((w + patch - 1) // patch) * patch
+ new_h = ((h + patch - 1) // patch) * patch
+ pad_w = new_w - w
+ pad_h = new_h - h
+ if pad_w == 0 and pad_h == 0:
+ return img, 0, 0, 0, 0
+
+ pad_left = pad_w // 2
+ pad_right = pad_w - pad_left
+ pad_top = pad_h // 2
+ pad_bottom = pad_h - pad_top
+
+ padded = ImageOps.expand(img, border=(pad_left, pad_top, pad_right, pad_bottom))
+ return padded, pad_left, pad_right, pad_top, pad_bottom
+
+
+# Backward compatibility alias
+InputAdapter = InputProcessor
+
+
+# ===========================
+# Minimal test runner (parallel execution)
+# ===========================
+if __name__ == "__main__":
+ """
+ Minimal test suite:
+ - Creates pairs of images so batch shapes match.
+ - Tests all four process_res_methods.
+ - Prints fx fy cx cy IN->OUT per image.
+ - Includes cases with K/E provided and with None.
+ """
+
+ def fmt_k_line(K: np.ndarray | None) -> str:
+ if K is None:
+ return "None"
+ fx, fy, cx, cy = float(K[0, 0]), float(K[1, 1]), float(K[0, 2]), float(K[1, 2])
+ return f"fx={fx:.3f} fy={fy:.3f} cx={cx:.3f} cy={cy:.3f}"
+
+ def show_result(
+ tag: str,
+ tensor: torch.Tensor,
+ Ks_in: Sequence[np.ndarray | None] | None = None,
+ Ks_out: Sequence[np.ndarray | None] | None = None,
+ ):
+ B, N, C, H, W = tensor.shape
+ print(f"[{tag}] shape={tuple(tensor.shape)} HxW=({H},{W}) div14=({H%14==0},{W%14==0})")
+ assert H % 14 == 0 and W % 14 == 0, f"{tag}: output size not divisible by 14!"
+ if Ks_in is not None or Ks_out is not None:
+ Ks_in = Ks_in or [None] * N
+ Ks_out = Ks_out or [None] * N
+ for i in range(N):
+ print(f" K[{i}]: {fmt_k_line(Ks_in[i])} -> {fmt_k_line(Ks_out[i])}")
+
+ proc = InputProcessor()
+ process_res = 504
+ methods = ["upper_bound_resize", "upper_bound_crop", "lower_bound_resize", "lower_bound_crop"]
+
+ # Example sizes (two orientations)
+ small_sizes = [(680, 1208), (1208, 680)]
+ large_sizes = [(1208, 680), (680, 1208)]
+
+ def make_K(w, h, fx=1200.0, fy=1100.0):
+ cx, cy = w / 2.0, h / 2.0
+ K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]], dtype=np.float32)
+ return K
+
+ def run_suite(suite_name: str, sizes: list[tuple[int, int]]):
+ print(f"\n===== {suite_name} =====")
+ for w, h in sizes:
+ img = Image.new("RGB", (w, h), color=(123, 222, 100))
+ batch_imgs = [img, img]
+
+ # intrinsics / extrinsics examples
+ Ks_in = [make_K(w, h), make_K(w, h)]
+ Es_in = [np.eye(4, dtype=np.float32), np.eye(4, dtype=np.float32)]
+
+ for m in methods:
+ tensor, Es_out, Ks_out = proc(
+ image=batch_imgs,
+ process_res=process_res,
+ process_res_method=m,
+ num_workers=8,
+ print_progress=False,
+ intrinsics=Ks_in, # test with non-None
+ extrinsics=Es_in,
+ )
+ show_result(f"{suite_name} size=({w},{h}) | {m}", tensor, Ks_in, Ks_out)
+
+ # Also test None path
+ tensor2, Es_out2, Ks_out2 = proc(
+ image=batch_imgs,
+ process_res=process_res,
+ process_res_method="upper_bound_resize",
+ num_workers=8,
+ intrinsics=None,
+ extrinsics=None,
+ )
+ show_result(
+ f"{suite_name} size=({w},{h}) | upper_bound_resize | no K/E",
+ tensor2,
+ None,
+ Ks_out2,
+ )
+
+ run_suite("SMALL", small_sizes)
+ run_suite("LARGE", large_sizes)
+
+ # Extra sanity for 504x376
+ print("\n===== EXTRA sanity for 504x376 =====")
+ img_example = Image.new("RGB", (504, 376), color=(10, 20, 30))
+ Ks_in_extra = [make_K(504, 376, fx=900.0, fy=900.0), make_K(504, 376, fx=900.0, fy=900.0)]
+
+ out_r, _, Ks_out_r = proc(
+ image=[img_example, img_example],
+ process_res=504,
+ process_res_method="upper_bound_resize",
+ num_workers=8,
+ intrinsics=Ks_in_extra,
+ )
+ out_c, _, Ks_out_c = proc(
+ image=[img_example, img_example],
+ process_res=504,
+ process_res_method="upper_bound_crop",
+ num_workers=8,
+ intrinsics=Ks_in_extra,
+ )
+ _, _, _, Hr, Wr = out_r.shape
+ _, _, _, Hc, Wc = out_c.shape
+ print(f"upper_bound_resize -> ({Hr},{Wr}) (rounded to nearest multiple of 14)")
+ show_result("Ks after upper_bound_resize", out_r, Ks_in_extra, Ks_out_r)
+ print(f"upper_bound_crop -> ({Hc},{Wc}) (floored to multiple of 14)")
+ show_result("Ks after upper_bound_crop", out_c, Ks_in_extra, Ks_out_c)
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/io/output_processor.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/io/output_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..c317eb9d596c1687b5281891a035993868cc5f8c
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/io/output_processor.py
@@ -0,0 +1,172 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Output processor for Depth Anything 3.
+
+This module handles model output processing, including tensor-to-numpy conversion,
+batch dimension removal, and Prediction object creation.
+"""
+
+from __future__ import annotations
+
+import numpy as np
+import torch
+from addict import Dict as AddictDict
+
+from depth_anything_3.specs import Prediction
+
+
+class OutputProcessor:
+ """
+ Output processor for converting model outputs to Prediction objects.
+
+ Handles tensor-to-numpy conversion, batch dimension removal,
+ and creates structured Prediction objects with proper data types.
+ """
+
+ def __init__(self) -> None:
+ """Initialize the output processor."""
+
+ def __call__(self, model_output: dict[str, torch.Tensor]) -> Prediction:
+ """
+ Convert model output to Prediction object.
+
+ Args:
+ model_output: Model output dictionary containing depth, conf, extrinsics, intrinsics
+ Expected shapes: depth (B, N, 1, H, W), conf (B, N, 1, H, W),
+ extrinsics (B, N, 4, 4), intrinsics (B, N, 3, 3)
+
+ Returns:
+ Prediction: Object containing depth estimation results with shapes:
+ depth (N, H, W), conf (N, H, W), extrinsics (N, 4, 4), intrinsics (N, 3, 3)
+ """
+ # Extract data from batch dimension (B=1, N=number of images)
+ depth = self._extract_depth(model_output)
+ conf = self._extract_conf(model_output)
+ extrinsics = self._extract_extrinsics(model_output)
+ intrinsics = self._extract_intrinsics(model_output)
+ sky = self._extract_sky(model_output)
+ aux = self._extract_aux(model_output)
+ gaussians = model_output.get("gaussians", None)
+ scale_factor = model_output.get("scale_factor", None)
+
+ return Prediction(
+ depth=depth,
+ sky=sky,
+ conf=conf,
+ extrinsics=extrinsics,
+ intrinsics=intrinsics,
+ is_metric=getattr(model_output, "is_metric", 0),
+ gaussians=gaussians,
+ aux=aux,
+ scale_factor=scale_factor,
+ )
+
+ def _extract_depth(self, model_output: dict[str, torch.Tensor]) -> np.ndarray:
+ """
+ Extract depth tensor from model output and convert to numpy.
+
+ Args:
+ model_output: Model output dictionary
+
+ Returns:
+ Depth array with shape (N, H, W)
+ """
+ depth = model_output["depth"].squeeze(0).squeeze(-1).cpu().numpy() # (N, H, W)
+ return depth
+
+ def _extract_conf(self, model_output: dict[str, torch.Tensor]) -> np.ndarray | None:
+ """
+ Extract confidence tensor from model output and convert to numpy.
+
+ Args:
+ model_output: Model output dictionary
+
+ Returns:
+ Confidence array with shape (N, H, W) or None
+ """
+ conf = model_output.get("depth_conf", None)
+ if conf is not None:
+ conf = conf.squeeze(0).cpu().numpy() # (N, H, W)
+ return conf
+
+ def _extract_extrinsics(self, model_output: dict[str, torch.Tensor]) -> np.ndarray | None:
+ """
+ Extract extrinsics tensor from model output and convert to numpy.
+
+ Args:
+ model_output: Model output dictionary
+
+ Returns:
+ Extrinsics array with shape (N, 4, 4) or None
+ """
+ extrinsics = model_output.get("extrinsics", None)
+ if extrinsics is not None:
+ extrinsics = extrinsics.squeeze(0).cpu().numpy() # (N, 4, 4)
+ return extrinsics
+
+ def _extract_intrinsics(self, model_output: dict[str, torch.Tensor]) -> np.ndarray | None:
+ """
+ Extract intrinsics tensor from model output and convert to numpy.
+
+ Args:
+ model_output: Model output dictionary
+
+ Returns:
+ Intrinsics array with shape (N, 3, 3) or None
+ """
+ intrinsics = model_output.get("intrinsics", None)
+ if intrinsics is not None:
+ intrinsics = intrinsics.squeeze(0).cpu().numpy() # (N, 3, 3)
+ return intrinsics
+
+ def _extract_sky(self, model_output: dict[str, torch.Tensor]) -> np.ndarray | None:
+ """
+ Extract sky tensor from model output and convert to numpy.
+
+ Args:
+ model_output: Model output dictionary
+
+ Returns:
+ Sky mask array with shape (N, H, W) or None
+ """
+ sky = model_output.get("sky", None)
+ if sky is not None:
+ sky = sky.squeeze(0).cpu().numpy() >= 0.5 # (N, H, W)
+ return sky
+
+ def _extract_aux(self, model_output: dict[str, torch.Tensor]) -> AddictDict:
+ """
+ Extract auxiliary data from model output and convert to numpy.
+
+ Args:
+ model_output: Model output dictionary
+
+ Returns:
+ Dictionary containing auxiliary data
+ """
+ aux = model_output.get("aux", None)
+ ret = AddictDict()
+ if aux is not None:
+ for k in aux.keys():
+ if isinstance(aux[k], torch.Tensor):
+ ret[k] = aux[k].squeeze(0).cpu().numpy()
+ else:
+ ret[k] = aux[k]
+ return ret
+
+
+# Backward compatibility alias
+OutputAdapter = OutputProcessor
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/layout_helpers.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/layout_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..189c170b2007c979e580b69ca929638560923fb2
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/layout_helpers.py
@@ -0,0 +1,216 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""This file contains useful layout utilities for images. They are:
+
+- add_border: Add a border to an image.
+- cat/hcat/vcat: Join images by arranging them in a line. If the images have different
+ sizes, they are aligned as specified (start, end, center). Allows you to specify a gap
+ between images.
+
+Images are assumed to be float32 tensors with shape (channel, height, width).
+"""
+
+from typing import Any, Generator, Iterable, Literal, Union
+import torch
+from torch import Tensor
+
+Alignment = Literal["start", "center", "end"]
+Axis = Literal["horizontal", "vertical"]
+Color = Union[
+ int,
+ float,
+ Iterable[int],
+ Iterable[float],
+ Tensor,
+ Tensor,
+]
+
+
+def _sanitize_color(color: Color) -> Tensor: # "#channel"
+ # Convert tensor to list (or individual item).
+ if isinstance(color, torch.Tensor):
+ color = color.tolist()
+
+ # Turn iterators and individual items into lists.
+ if isinstance(color, Iterable):
+ color = list(color)
+ else:
+ color = [color]
+
+ return torch.tensor(color, dtype=torch.float32)
+
+
+def _intersperse(iterable: Iterable, delimiter: Any) -> Generator[Any, None, None]:
+ it = iter(iterable)
+ yield next(it)
+ for item in it:
+ yield delimiter
+ yield item
+
+
+def _get_main_dim(main_axis: Axis) -> int:
+ return {
+ "horizontal": 2,
+ "vertical": 1,
+ }[main_axis]
+
+
+def _get_cross_dim(main_axis: Axis) -> int:
+ return {
+ "horizontal": 1,
+ "vertical": 2,
+ }[main_axis]
+
+
+def _compute_offset(base: int, overlay: int, align: Alignment) -> slice:
+ assert base >= overlay
+ offset = {
+ "start": 0,
+ "center": (base - overlay) // 2,
+ "end": base - overlay,
+ }[align]
+ return slice(offset, offset + overlay)
+
+
+def overlay(
+ base: Tensor, # "channel base_height base_width"
+ overlay: Tensor, # "channel overlay_height overlay_width"
+ main_axis: Axis,
+ main_axis_alignment: Alignment,
+ cross_axis_alignment: Alignment,
+) -> Tensor: # "channel base_height base_width"
+ # The overlay must be smaller than the base.
+ _, base_height, base_width = base.shape
+ _, overlay_height, overlay_width = overlay.shape
+ assert base_height >= overlay_height and base_width >= overlay_width
+
+ # Compute spacing on the main dimension.
+ main_dim = _get_main_dim(main_axis)
+ main_slice = _compute_offset(
+ base.shape[main_dim], overlay.shape[main_dim], main_axis_alignment
+ )
+
+ # Compute spacing on the cross dimension.
+ cross_dim = _get_cross_dim(main_axis)
+ cross_slice = _compute_offset(
+ base.shape[cross_dim], overlay.shape[cross_dim], cross_axis_alignment
+ )
+
+ # Combine the slices and paste the overlay onto the base accordingly.
+ selector = [..., None, None]
+ selector[main_dim] = main_slice
+ selector[cross_dim] = cross_slice
+ result = base.clone()
+ result[selector] = overlay
+ return result
+
+
+def cat(
+ main_axis: Axis,
+ *images: Iterable[Tensor], # "channel _ _"
+ align: Alignment = "center",
+ gap: int = 8,
+ gap_color: Color = 1,
+) -> Tensor: # "channel height width"
+ """Arrange images in a line. The interface resembles a CSS div with flexbox."""
+ device = images[0].device
+ gap_color = _sanitize_color(gap_color).to(device)
+
+ # Find the maximum image side length in the cross axis dimension.
+ cross_dim = _get_cross_dim(main_axis)
+ cross_axis_length = max(image.shape[cross_dim] for image in images)
+
+ # Pad the images.
+ padded_images = []
+ for image in images:
+ # Create an empty image with the correct size.
+ padded_shape = list(image.shape)
+ padded_shape[cross_dim] = cross_axis_length
+ base = torch.ones(padded_shape, dtype=torch.float32, device=device)
+ base = base * gap_color[:, None, None]
+ padded_images.append(overlay(base, image, main_axis, "start", align))
+
+ # Intersperse separators if necessary.
+ if gap > 0:
+ # Generate a separator.
+ c, _, _ = images[0].shape
+ separator_size = [gap, gap]
+ separator_size[cross_dim - 1] = cross_axis_length
+ separator = torch.ones((c, *separator_size), dtype=torch.float32, device=device)
+ separator = separator * gap_color[:, None, None]
+
+ # Intersperse the separator between the images.
+ padded_images = list(_intersperse(padded_images, separator))
+
+ return torch.cat(padded_images, dim=_get_main_dim(main_axis))
+
+
+def hcat(
+ *images: Iterable[Tensor], # "channel _ _"
+ align: Literal["start", "center", "end", "top", "bottom"] = "start",
+ gap: int = 8,
+ gap_color: Color = 1,
+):
+ """Shorthand for a horizontal linear concatenation."""
+ return cat(
+ "horizontal",
+ *images,
+ align={
+ "start": "start",
+ "center": "center",
+ "end": "end",
+ "top": "start",
+ "bottom": "end",
+ }[align],
+ gap=gap,
+ gap_color=gap_color,
+ )
+
+
+def vcat(
+ *images: Iterable[Tensor], # "channel _ _"
+ align: Literal["start", "center", "end", "left", "right"] = "start",
+ gap: int = 8,
+ gap_color: Color = 1,
+):
+ """Shorthand for a horizontal linear concatenation."""
+ return cat(
+ "vertical",
+ *images,
+ align={
+ "start": "start",
+ "center": "center",
+ "end": "end",
+ "left": "start",
+ "right": "end",
+ }[align],
+ gap=gap,
+ gap_color=gap_color,
+ )
+
+
+def add_border(
+ image: Tensor, # "channel height width"
+ border: int = 8,
+ color: Color = 1,
+) -> Tensor: # "channel new_height new_width"
+ color = _sanitize_color(color).to(image)
+ c, h, w = image.shape
+ result = torch.empty(
+ (c, h + 2 * border, w + 2 * border), dtype=torch.float32, device=image.device
+ )
+ result[:] = color[:, None, None]
+ result[:, border : h + border, border : w + border] = image
+ return result
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/logger.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..0eb4f60696a085001cf4866ccfe1654170702a2d
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/logger.py
@@ -0,0 +1,82 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import sys
+
+
+class Color:
+ RED = "\033[91m"
+ YELLOW = "\033[93m"
+ WHITE = "\033[97m"
+ GREEN = "\033[92m"
+ RESET = "\033[0m"
+
+
+LOG_LEVELS = {"ERROR": 0, "WARN": 1, "INFO": 2, "DEBUG": 3}
+
+COLOR_MAP = {"ERROR": Color.RED, "WARN": Color.YELLOW, "INFO": Color.WHITE, "DEBUG": Color.GREEN}
+
+
+def get_env_log_level():
+ level = os.environ.get("DA3_LOG_LEVEL", "INFO").upper()
+ return LOG_LEVELS.get(level, LOG_LEVELS["INFO"])
+
+
+class Logger:
+ def __init__(self):
+ self.level = get_env_log_level()
+
+ def log(self, level_str, *args, **kwargs):
+ level_key = level_str.split(":")[0].strip()
+ level_val = LOG_LEVELS.get(level_key)
+ if level_val is None:
+ raise ValueError(f"Unknown log level: {level_str}")
+ if self.level >= level_val:
+ color = COLOR_MAP[level_key]
+ msg = " ".join(str(arg) for arg in args)
+
+ # Align log level output in square brackets
+ # ERROR and DEBUG are 5 characters, INFO and WARN have an extra space for alignment
+ tag = level_key
+ if tag in ("INFO", "WARN"):
+ tag += " "
+ print(
+ f"{color}[{tag}] {msg}{Color.RESET}",
+ file=sys.stderr if level_key == "ERROR" else sys.stdout,
+ **kwargs,
+ )
+
+ def error(self, *args, **kwargs):
+ self.log("ERROR:", *args, **kwargs)
+
+ def warn(self, *args, **kwargs):
+ self.log("WARN:", *args, **kwargs)
+
+ def info(self, *args, **kwargs):
+ self.log("INFO:", *args, **kwargs)
+
+ def debug(self, *args, **kwargs):
+ self.log("DEBUG:", *args, **kwargs)
+
+
+logger = Logger()
+
+__all__ = ["logger"]
+
+if __name__ == "__main__":
+ logger.info("This is an info message")
+ logger.warn("This is a warning message")
+ logger.error("This is an error message")
+ logger.debug("This is a debug message")
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/memory.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/memory.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5f61595d21a6c221b9be3c7954fe56ff83e5300
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/memory.py
@@ -0,0 +1,128 @@
+"""
+GPU memory utility helpers.
+
+Shared cleanup and memory checking logic used by both the backend API and
+the Gradio UI to keep memory-management behavior consistent.
+"""
+from __future__ import annotations
+
+import gc
+
+from typing import Any, Dict, Optional
+
+import torch
+
+
+def get_gpu_memory_info() -> Optional[Dict[str, Any]]:
+ """Return a snapshot of current GPU memory usage or None if CUDA not available.
+
+ Keys in returned dict: total_gb, allocated_gb, reserved_gb, free_gb, utilization
+ """
+ if not torch.cuda.is_available():
+ return None
+
+ try:
+ device = torch.cuda.current_device()
+ total_memory = torch.cuda.get_device_properties(device).total_memory
+ allocated_memory = torch.cuda.memory_allocated(device)
+ reserved_memory = torch.cuda.memory_reserved(device)
+ free_memory = total_memory - reserved_memory
+
+ return {
+ "total_gb": total_memory / 1024 ** 3,
+ "allocated_gb": allocated_memory / 1024 ** 3,
+ "reserved_gb": reserved_memory / 1024 ** 3,
+ "free_gb": free_memory / 1024 ** 3,
+ "utilization": (reserved_memory / total_memory) * 100,
+ }
+ except Exception:
+ return None
+
+
+def cleanup_cuda_memory() -> None:
+ """Perform a robust GPU cleanup sequence.
+
+ This includes synchronizing, emptying caches, collecting IPC handles and
+ running the Python garbage collector. Use this instead of a raw
+ ``torch.cuda.empty_cache()`` where you need reliable freeing of GPU memory
+ between model loads or in error handling paths.
+ """
+ try:
+ if torch.cuda.is_available():
+ mem_before = get_gpu_memory_info()
+
+ torch.cuda.synchronize()
+ torch.cuda.empty_cache()
+ # Collect cross-process cuda resources
+ try:
+ torch.cuda.ipc_collect()
+ except Exception:
+ # Older PyTorch versions or non-cuda devices may not support
+ # ipc_collect (no-op if not available)
+ pass
+ gc.collect()
+
+ mem_after = get_gpu_memory_info()
+ if mem_before and mem_after:
+ freed = mem_before["reserved_gb"] - mem_after["reserved_gb"]
+ print(
+ f"CUDA cleanup: freed {freed:.2f}GB, "
+ f"available: {mem_after['free_gb']:.2f}GB/{mem_after['total_gb']:.2f}GB"
+ )
+ else:
+ print("CUDA memory cleanup completed")
+ except Exception as e:
+ print(f"Warning: CUDA cleanup failed: {e}")
+
+
+def check_memory_availability(required_gb: float = 2.0) -> tuple[bool, str]:
+ """Return whether at least ``required_gb`` seems available on the current GPU.
+
+ The returned tuple is (is_available, message) with a human-friendly message.
+ """
+ try:
+ if not torch.cuda.is_available():
+ return False, "CUDA is not available"
+
+ mem_info = get_gpu_memory_info()
+ if mem_info is None:
+ return True, "Cannot check memory, proceeding anyway"
+
+ if mem_info["free_gb"] < required_gb:
+ return (
+ False,
+ (
+ f"Insufficient GPU memory: {mem_info['free_gb']:.2f}GB available, "
+ f"{required_gb:.2f}GB required. Total: {mem_info['total_gb']:.2f}GB, "
+ f"Used: {mem_info['reserved_gb']:.2f}GB ({mem_info['utilization']:.1f}%)"
+ ),
+ )
+
+ return (
+ True,
+ (
+ f"Memory check passed: {mem_info['free_gb']:.2f}GB available, "
+ f"{required_gb:.2f}GB required"
+ ),
+ )
+ except Exception as e:
+ return True, f"Memory check failed: {e}, proceeding anyway"
+def estimate_memory_requirement(num_images: int, process_res: int | None) -> float:
+ """Heuristic estimate for memory usage (GB) based on image count and resolution.
+
+ This mirrors the simple policy used by the backend service so other code
+ (e.g., Gradio UI) can make consistent decisions when checking available
+ memory before loading a model or running inference.
+
+ Args:
+ num_images: Number of images to process.
+ process_res: Processing resolution.
+
+ Returns:
+ Estimated memory requirement in GB.
+ """
+ base_memory = 2.0
+ effective_res = 504 if process_res is None or process_res <= 0 else process_res
+ per_image_memory = (effective_res / 504) ** 2 * 0.5
+ total_memory = base_memory + (num_images * per_image_memory * 0.1)
+ return total_memory
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/model_loading.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/model_loading.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6d43b5bbab5a0989eae272422192103cd4d5bee
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/model_loading.py
@@ -0,0 +1,149 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+Model loading and state dict conversion utilities.
+"""
+
+from typing import Dict, Tuple
+import torch
+
+from depth_anything_3.utils.logger import logger
+
+
+def convert_general_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """
+ Convert general model state dict to match current model architecture.
+
+ Args:
+ state_dict: Original state dictionary
+
+ Returns:
+ Converted state dictionary
+ """
+ # Replace module prefixes
+ state_dict = {k.replace("module.", "model."): v for k, v in state_dict.items()}
+ state_dict = {k.replace(".net.", ".backbone."): v for k, v in state_dict.items()}
+
+ # Remove camera token if present
+ if "model.backbone.pretrained.camera_token" in state_dict:
+ del state_dict["model.backbone.pretrained.camera_token"]
+
+ # Replace camera token naming
+ state_dict = {
+ k.replace(".camera_token_extra", ".camera_token"): v for k, v in state_dict.items()
+ }
+
+ # Replace head naming
+ state_dict = {
+ k.replace("model.all_heads.camera_cond_head", "model.cam_enc"): v
+ for k, v in state_dict.items()
+ }
+ state_dict = {
+ k.replace("model.all_heads.camera_head", "model.cam_dec"): v for k, v in state_dict.items()
+ }
+ state_dict = {k.replace(".more_mlps.", ".backbone."): v for k, v in state_dict.items()}
+ state_dict = {k.replace(".fc_rot.", ".fc_qvec."): v for k, v in state_dict.items()}
+ state_dict = {
+ k.replace("model.all_heads.head", "model.head"): v for k, v in state_dict.items()
+ }
+
+ # Replace output naming
+ state_dict = {
+ k.replace("output_conv2_additional.sky_mask", "sky_output_conv2"): v
+ for k, v in state_dict.items()
+ }
+ state_dict = {k.replace("_ray.", "_aux."): v for k, v in state_dict.items()}
+
+ # Update GS-DPT head naming and value
+ state_dict = {k.replace("gaussian_param_head.", "gs_head."): v for k, v in state_dict.items()}
+
+ return state_dict
+
+
+def convert_metric_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """
+ Convert metric model state dict to match current model architecture.
+
+ Args:
+ state_dict: Original metric state dictionary
+
+ Returns:
+ Converted state dictionary
+ """
+ # Add module prefix for metric models
+ state_dict = {"module." + k: v for k, v in state_dict.items()}
+ return convert_general_state_dict(state_dict)
+
+
+def load_pretrained_weights(model, model_path: str, is_metric: bool = False) -> Tuple[list, list]:
+ """
+ Load pretrained weights for a single model.
+
+ Args:
+ model: Model instance to load weights into
+ model_path: Path to the pretrained weights
+ is_metric: Whether this is a metric model
+
+ Returns:
+ Tuple of (missed_keys, unexpected_keys)
+ """
+ state_dict = torch.load(model_path, map_location="cpu")
+
+ if is_metric:
+ state_dict = convert_metric_state_dict(state_dict)
+ else:
+ state_dict = convert_general_state_dict(state_dict)
+
+ missed, unexpected = model.load_state_dict(state_dict, strict=False)
+ logger.info("Missed keys:", missed)
+ logger.info("Unexpected keys:", unexpected)
+
+ return missed, unexpected
+
+
+def load_pretrained_nested_weights(
+ model, main_model_path: str, metric_model_path: str
+) -> Tuple[list, list]:
+ """
+ Load pretrained weights for a nested model with both main and metric branches.
+
+ Args:
+ model: Nested model instance
+ main_model_path: Path to main model weights
+ metric_model_path: Path to metric model weights
+
+ Returns:
+ Tuple of (missed_keys, unexpected_keys)
+ """
+ # Load main model weights
+ state_dict0 = torch.load(main_model_path, map_location="cpu")
+ state_dict0 = convert_general_state_dict(state_dict0)
+ state_dict0 = {k.replace("model.", "model.da3."): v for k, v in state_dict0.items()}
+
+ # Load metric model weights
+ state_dict1 = torch.load(metric_model_path, map_location="cpu")
+ state_dict1 = convert_metric_state_dict(state_dict1)
+ state_dict1 = {k.replace("model.", "model.da3_metric."): v for k, v in state_dict1.items()}
+
+ # Combine state dictionaries
+ combined_state_dict = state_dict0.copy()
+ combined_state_dict.update(state_dict1)
+
+ missed, unexpected = model.load_state_dict(combined_state_dict, strict=False)
+
+ print("Missed keys:", missed)
+ print("Unexpected keys:", unexpected)
+
+ return missed, unexpected
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/parallel_utils.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/parallel_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ff108e95d205097f9f2012ab87ad9e265d58d8d
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/parallel_utils.py
@@ -0,0 +1,133 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import asyncio
+import os
+from functools import wraps
+from multiprocessing.pool import ThreadPool
+from threading import Thread
+from typing import Callable, Dict, List
+import imageio
+from tqdm import tqdm
+
+
+def async_call_func(func):
+ @wraps(func)
+ async def wrapper(*args, **kwargs):
+ loop = asyncio.get_event_loop()
+ # Use run_in_executor to run the blocking function in a separate thread
+ return await loop.run_in_executor(None, func, *args, **kwargs)
+
+ return wrapper
+
+
+slice_func = lambda chunk_index, chunk_dim, chunk_size: [slice(None)] * chunk_dim + [
+ slice(chunk_index, chunk_index + chunk_size)
+]
+
+
+def async_call(fn):
+ def wrapper(*args, **kwargs):
+ Thread(target=fn, args=args, kwargs=kwargs).start()
+
+ return wrapper
+
+
+def _save_image_impl(save_img, save_path):
+ """Common implementation for saving images synchronously or asynchronously"""
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
+ imageio.imwrite(save_path, save_img)
+
+
+@async_call
+def save_image_async(save_img, save_path):
+ """Save image asynchronously"""
+ _save_image_impl(save_img, save_path)
+
+
+def save_image(save_img, save_path):
+ """Save image synchronously"""
+ _save_image_impl(save_img, save_path)
+
+
+def parallel_execution(
+ *args,
+ action: Callable,
+ num_processes=32,
+ print_progress=False,
+ sequential=False,
+ async_return=False,
+ desc=None,
+ **kwargs,
+):
+ # Partially copy from EasyVolumetricVideo (parallel_execution)
+ # NOTE: we expect first arg / or kwargs to be distributed
+ # NOTE: print_progress arg is reserved.
+ # `*args` packs all positional arguments passed to the function into a tuple
+ args = list(args)
+
+ def get_length(args: List, kwargs: Dict):
+ for a in args:
+ if isinstance(a, list):
+ return len(a)
+ for v in kwargs.values():
+ if isinstance(v, list):
+ return len(v)
+ raise NotImplementedError
+
+ def get_action_args(length: int, args: List, kwargs: Dict, i: int):
+ action_args = [
+ (arg[i] if isinstance(arg, list) and len(arg) == length else arg) for arg in args
+ ]
+ # TODO: Support all types of iterable
+ action_kwargs = {
+ key: (
+ kwargs[key][i]
+ if isinstance(kwargs[key], list) and len(kwargs[key]) == length
+ else kwargs[key]
+ )
+ for key in kwargs
+ }
+ return action_args, action_kwargs
+
+ if not sequential:
+ # Create ThreadPool
+ pool = ThreadPool(processes=num_processes)
+
+ # Spawn threads
+ results = []
+ asyncs = []
+ length = get_length(args, kwargs)
+ for i in range(length):
+ action_args, action_kwargs = get_action_args(length, args, kwargs, i)
+ async_result = pool.apply_async(action, action_args, action_kwargs)
+ asyncs.append(async_result)
+
+ # Join threads and get return values
+ if not async_return:
+ for async_result in tqdm(asyncs, desc=desc, disable=not print_progress):
+ results.append(async_result.get()) # will sync the corresponding thread
+ pool.close()
+ pool.join()
+ return results
+ else:
+ return pool
+ else:
+ results = []
+ length = get_length(args, kwargs)
+ for i in tqdm(range(length), desc=desc, disable=not print_progress):
+ action_args, action_kwargs = get_action_args(length, args, kwargs, i)
+ async_result = action(*action_args, **action_kwargs)
+ results.append(async_result)
+ return results
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/pca_utils.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/pca_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b9eee268cd8692d885bf093700a5752077b9d7d
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/pca_utils.py
@@ -0,0 +1,284 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+"""
+PCA utilities for feature visualization and dimensionality reduction (video-friendly).
+- Support frame-by-frame: transform_frame / transform_video
+- Support one-time global PCA fitting and reuse (mean, V3) for stable colors
+- Support Procrustes alignment (solving principal component order/sign/rotation jumps)
+- Support global fixed or temporal EMA for percentiles (time dimension only, no spatial)
+"""
+
+import numpy as np
+import torch
+
+
+def pca_to_rgb_4d_bf16_percentile(
+ x_np: np.ndarray,
+ device=None,
+ q_oversample: int = 6,
+ clip_percent: float = 10.0, # Percentage to clip from top and bottom (0~49.9)
+ return_uint8: bool = False,
+ enable_autocast_bf16: bool = True,
+):
+ """
+ Reduce numpy array of shape (49, 27, 36, 3072) to 3D via PCA and visualize as (49, 27, 36, 3).
+ - PCA uses torch.pca_lowrank (randomized SVD), defaults to GPU.
+ - Uses CUDA bf16 autocast in computation (if available),
+ then per-channel percentile clipping and normalization.
+ - Default removes 5% outliers from top and bottom (adjustable via clip_percent) to
+ improve visualization contrast.
+
+ Parameters
+ ----------
+ x_np : np.ndarray
+ Shape must be (49, 27, 36, 3072). dtype recommended float32/float64.
+ device : str | None
+ Specify 'cuda' or 'cpu'. Auto-select if None (prefer cuda).
+ q_oversample : int
+ Oversampling q for pca_lowrank, must be >= 3.
+ Slightly larger than target dim (3) is more stable, default 6.
+ clip_percent : float
+ Percentage to clip from top and bottom (0~49.9),
+ e.g. 5.0 means clip lowest 5% and highest 5% per channel.
+ return_uint8 : bool
+ True returns uint8(0~255), otherwise returns float32(0~1).
+ enable_autocast_bf16 : bool
+ Enable bf16 autocast on CUDA.
+
+ Returns
+ -------
+ np.ndarray
+ Array of shape (49, 27, 36, 3), float32[0,1] or uint8[0,255].
+ """
+ assert (
+ x_np.ndim == 4
+ ) # and x_np.shape[-1] == 3072, f"expect (49,27,36,3072), got {x_np.shape}"
+ B1, B2, B3, D = x_np.shape
+ N = B1 * B2 * B3
+
+ # Device selection
+ if device is None:
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ # Convert input to torch, unified float32
+ X = torch.from_numpy(x_np.reshape(N, D)).to(device=device, dtype=torch.float32)
+
+ # Parameter and safety checks
+ k = 3
+ q = max(int(q_oversample), k)
+ clip_percent = float(clip_percent)
+ if not (0.0 <= clip_percent < 50.0):
+ raise ValueError(
+ "clip_percent must be in [0, 50), e.g. 5.0 means clip 5% from top and bottom"
+ )
+ low = clip_percent / 100.0
+ high = 1.0 - low
+
+ with torch.no_grad():
+ # Zero mean
+ mean = X.mean(dim=0, keepdim=True)
+ Xc = X - mean
+
+ # Main computation: PCA + projection, try to use bf16
+ # (auto-fallback if operator not supported)
+ device.startswith("cuda") and enable_autocast_bf16
+ U, S, V = torch.pca_lowrank(Xc, q=q, center=False) # V: (D, q)
+ V3 = V[:, :k] # (3072, 3)
+ PCs = Xc @ V3 # (N, 3)
+
+ # === Per-channel percentile clipping and normalization to [0,1] ===
+ # Vectorized one-time calculation of low/high percentiles for each channel
+ qs = torch.tensor([low, high], device=PCs.device, dtype=PCs.dtype)
+ qvals = torch.quantile(PCs, q=qs, dim=0) # Shape (2, 3)
+ lo = qvals[0] # (3,)
+ hi = qvals[1] # (3,)
+
+ # Avoid degenerate case where hi==lo
+ denom = torch.clamp(hi - lo, min=1e-8)
+
+ # Broadcast clipping + normalization
+ PCs = torch.clamp(PCs, lo, hi)
+ PCs = (PCs - lo) / denom # (N, 3) in [0,1]
+
+ # Restore 4D
+ PCs = PCs.reshape(B1, B2, B3, k)
+
+ # Output
+ if return_uint8:
+ out = (PCs * 255.0).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
+ else:
+ out = PCs.clamp(0, 1).to(torch.float32).cpu().numpy()
+
+ return out
+
+
+class PCARGBVisualizer:
+ """
+ Stable PCAβRGB for video features shaped (T, H, W, D) or a single frame (H, W, D).
+ - Global mean/V3 reference for stable colors
+ - Per-frame PCA with Procrustes alignment to V3_ref (basis_mode='procrustes')
+ - Percentile normalization with global or EMA stats (time-only, no spatial smoothing)
+ """
+
+ def __init__(
+ self,
+ device=None,
+ q_oversample: int = 16,
+ clip_percent: float = 10.0,
+ return_uint8: bool = False,
+ enable_autocast_bf16: bool = True,
+ basis_mode: str = "procrustes", # 'fixed' | 'procrustes'
+ percentile_mode: str = "ema", # 'global' | 'ema'
+ ema_alpha: float = 0.1,
+ denom_eps: float = 1e-4,
+ ):
+ assert 0.0 <= clip_percent < 50.0
+ assert basis_mode in ("fixed", "procrustes")
+ assert percentile_mode in ("global", "ema")
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
+ self.q = max(int(q_oversample), 6)
+ self.clip_percent = float(clip_percent)
+ self.return_uint8 = return_uint8
+ self.enable_autocast_bf16 = enable_autocast_bf16
+ self.basis_mode = basis_mode
+ self.percentile_mode = percentile_mode
+ self.ema_alpha = float(ema_alpha)
+ self.denom_eps = float(denom_eps)
+
+ # reference state
+ self.mean_ref = None # (1, D)
+ self.V3_ref = None # (D, 3)
+ self.lo_ref = None # (3,)
+ self.hi_ref = None # (3,)
+
+ @torch.no_grad()
+ def fit_reference(self, frames):
+ """
+ Fit global mean/V3 and initialize percentiles from a reference set.
+ frames: ndarray (T,H,W,D) or list of (H,W,D)
+ """
+ if isinstance(frames, np.ndarray):
+ if frames.ndim != 4:
+ raise ValueError("fit_reference expects (T,H,W,D) ndarray.")
+ T, H, W, D = frames.shape
+ X = torch.from_numpy(frames.reshape(T * H * W, D))
+ else: # list of (H,W,D)
+ xs = [torch.from_numpy(x.reshape(-1, x.shape[-1])) for x in frames]
+ D = xs[0].shape[-1]
+ X = torch.cat(xs, dim=0)
+
+ X = X.to(self.device, dtype=torch.float32)
+ X = torch.nan_to_num(X, nan=0.0, posinf=1e6, neginf=-1e6)
+
+ mean = X.mean(0, keepdim=True)
+ Xc = X - mean
+
+ U, S, V = torch.pca_lowrank(Xc, q=max(self.q, 8), center=False)
+ V3 = V[:, :3] # (D,3)
+
+ PCs = Xc @ V3
+ low = self.clip_percent / 100.0
+ high = 1.0 - low
+ qs = torch.tensor([low, high], device=PCs.device, dtype=PCs.dtype)
+ qvals = torch.quantile(PCs, q=qs, dim=0)
+ lo, hi = qvals[0], qvals[1]
+
+ self.mean_ref = mean
+ self.V3_ref = V3
+ if self.percentile_mode == "global":
+ self.lo_ref, self.hi_ref = lo, hi
+ else:
+ self.lo_ref = lo.clone()
+ self.hi_ref = hi.clone()
+
+ @torch.no_grad()
+ def _project_with_stable_colors(self, X: torch.Tensor) -> torch.Tensor:
+ """
+ X: (N,D) where N = H*W
+ Returns PCs_raw: (N,3) using stable basis (fixed or Procrustes-aligned)
+ """
+ assert self.mean_ref is not None and self.V3_ref is not None, "Call fit_reference() first."
+ X = torch.nan_to_num(X, nan=0.0, posinf=1e6, neginf=-1e6)
+ Xc = X - self.mean_ref
+
+ if self.basis_mode == "fixed":
+ V3_used = self.V3_ref
+ else:
+ U, S, V = torch.pca_lowrank(Xc, q=max(self.q, 6), center=False)
+ V3 = V[:, :3] # (D,3)
+ M = V3.T @ self.V3_ref
+ Uo, So, Vh = torch.linalg.svd(M)
+ R = Uo @ Vh
+ V3_used = V3 @ R
+ # Optional polarity fix via anchor
+ a = self.V3_ref.mean(0, keepdim=True)
+ sign = torch.sign((V3_used * a).sum(0, keepdim=True)).clamp(min=-1)
+ V3_used = V3_used * sign
+
+ return Xc @ V3_used
+
+ @torch.no_grad()
+ def _normalize_rgb(self, PCs_raw: torch.Tensor) -> torch.Tensor:
+ assert self.lo_ref is not None and self.hi_ref is not None
+ if self.percentile_mode == "global":
+ lo, hi = self.lo_ref, self.hi_ref
+ else:
+ low = self.clip_percent / 100.0
+ high = 1.0 - low
+ qs = torch.tensor([low, high], device=PCs_raw.device, dtype=PCs_raw.dtype)
+ qvals = torch.quantile(PCs_raw, q=qs, dim=0)
+ lo_now, hi_now = qvals[0], qvals[1]
+ a = self.ema_alpha
+ self.lo_ref = (1 - a) * self.lo_ref + a * lo_now
+ self.hi_ref = (1 - a) * self.hi_ref + a * hi_now
+ lo, hi = self.lo_ref, self.hi_ref
+
+ denom = torch.clamp(hi - lo, min=self.denom_eps)
+ PCs = torch.clamp(PCs_raw, lo, hi)
+ PCs = (PCs - lo) / denom
+ return PCs.clamp_(0, 1)
+
+ @torch.no_grad()
+ def transform_frame(self, frame: np.ndarray) -> np.ndarray:
+ """
+ frame: (H,W,D) -> (H,W,3)
+ """
+ if frame.ndim != 3:
+ raise ValueError("transform_frame expects (H,W,D).")
+ H, W, D = frame.shape
+ X = torch.from_numpy(frame.reshape(H * W, D)).to(self.device, dtype=torch.float32)
+ PCs_raw = self._project_with_stable_colors(X)
+ PCs = self._normalize_rgb(PCs_raw).reshape(H, W, 3)
+ if self.return_uint8:
+ return (PCs * 255.0).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
+ return PCs.to(torch.float32).cpu().numpy()
+
+ @torch.no_grad()
+ def transform_video(self, frames) -> np.ndarray:
+ """
+ frames: (T,H,W,D) or list of (H,W,D)
+ returns: (T,H,W,3)
+ """
+ outs = []
+ if isinstance(frames, np.ndarray):
+ if frames.ndim != 4:
+ raise ValueError("transform_video expects (T,H,W,D).")
+ T, H, W, D = frames.shape
+ for t in range(T):
+ outs.append(self.transform_frame(frames[t]))
+ else:
+ for f in frames:
+ outs.append(self.transform_frame(f))
+ return np.stack(outs, axis=0)
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/pose_align.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/pose_align.py
new file mode 100644
index 0000000000000000000000000000000000000000..695d07fc1210e3cc3614c9b22c56d765b1106500
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/pose_align.py
@@ -0,0 +1,347 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import List
+import numpy as np
+import torch
+from evo.core.trajectory import PosePath3D
+
+from depth_anything_3.utils.geometry import affine_inverse, affine_inverse_np
+
+
+def batch_apply_alignment_to_enc(
+ rots: torch.Tensor, trans: torch.Tensor, scales: torch.Tensor, enc_list: List[torch.Tensor]
+):
+ pass
+
+
+def batch_apply_alignment_to_ext(
+ rots: torch.Tensor, trans: torch.Tensor, scales: torch.Tensor, ext: torch.Tensor
+):
+ device, _ = ext.device, ext.dtype
+ if ext.shape[-2:] == (3, 4):
+ pad = torch.zeros((*ext.shape[:-2], 4, 4), dtype=ext.dtype, device=device)
+ pad[..., :3, :4] = ext
+ pad[..., 3, 3] = 1.0
+ ext = pad
+ pose_est = affine_inverse(ext)
+ pose_new_align_rot = rots[:, None] @ pose_est[..., :3, :3]
+ pose_new_align_trans = (
+ scales[:, None, None] * (rots[:, None] @ pose_est[..., :3, 3:])[..., 0] + trans[:, None]
+ )
+ pose_new_align = torch.zeros_like(ext)
+ pose_new_align[..., :3, :3] = pose_new_align_rot
+ pose_new_align[..., :3, 3] = pose_new_align_trans
+ pose_new_align[..., 3, 3] = 1.0
+ return affine_inverse(pose_new_align)[:, :3]
+
+
+def batch_align_poses_umeyama(ext_ref: torch.Tensor, ext_est: torch.Tensor):
+ device, dtype = ext_ref.device, ext_ref.dtype
+ assert ext_ref.dtype in [torch.float32, torch.float64]
+ assert ext_est.dtype in [torch.float32, torch.float64]
+ assert ext_ref.requires_grad is False
+ assert ext_est.requires_grad is False
+ rots, trans, scales = [], [], []
+ for b in range(ext_ref.shape[0]):
+ r, t, s = align_poses_umeyama(ext_ref[b].cpu().numpy(), ext_est[b].cpu().numpy())
+ rots.append(torch.from_numpy(r).to(device=device, dtype=dtype))
+ trans.append(torch.from_numpy(t).to(device=device, dtype=dtype))
+ scales.append(torch.tensor(s, device=device, dtype=dtype))
+ return torch.stack(rots), torch.stack(trans), torch.stack(scales)
+
+
+# Dependencies: affine_inverse_np, PosePath3D (maintain consistency with your existing project)
+
+
+def _to44(ext):
+ if ext.shape[1] == 3:
+ out = np.eye(4)[None].repeat(len(ext), 0)
+ out[:, :3, :4] = ext
+ return out
+ return ext
+
+
+def _poses_from_ext(ext_ref, ext_est):
+ ext_ref = _to44(ext_ref)
+ ext_est = _to44(ext_est)
+ pose_ref = affine_inverse_np(ext_ref)
+ pose_est = affine_inverse_np(ext_est)
+ return pose_ref, pose_est
+
+
+def _umeyama_sim3_from_paths(pose_ref, pose_est):
+ path_ref = PosePath3D(poses_se3=pose_ref.copy())
+ path_est = PosePath3D(poses_se3=pose_est.copy())
+ r, t, s = path_est.align(path_ref, correct_scale=True)
+ pose_est_aligned = np.stack(path_est.poses_se3)
+ return r, t, s, pose_est_aligned
+
+
+def _apply_sim3_to_poses(poses, r, t, s):
+ out = poses.copy()
+ Ri = poses[:, :3, :3]
+ ti = poses[:, :3, 3]
+ out[:, :3, :3] = r @ Ri
+ out[:, :3, 3] = (r @ (s * ti.T)).T + t
+ return out
+
+
+def _median_nn_thresh(pose_ref, pose_est_aligned):
+ P_ref = pose_ref[:, :3, 3]
+ P_est = pose_est_aligned[:, :3, 3]
+ dists = []
+ for p in P_est:
+ dd = np.linalg.norm(P_ref - p[None, :], axis=1)
+ dists.append(dd.min())
+ return float(np.median(dists)) if dists else 0.0
+
+
+def _ransac_align_sim3(
+ pose_ref, pose_est, sub_n=None, inlier_thresh=None, max_iters=10, random_state=None
+):
+ rng = np.random.default_rng(random_state)
+ N = pose_ref.shape[0]
+ idx_all = np.arange(N)
+ if sub_n is None:
+ sub_n = max(3, (N + 1) // 2)
+ else:
+ sub_n = max(3, min(sub_n, N))
+
+ # Pre-alignment + default threshold
+ r0, t0, s0, pose_est0 = _umeyama_sim3_from_paths(pose_ref, pose_est)
+ if inlier_thresh is None:
+ inlier_thresh = _median_nn_thresh(pose_ref, pose_est0)
+
+ P_ref_all = pose_ref[:, :3, 3]
+
+ best_model = (r0, t0, s0)
+ best_inliers = None
+ best_score = (-1, np.inf) # (num_inliers, mean_err)
+
+ for _ in range(max_iters):
+ sample = rng.choice(idx_all, size=sub_n, replace=False)
+ try:
+ r, t, s, _ = _umeyama_sim3_from_paths(pose_ref[sample], pose_est[sample])
+ except Exception:
+ continue
+ pose_h = _apply_sim3_to_poses(pose_est, r, t, s)
+ P_h = pose_h[:, :3, 3]
+ errs = np.linalg.norm(P_h - P_ref_all, axis=1) # Match by same index
+ inliers = errs <= inlier_thresh
+ k = int(inliers.sum())
+ mean_err = float(errs[inliers].mean()) if k > 0 else np.inf
+ if (k > best_score[0]) or (k == best_score[0] and mean_err < best_score[1]):
+ best_score = (k, mean_err)
+ best_model = (r, t, s)
+ best_inliers = inliers
+
+ # Fit again with best inliers
+ if best_inliers is not None and best_inliers.sum() >= 3:
+ r, t, s, _ = _umeyama_sim3_from_paths(pose_ref[best_inliers], pose_est[best_inliers])
+ else:
+ r, t, s = best_model
+ return r, t, s
+
+
+def align_poses_umeyama(
+ ext_ref: np.ndarray,
+ ext_est: np.ndarray,
+ return_aligned=False,
+ ransac=False,
+ sub_n=None,
+ inlier_thresh=None,
+ ransac_max_iters=10,
+ random_state=None,
+):
+ """
+ Align estimated trajectory to reference using Umeyama Sim(3).
+ Default no RANSAC; if ransac=True, use RANSAC (max iterations default 10).
+ - sub_n defaults to half the number of frames (rounded up, at least 3)
+ - inlier_thresh defaults to median of "distance from each estimated pose to
+ nearest reference pose after pre-alignment"
+ Returns rotation (3x3), translation (3,), scale; optionally returns aligned extrinsics (4x4).
+ """
+ pose_ref, pose_est = _poses_from_ext(ext_ref, ext_est)
+
+ if not ransac:
+ r, t, s, pose_est_aligned = _umeyama_sim3_from_paths(pose_ref, pose_est)
+ else:
+ r, t, s = _ransac_align_sim3(
+ pose_ref,
+ pose_est,
+ sub_n=sub_n,
+ inlier_thresh=inlier_thresh,
+ max_iters=ransac_max_iters,
+ random_state=random_state,
+ )
+ pose_est_aligned = _apply_sim3_to_poses(pose_est, r, t, s)
+
+ if return_aligned:
+ ext_est_aligned = affine_inverse_np(pose_est_aligned)
+ return r, t, s, ext_est_aligned
+ return r, t, s
+
+
+# def align_poses_umeyama(ext_ref: np.ndarray, ext_est: np.ndarray, return_aligned=False):
+# """
+# Align estimated trajectory to reference trajectory using Umeyama Sim(3)
+# alignment (via evo PosePath3D). # noqa
+# Returns rotation, translation, and scale.
+# """
+# # If input extrinsics are 3x4, convert to 4x4 by padding
+# if ext_ref.shape[1] == 3:
+# ext_ref_ = np.eye(4)[None].repeat(len(ext_ref), 0)
+# ext_ref_[:, :3] = ext_ref
+# ext_ref = ext_ref_
+# if ext_est.shape[1] == 3:
+# ext_est_ = np.eye(4)[None].repeat(len(ext_est), 0)
+# ext_est_[:, :3] = ext_est
+# ext_est = ext_est_
+
+# # Convert to camera poses (inverse extrinsics)
+# pose_ref = affine_inverse_np(ext_ref)
+# pose_est = affine_inverse_np(ext_est)
+
+# # Create evo PosePath3D objects
+# path_ref = PosePath3D(poses_se3=pose_ref)
+# path_est = PosePath3D(poses_se3=pose_est)
+# r, t, s = path_est.align(path_ref, correct_scale=True)
+# if return_aligned:
+# return r, t, s, affine_inverse_np(np.stack(path_est.poses_se3))
+# else:
+# return r, t, s
+
+
+def apply_umeyama_alignment_to_ext(
+ rot: np.ndarray, # (3,3)
+ trans: np.ndarray, # (3,) or (1,3)
+ scale: float,
+ ext_est: np.ndarray, # (...,4,4) or (...,3,4)
+) -> np.ndarray:
+ """
+ Apply Sim(3) (R, t, s) to a batch of world-to-camera extrinsics ext_est.
+ Returns the aligned extrinsics, with the same shape as input.
+ """
+
+ # Allow 3x4 extrinsics: pad to 4x4
+ if ext_est.shape[-2:] == (3, 4):
+ pad = np.zeros((*ext_est.shape[:-2], 4, 4), dtype=ext_est.dtype)
+ pad[..., :3, :4] = ext_est
+ pad[..., 3, 3] = 1.0
+ ext_est = pad
+
+ # Convert world-to-camera to camera-to-world
+ pose_est = affine_inverse_np(ext_est) # (...,4,4)
+ R_e = pose_est[..., :3, :3] # (...,3,3)
+ t_e = pose_est[..., :3, 3] # (...,3)
+
+ # Apply Sim(3) transformation
+ R_a = np.einsum("ij,...jk->...ik", rot, R_e) # (...,3,3)
+ t_a = scale * np.einsum("ij,...j->...i", rot, t_e) + trans # (...,3)
+
+ # Assemble the transformed pose
+ pose_a = np.zeros_like(pose_est)
+ pose_a[..., :3, :3] = R_a
+ pose_a[..., :3, 3] = t_a
+ pose_a[..., 3, 3] = 1.0
+
+ # Convert back to world-to-camera
+ return affine_inverse_np(pose_a)
+
+
+def transform_points_sim3(points, rot, trans, scale, inverse=False):
+ """
+ Sim(3) transform point cloud
+ points: (N, 3)
+ rot: (3, 3)
+ trans: (3,) or (1, 3)
+ scale: float
+ inverse: Whether to do inverse transform (ref->est)
+ Returns: (N, 3)
+ """
+ if not inverse:
+ # Forward: est -> ref
+ return scale * (points @ rot.T) + trans
+ else:
+ # Inverse: ref -> est
+ return ((points - trans) @ rot) / scale
+
+
+def _rand_rot():
+ u1, u2, u3 = np.random.rand(3)
+ q = np.array(
+ [
+ np.sqrt(1 - u1) * np.sin(2 * np.math.pi * u2),
+ np.sqrt(1 - u1) * np.cos(2 * np.math.pi * u2),
+ np.sqrt(u1) * np.sin(2 * np.math.pi * u3),
+ np.sqrt(u1) * np.cos(2 * np.math.pi * u3),
+ ]
+ )
+ w, x, y, z = q
+ return np.array(
+ [
+ [1 - 2 * (y * y + z * z), 2 * (x * y - z * w), 2 * (x * z + y * w)],
+ [2 * (x * y + z * w), 1 - 2 * (x * x + z * z), 2 * (y * z - x * w)],
+ [2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x * x + y * y)],
+ ]
+ )
+
+
+def _rand_pose():
+ R, t = _rand_rot(), np.random.randn(3)
+ P = np.eye(4)
+ P[:3, :3] = R
+ P[:3, 3] = t
+ return P
+
+
+if __name__ == "__main__":
+ np.random.seed(42)
+ # 1. Randomly generate reference trajectory and Sim(3)
+ N = 8
+ pose_ref = np.stack([_rand_pose() for _ in range(N)]) # (N,4,4) camβworld
+ rot_gt = _rand_rot()
+ scale_gt = 2.3
+ trans_gt = np.random.randn(3)
+ # 2. Generate estimated trajectory (apply Sim(3))
+ pose_est = np.zeros_like(pose_ref)
+ for i in range(N):
+ R = pose_ref[i][:3, :3]
+ t = pose_ref[i][:3, 3]
+ pose_est[i][:3, :3] = rot_gt @ R
+ pose_est[i][:3, 3] = scale_gt * (rot_gt @ t) + trans_gt
+ pose_est[i][3, 3] = 1.0
+ # 3. Get extrinsics (world->cam)
+ ext_ref = affine_inverse_np(pose_ref)
+ ext_est = affine_inverse_np(pose_est)
+ # 4. Use umeyama alignment, estimate Sim(3)
+ r_est, t_est, s_est = align_poses_umeyama(ext_ref, ext_est)
+ print("GT scale:", scale_gt, "Estimated:", s_est)
+ print("GT trans:", trans_gt, "Estimated:", t_est)
+ print("GT rot:\n", rot_gt, "\nEstimated:\n", r_est)
+ # 5. Random point cloud, in ref frame
+ num_points = 100
+ points_ref = np.random.randn(num_points, 3)
+ # 6. Use GT Sim(3) inverse transform to est frame
+ points_est = transform_points_sim3(points_ref, rot_gt, trans_gt, scale_gt, inverse=True)
+ # 7. Use estimated Sim(3) forward transform back to ref frame
+ points_ref_recovered = transform_points_sim3(points_est, r_est, t_est, s_est, inverse=False)
+ # 8. Check error
+ err = np.abs(points_ref_recovered - points_ref)
+ print("Point cloud sim3 transform error (mean abs):", err.mean())
+ print("Point cloud sim3 transform error (max abs):", err.max())
+ assert err.mean() < 1e-6, "Mean sim3 transform error too large!"
+ assert err.max() < 1e-5, "Max sim3 transform error too large!"
+ print("Sim(3) point cloud transform & alignment test passed!")
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/read_write_model.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/read_write_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b4cf197f609665e46f14e6b3e8b6657efc5660c
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/read_write_model.py
@@ -0,0 +1,585 @@
+# Copyright (c), ETH Zurich and UNC Chapel Hill.
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+# All rights reserved.
+#
+# This file has been modified by ByteDance Ltd. and/or its affiliates. on 11/05/2025
+#
+# Redistribution and use in source and binary forms, with or without
+# modification, are permitted provided that the following conditions are met:
+#
+# * Redistributions of source code must retain the above copyright
+# notice, this list of conditions and the following disclaimer.
+#
+# * Redistributions in binary form must reproduce the above copyright
+# notice, this list of conditions and the following disclaimer in the
+# documentation and/or other materials provided with the distribution.
+#
+# * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
+# its contributors may be used to endorse or promote products derived
+# from this software without specific prior written permission.
+#
+# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
+# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
+# ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
+# LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
+# CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
+# SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
+# INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
+# CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
+# ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
+# POSSIBILITY OF SUCH DAMAGE.
+
+
+import argparse
+import collections
+import os
+import struct
+import numpy as np
+
+CameraModel = collections.namedtuple("CameraModel", ["model_id", "model_name", "num_params"])
+Camera = collections.namedtuple("Camera", ["id", "model", "width", "height", "params"])
+BaseImage = collections.namedtuple(
+ "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"]
+)
+Point3D = collections.namedtuple(
+ "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"]
+)
+
+
+class Image(BaseImage):
+ def qvec2rotmat(self):
+ return qvec2rotmat(self.qvec)
+
+
+CAMERA_MODELS = {
+ CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
+ CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
+ CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
+ CameraModel(model_id=3, model_name="RADIAL", num_params=5),
+ CameraModel(model_id=4, model_name="OPENCV", num_params=8),
+ CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
+ CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
+ CameraModel(model_id=7, model_name="FOV", num_params=5),
+ CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
+ CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
+ CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12),
+}
+CAMERA_MODEL_IDS = {camera_model.model_id: camera_model for camera_model in CAMERA_MODELS}
+CAMERA_MODEL_NAMES = {camera_model.model_name: camera_model for camera_model in CAMERA_MODELS}
+
+
+def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
+ """Read and unpack the next bytes from a binary file.
+ :param fid:
+ :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
+ :param endian_character: Any of {@, =, <, >, !}
+ :return: Tuple of read and unpacked values.
+ """
+ data = fid.read(num_bytes)
+ return struct.unpack(endian_character + format_char_sequence, data)
+
+
+def write_next_bytes(fid, data, format_char_sequence, endian_character="<"):
+ """pack and write to a binary file.
+ :param fid:
+ :param data: data to send, if multiple elements are sent at the same time,
+ they should be encapsuled either in a list or a tuple
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
+ should be the same length as the data list or tuple
+ :param endian_character: Any of {@, =, <, >, !}
+ """
+ if isinstance(data, (list, tuple)):
+ bytes = struct.pack(endian_character + format_char_sequence, *data)
+ else:
+ bytes = struct.pack(endian_character + format_char_sequence, data)
+ fid.write(bytes)
+
+
+def read_cameras_text(path):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::WriteCamerasText(const std::string& path)
+ void Reconstruction::ReadCamerasText(const std::string& path)
+ """
+ cameras = {}
+ with open(path) as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ camera_id = int(elems[0])
+ model = elems[1]
+ width = int(elems[2])
+ height = int(elems[3])
+ params = np.array(tuple(map(float, elems[4:])))
+ cameras[camera_id] = Camera(
+ id=camera_id,
+ model=model,
+ width=width,
+ height=height,
+ params=params,
+ )
+ return cameras
+
+
+def read_cameras_binary(path_to_model_file):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
+ """
+ cameras = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_cameras = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_cameras):
+ camera_properties = read_next_bytes(fid, num_bytes=24, format_char_sequence="iiQQ")
+ camera_id = camera_properties[0]
+ model_id = camera_properties[1]
+ model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
+ width = camera_properties[2]
+ height = camera_properties[3]
+ num_params = CAMERA_MODEL_IDS[model_id].num_params
+ params = read_next_bytes(
+ fid,
+ num_bytes=8 * num_params,
+ format_char_sequence="d" * num_params,
+ )
+ cameras[camera_id] = Camera(
+ id=camera_id,
+ model=model_name,
+ width=width,
+ height=height,
+ params=np.array(params),
+ )
+ assert len(cameras) == num_cameras
+ return cameras
+
+
+def write_cameras_text(cameras, path):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::WriteCamerasText(const std::string& path)
+ void Reconstruction::ReadCamerasText(const std::string& path)
+ """
+ HEADER = (
+ "# Camera list with one line of data per camera:\n"
+ + "# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n"
+ + f"# Number of cameras: {len(cameras)}\n"
+ )
+ with open(path, "w") as fid:
+ fid.write(HEADER)
+ for _, cam in cameras.items():
+ to_write = [cam.id, cam.model, cam.width, cam.height, *cam.params]
+ line = " ".join([str(elem) for elem in to_write])
+ fid.write(line + "\n")
+
+
+def write_cameras_binary(cameras, path_to_model_file):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(cameras), "Q")
+ for _, cam in cameras.items():
+ model_id = CAMERA_MODEL_NAMES[cam.model].model_id
+ camera_properties = [cam.id, model_id, cam.width, cam.height]
+ write_next_bytes(fid, camera_properties, "iiQQ")
+ for p in cam.params:
+ write_next_bytes(fid, float(p), "d")
+ return cameras
+
+
+def read_images_text(path):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::ReadImagesText(const std::string& path)
+ void Reconstruction::WriteImagesText(const std::string& path)
+ """
+ images = {}
+ with open(path) as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ image_id = int(elems[0])
+ qvec = np.array(tuple(map(float, elems[1:5])))
+ tvec = np.array(tuple(map(float, elems[5:8])))
+ camera_id = int(elems[8])
+ image_name = elems[9]
+ elems = fid.readline().split()
+ xys = np.column_stack(
+ [
+ tuple(map(float, elems[0::3])),
+ tuple(map(float, elems[1::3])),
+ ]
+ )
+ point3D_ids = np.array(tuple(map(int, elems[2::3])))
+ images[image_id] = Image(
+ id=image_id,
+ qvec=qvec,
+ tvec=tvec,
+ camera_id=camera_id,
+ name=image_name,
+ xys=xys,
+ point3D_ids=point3D_ids,
+ )
+ return images
+
+
+def read_images_binary(path_to_model_file):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::ReadImagesBinary(const std::string& path)
+ void Reconstruction::WriteImagesBinary(const std::string& path)
+ """
+ images = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_reg_images = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_reg_images):
+ binary_image_properties = read_next_bytes(
+ fid, num_bytes=64, format_char_sequence="idddddddi"
+ )
+ image_id = binary_image_properties[0]
+ qvec = np.array(binary_image_properties[1:5])
+ tvec = np.array(binary_image_properties[5:8])
+ camera_id = binary_image_properties[8]
+ binary_image_name = b""
+ current_char = read_next_bytes(fid, 1, "c")[0]
+ while current_char != b"\x00": # look for the ASCII 0 entry
+ binary_image_name += current_char
+ current_char = read_next_bytes(fid, 1, "c")[0]
+ image_name = binary_image_name.decode("utf-8")
+ num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[0]
+ x_y_id_s = read_next_bytes(
+ fid,
+ num_bytes=24 * num_points2D,
+ format_char_sequence="ddq" * num_points2D,
+ )
+ xys = np.column_stack(
+ [
+ tuple(map(float, x_y_id_s[0::3])),
+ tuple(map(float, x_y_id_s[1::3])),
+ ]
+ )
+ point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
+ images[image_id] = Image(
+ id=image_id,
+ qvec=qvec,
+ tvec=tvec,
+ camera_id=camera_id,
+ name=image_name,
+ xys=xys,
+ point3D_ids=point3D_ids,
+ )
+ return images
+
+
+def write_images_text(images, path):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::ReadImagesText(const std::string& path)
+ void Reconstruction::WriteImagesText(const std::string& path)
+ """
+ if len(images) == 0:
+ mean_observations = 0
+ else:
+ mean_observations = sum((len(img.point3D_ids) for _, img in images.items())) / len(images)
+ HEADER = (
+ "# Image list with two lines of data per image:\n"
+ + "# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME\n"
+ + "# POINTS2D[] as (X, Y, POINT3D_ID)\n"
+ + "# Number of images: {}, mean observations per image: {}\n".format(
+ len(images), mean_observations
+ )
+ )
+
+ with open(path, "w") as fid:
+ fid.write(HEADER)
+ for _, img in images.items():
+ image_header = [
+ img.id,
+ *img.qvec,
+ *img.tvec,
+ img.camera_id,
+ img.name,
+ ]
+ first_line = " ".join(map(str, image_header))
+ fid.write(first_line + "\n")
+
+ points_strings = []
+ for xy, point3D_id in zip(img.xys, img.point3D_ids):
+ points_strings.append(" ".join(map(str, [*xy, point3D_id])))
+ fid.write(" ".join(points_strings) + "\n")
+
+
+def write_images_binary(images, path_to_model_file):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::ReadImagesBinary(const std::string& path)
+ void Reconstruction::WriteImagesBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(images), "Q")
+ for _, img in images.items():
+ write_next_bytes(fid, img.id, "i")
+ write_next_bytes(fid, img.qvec.tolist(), "dddd")
+ write_next_bytes(fid, img.tvec.tolist(), "ddd")
+ write_next_bytes(fid, img.camera_id, "i")
+ for char in img.name:
+ write_next_bytes(fid, char.encode("utf-8"), "c")
+ write_next_bytes(fid, b"\x00", "c")
+ write_next_bytes(fid, len(img.point3D_ids), "Q")
+ for xy, p3d_id in zip(img.xys, img.point3D_ids):
+ write_next_bytes(fid, [*xy, p3d_id], "ddq")
+
+
+def read_points3D_text(path):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::ReadPoints3DText(const std::string& path)
+ void Reconstruction::WritePoints3DText(const std::string& path)
+ """
+ points3D = {}
+ with open(path) as fid:
+ while True:
+ line = fid.readline()
+ if not line:
+ break
+ line = line.strip()
+ if len(line) > 0 and line[0] != "#":
+ elems = line.split()
+ point3D_id = int(elems[0])
+ xyz = np.array(tuple(map(float, elems[1:4])))
+ rgb = np.array(tuple(map(int, elems[4:7])))
+ error = float(elems[7])
+ image_ids = np.array(tuple(map(int, elems[8::2])))
+ point2D_idxs = np.array(tuple(map(int, elems[9::2])))
+ points3D[point3D_id] = Point3D(
+ id=point3D_id,
+ xyz=xyz,
+ rgb=rgb,
+ error=error,
+ image_ids=image_ids,
+ point2D_idxs=point2D_idxs,
+ )
+ return points3D
+
+
+def read_points3D_binary(path_to_model_file):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
+ """
+ points3D = {}
+ with open(path_to_model_file, "rb") as fid:
+ num_points = read_next_bytes(fid, 8, "Q")[0]
+ for _ in range(num_points):
+ binary_point_line_properties = read_next_bytes(
+ fid, num_bytes=43, format_char_sequence="QdddBBBd"
+ )
+ point3D_id = binary_point_line_properties[0]
+ xyz = np.array(binary_point_line_properties[1:4])
+ rgb = np.array(binary_point_line_properties[4:7])
+ error = np.array(binary_point_line_properties[7])
+ track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[0]
+ track_elems = read_next_bytes(
+ fid,
+ num_bytes=8 * track_length,
+ format_char_sequence="ii" * track_length,
+ )
+ image_ids = np.array(tuple(map(int, track_elems[0::2])))
+ point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
+ points3D[point3D_id] = Point3D(
+ id=point3D_id,
+ xyz=xyz,
+ rgb=rgb,
+ error=error,
+ image_ids=image_ids,
+ point2D_idxs=point2D_idxs,
+ )
+ return points3D
+
+
+def write_points3D_text(points3D, path):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::ReadPoints3DText(const std::string& path)
+ void Reconstruction::WritePoints3DText(const std::string& path)
+ """
+ if len(points3D) == 0:
+ mean_track_length = 0
+ else:
+ mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items())) / len(points3D)
+ HEADER = (
+ "# 3D point list with one line of data per point:\n"
+ + "# POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n"
+ + "# Number of points: {}, mean track length: {}\n".format(
+ len(points3D), mean_track_length
+ )
+ )
+
+ with open(path, "w") as fid:
+ fid.write(HEADER)
+ for _, pt in points3D.items():
+ point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error]
+ fid.write(" ".join(map(str, point_header)) + " ")
+ track_strings = []
+ for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs):
+ track_strings.append(" ".join(map(str, [image_id, point2D])))
+ fid.write(" ".join(track_strings) + "\n")
+
+
+def write_points3D_binary(points3D, path_to_model_file):
+ """
+ see: src/colmap/scene/reconstruction.cc
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
+ """
+ with open(path_to_model_file, "wb") as fid:
+ write_next_bytes(fid, len(points3D), "Q")
+ for _, pt in points3D.items():
+ write_next_bytes(fid, pt.id, "Q")
+ write_next_bytes(fid, pt.xyz.tolist(), "ddd")
+ write_next_bytes(fid, pt.rgb.tolist(), "BBB")
+ write_next_bytes(fid, pt.error, "d")
+ track_length = pt.image_ids.shape[0]
+ write_next_bytes(fid, track_length, "Q")
+ for image_id, point2D_id in zip(pt.image_ids, pt.point2D_idxs):
+ write_next_bytes(fid, [image_id, point2D_id], "ii")
+
+
+def detect_model_format(path, ext):
+ if (
+ os.path.isfile(os.path.join(path, "cameras" + ext))
+ and os.path.isfile(os.path.join(path, "images" + ext))
+ and os.path.isfile(os.path.join(path, "points3D" + ext))
+ ):
+ print("Detected model format: '" + ext + "'")
+ return True
+
+ return False
+
+
+def read_model(path, ext=""):
+ # try to detect the extension automatically
+ if ext == "":
+ if detect_model_format(path, ".bin"):
+ ext = ".bin"
+ elif detect_model_format(path, ".txt"):
+ ext = ".txt"
+ else:
+ print("Provide model format: '.bin' or '.txt'")
+ return
+
+ if ext == ".txt":
+ cameras = read_cameras_text(os.path.join(path, "cameras" + ext))
+ images = read_images_text(os.path.join(path, "images" + ext))
+ points3D = read_points3D_text(os.path.join(path, "points3D") + ext)
+ else:
+ cameras = read_cameras_binary(os.path.join(path, "cameras" + ext))
+ images = read_images_binary(os.path.join(path, "images" + ext))
+ points3D = read_points3D_binary(os.path.join(path, "points3D") + ext)
+ return cameras, images, points3D
+
+
+def write_model(cameras, images, points3D, path, ext=".bin"):
+ if ext == ".txt":
+ write_cameras_text(cameras, os.path.join(path, "cameras" + ext))
+ write_images_text(images, os.path.join(path, "images" + ext))
+ write_points3D_text(points3D, os.path.join(path, "points3D") + ext)
+ else:
+ write_cameras_binary(cameras, os.path.join(path, "cameras" + ext))
+ write_images_binary(images, os.path.join(path, "images" + ext))
+ write_points3D_binary(points3D, os.path.join(path, "points3D") + ext)
+ return cameras, images, points3D
+
+
+def qvec2rotmat(qvec):
+ return np.array(
+ [
+ [
+ 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
+ ],
+ [
+ 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
+ ],
+ [
+ 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
+ ],
+ ]
+ )
+
+
+def rotmat2qvec(R):
+ Rxx, Ryx, Rzx, Rxy, Ryy, Rzy, Rxz, Ryz, Rzz = R.flat
+ K = (
+ np.array(
+ [
+ [Rxx - Ryy - Rzz, 0, 0, 0],
+ [Ryx + Rxy, Ryy - Rxx - Rzz, 0, 0],
+ [Rzx + Rxz, Rzy + Ryz, Rzz - Rxx - Ryy, 0],
+ [Ryz - Rzy, Rzx - Rxz, Rxy - Ryx, Rxx + Ryy + Rzz],
+ ]
+ )
+ / 3.0
+ )
+ eigvals, eigvecs = np.linalg.eigh(K)
+ qvec = eigvecs[[3, 0, 1, 2], np.argmax(eigvals)]
+ if qvec[0] < 0:
+ qvec *= -1
+ return qvec
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Read and write COLMAP binary and text models")
+ parser.add_argument("--input_model", help="path to input model folder")
+ parser.add_argument(
+ "--input_format",
+ choices=[".bin", ".txt"],
+ help="input model format",
+ default="",
+ )
+ parser.add_argument("--output_model", help="path to output model folder")
+ parser.add_argument(
+ "--output_format",
+ choices=[".bin", ".txt"],
+ help="output model format",
+ default=".txt",
+ )
+ args = parser.parse_args()
+
+ cameras, images, points3D = read_model(path=args.input_model, ext=args.input_format)
+
+ print("num_cameras:", len(cameras))
+ print("num_images:", len(images))
+ print("num_points3D:", len(points3D))
+
+ if args.output_model is not None:
+ write_model(
+ cameras,
+ images,
+ points3D,
+ path=args.output_model,
+ ext=args.output_format,
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/registry.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..7db16d525e0417110d87aa5e621b792d8bd95596
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/registry.py
@@ -0,0 +1,36 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any
+from addict import Dict
+
+
+class Registry(Dict[str, Any]):
+ def __init__(self):
+ super().__init__()
+ self._map = Dict({})
+
+ def register(self, name=None):
+ def decorator(cls):
+ key = name or cls.__name__
+ self._map[key] = cls
+ return cls
+
+ return decorator
+
+ def get(self, name):
+ return self._map[name]
+
+ def all(self):
+ return self._map
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/sh_helpers.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/sh_helpers.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a1a4ca8204eb8d858351afb253e41e77dfa2f74
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/sh_helpers.py
@@ -0,0 +1,88 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from math import isqrt
+import torch
+from einops import einsum
+
+try:
+ from e3nn.o3 import matrix_to_angles, wigner_D
+except ImportError:
+ from depth_anything_3.utils.logger import logger
+
+ logger.warn("Dependency 'e3nn' not found. Required for rotating the camera space SH coeff")
+
+
+def project_to_so3_strict(M: torch.Tensor) -> torch.Tensor:
+ if M.shape[-2:] != (3, 3):
+ raise ValueError("Input must be a batch of 3x3 matrices (i.e., shape [..., 3, 3]).")
+
+ # 1. Compute SVD
+ U, S, Vh = torch.linalg.svd(M)
+ V = Vh.mH
+
+ # 2. Handle reflection case (det = -1)
+ det_U = torch.det(U)
+ det_V = torch.det(V)
+ is_reflection = (det_U * det_V) < 0
+ correction_sign = torch.where(
+ is_reflection[..., None],
+ torch.tensor([1, 1, -1.0], device=M.device, dtype=M.dtype),
+ torch.tensor([1, 1, 1.0], device=M.device, dtype=M.dtype),
+ )
+ correction_matrix = torch.diag_embed(correction_sign)
+ U_corrected = U @ correction_matrix
+ R_so3_initial = U_corrected @ V.transpose(-2, -1)
+
+ # 3. Explicitly ensure determinant is 1 (or extremely close)
+ current_det = torch.det(R_so3_initial)
+ det_correction_factor = torch.pow(current_det, -1 / 3)[..., None, None]
+ R_so3_final = R_so3_initial * det_correction_factor
+
+ return R_so3_final
+
+
+def rotate_sh(
+ sh_coefficients: torch.Tensor, # "*#batch n"
+ rotations: torch.Tensor, # "*#batch 3 3"
+) -> torch.Tensor: # "*batch n"
+ # https://github.com/graphdeco-inria/gaussian-splatting/issues/176#issuecomment-2452412653
+ device = sh_coefficients.device
+ dtype = sh_coefficients.dtype
+
+ *_, n = sh_coefficients.shape
+
+ with torch.autocast(device_type=rotations.device.type, enabled=False):
+ rotations_float32 = rotations.to(torch.float32)
+
+ # switch axes: yzx -> xyz
+ P = torch.tensor([[0, 0, 1], [1, 0, 0], [0, 1, 0]]).unsqueeze(0).to(rotations_float32)
+ permuted_rotations = torch.linalg.inv(P) @ rotations_float32 @ P
+
+ # ensure rotation has det == 1 in float32 type
+ permuted_rotations_so3 = project_to_so3_strict(permuted_rotations)
+
+ alpha, beta, gamma = matrix_to_angles(permuted_rotations_so3)
+ result = []
+ for degree in range(isqrt(n)):
+ with torch.device(device):
+ sh_rotations = wigner_D(degree, alpha, -beta, gamma).type(dtype)
+ sh_rotated = einsum(
+ sh_rotations,
+ sh_coefficients[..., degree**2 : (degree + 1) ** 2],
+ "... i j, ... j -> ... i",
+ )
+ result.append(sh_rotated)
+
+ return torch.cat(result, dim=-1)
diff --git a/Depth-Anything-3-anysize/src/depth_anything_3/utils/visualize.py b/Depth-Anything-3-anysize/src/depth_anything_3/utils/visualize.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fd32bddf00e5461f674525e73653409270c0227
--- /dev/null
+++ b/Depth-Anything-3-anysize/src/depth_anything_3/utils/visualize.py
@@ -0,0 +1,120 @@
+# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import matplotlib
+import numpy as np
+import torch
+from einops import rearrange
+
+from depth_anything_3.utils.logger import logger
+
+
+def visualize_depth(
+ depth: np.ndarray,
+ depth_min=None,
+ depth_max=None,
+ percentile=2,
+ ret_minmax=False,
+ ret_type=np.uint8,
+ cmap="Spectral",
+):
+ """
+ Visualize a depth map using a colormap.
+
+ Args:
+ depth: Input depth map array
+ depth_min: Minimum depth value for normalization. If None, uses percentile
+ depth_max: Maximum depth value for normalization. If None, uses percentile
+ percentile: Percentile for min/max computation if not provided
+ ret_minmax: Whether to return min/max depth values
+ ret_type: Return array type (uint8 or float)
+ cmap: Matplotlib colormap name to use
+
+ Returns:
+ Colored depth visualization as numpy array
+ If ret_minmax=True, also returns depth_min and depth_max
+ """
+ depth = depth.copy()
+ depth.copy()
+ valid_mask = depth > 0
+ depth[valid_mask] = 1 / depth[valid_mask]
+ if depth_min is None:
+ if valid_mask.sum() <= 10:
+ depth_min = 0
+ else:
+ depth_min = np.percentile(depth[valid_mask], percentile)
+ if depth_max is None:
+ if valid_mask.sum() <= 10:
+ depth_max = 0
+ else:
+ depth_max = np.percentile(depth[valid_mask], 100 - percentile)
+ if depth_min == depth_max:
+ depth_min = depth_min - 1e-6
+ depth_max = depth_max + 1e-6
+ cm = matplotlib.colormaps[cmap]
+ depth = ((depth - depth_min) / (depth_max - depth_min)).clip(0, 1)
+ depth = 1 - depth
+ img_colored_np = cm(depth[None], bytes=False)[:, :, :, 0:3] # value from 0 to 1
+ if ret_type == np.uint8:
+ img_colored_np = (img_colored_np[0] * 255.0).astype(np.uint8)
+ elif ret_type == np.float32 or ret_type == np.float64:
+ img_colored_np = img_colored_np[0]
+ else:
+ raise ValueError(f"Invalid return type: {ret_type}")
+ if ret_minmax:
+ return img_colored_np, depth_min, depth_max
+ else:
+ return img_colored_np
+
+
+# GS video rendering visulization function, since it operates in Tensor space...
+
+
+def vis_depth_map_tensor(
+ result: torch.Tensor, # "*batch height width"
+ color_map: str = "Spectral",
+) -> torch.Tensor: # "*batch 3 height with"
+ """
+ Color-map the depth map.
+ """
+ far = result.reshape(-1)[:16_000_000].float().quantile(0.99).log().to(result)
+ try:
+ near = result[result > 0][:16_000_000].float().quantile(0.01).log().to(result)
+ except (RuntimeError, ValueError) as e:
+ logger.error(f"No valid depth values found. Reason: {e}")
+ near = torch.zeros_like(far)
+ result = result.log()
+ result = (result - near) / (far - near)
+ return apply_color_map_to_image(result, color_map)
+
+
+def apply_color_map(
+ x: torch.Tensor, # " *batch"
+ color_map: str = "inferno",
+) -> torch.Tensor: # "*batch 3"
+ cmap = matplotlib.cm.get_cmap(color_map)
+
+ # Convert to NumPy so that Matplotlib color maps can be used.
+ mapped = cmap(x.float().detach().clip(min=0, max=1).cpu().numpy())[..., :3]
+
+ # Convert back to the original format.
+ return torch.tensor(mapped, device=x.device, dtype=torch.float32)
+
+
+def apply_color_map_to_image(
+ image: torch.Tensor, # "*batch height width"
+ color_map: str = "inferno",
+) -> torch.Tensor: # "*batch 3 height with"
+ image = apply_color_map(image, color_map)
+ return rearrange(image, "... h w c -> ... c h w")
diff --git a/README.md b/README.md
index c0bef943eb50179a725be8ffb1666694d21f8849..c0f703a2800bf56b28bb69a60ff2fef4f865e3f0 100644
--- a/README.md
+++ b/README.md
@@ -4,31 +4,34 @@ emoji: π
colorFrom: indigo
colorTo: indigo
sdk: gradio
-sdk_version: 5.49.1
+sdk_version: 6.0.0
app_file: app.py
pinned: false
---
# Depth Estimation Comparison Demo
-A ZeroGPU-friendly Gradio interface for comparing **Depth Anything v1**, **Depth Anything v2**, and **Pixel-Perfect Depth (PPD)** on the same image. Switch between side-by-side layouts, a slider overlay, or single-model inspection to understand how different pipelines perceive scene geometry.
+A Gradio interface for comparing **Depth Anything v1**, **Depth Anything v2**, **Depth Anything v3 (AnySize)**, and **Pixel-Perfect Depth (PPD)** on the same image. Switch between side-by-side layouts, a slider overlay, single-model inspection, or a dedicated v3 tab to understand how different pipelines perceive scene geometry. Two entrypoints are provided:
+
+- `app_local.py` β full-featured local runner with minimal memory constraints.
+- `app.py` β ZeroGPU-aware build tuned for HuggingFace Spaces with aggressive cache management.
## π Highlights
-- **Three interactive views**: draggable slider, labeled side-by-side comparison, and original vs depth for any single model.
-- **Multi-family depth models**: run ViT variants from Depth Anything v1/v2 alongside Pixel-Perfect Depth with MoGe metric alignment.
-- **ZeroGPU aware**: on-demand loading, model cache clearing, and torch CUDA cleanup keep GPU usage inside HuggingFace Spaces limits.
-- **Curated examples**: reusable demo images sourced from each model family plus local assets to quickly validate behaviour.
+- **Four interactive experiences**: draggable slider, labeled side-by-side comparison, original-vs-depth slider, and a Depth Anything v3 tab with RGB vs depth visualization + metadata.
+- **Multi-family depth models**: run ViT variants from Depth Anything v1/v2/v3 alongside Pixel-Perfect Depth with MoGe metric alignment.
+- **ZeroGPU aware**: `app.py` performs on-demand loading, cache clearing, and CUDA cleanup to stay within HuggingFace Spaces limits, while `app_local.py` keeps models warm for faster iteration.
+- **Curated examples**: reusable demo images sourced from each model family (`assets/examples`, `Depth-Anything*/assets/examples`, `Depth-Anything-3-anysize/assets/examples`, `Pixel-Perfect-Depth/assets/examples`).
## π Supported Pipelines
- **Depth Anything v1** (`LiheYoung/depth_anything_*`): ViT-S/B/L with fast transformer backbones and colorized outputs via `Spectral_r` colormap.
-- **Depth Anything v2** (`Depth-Anything-V2/checkpoints/*.pth`): ViT-Small/Base/Large with HF Hub fallback, configurable feature channels, and improved edge handling.
+- **Depth Anything v2** (`Depth-Anything-V2/checkpoints/*.pth` or HF Hub mirrors): ViT-Small/Base/Large with configurable feature channels and improved edge handling.
+- **Depth Anything v3 (AnySize)** (`depth-anything/DA3*` via bundled AnySize fork): Nested, giant, large, base, small, mono, and metric variants with native-resolution inference and automatic padding/cropping.
- **Pixel-Perfect Depth**: Diffusion-based relative depth refined by the **MoGe** metric surface model and RANSAC alignment to recover metric depth; customizable denoising steps.
## π₯οΈ App Experience
-- **Slider Comparison**: drag between two predictions with automatically labeled overlays.
+- **Slider Comparison**: drag between any two predictions with automatically labeled overlays.
- **Method Comparison**: view models side-by-side with synchronized layout and captions rendered in OpenCV.
- **Single Model**: inspect the RGB input versus one model output using the Gradio `ImageSlider` component.
-- **Example Gallery**: natural-number sorting across `assets/examples`, `Depth-Anything/assets/examples`, `Depth-Anything-V2/assets/examples`, and `Pixel-Perfect-Depth/assets/examples`.
## π¦ Installation & Setup
@@ -42,44 +45,57 @@ A ZeroGPU-friendly Gradio interface for comparing **Depth Anything v1**, **Depth
```bash
pip install -r requirements.txt
```
-3. **Model assets**:
+3. **Install the AnySize fork** (required for Depth Anything v3 tab):
+ ```bash
+ pip install -e Depth-Anything-3-anysize/.[all]
+ ```
+4. **Model assets**:
- Depth Anything v1 checkpoints stream automatically from the HuggingFace Hub.
- Download Depth Anything v2 weights into `Depth-Anything-V2/checkpoints/` if they are not already present (`depth_anything_v2_vits.pth`, `depth_anything_v2_vitb.pth`, `depth_anything_v2_vitl.pth`).
+ - Depth Anything v3 models download via the bundled AnySize API from `depth-anything/*` repositories at inference time; no manual checkpoints required.
- Pixel-Perfect Depth pulls the diffusion checkpoint (`ppd.pth`) from `gangweix/Pixel-Perfect-Depth` on first use and loads MoGe weights (`Ruicheng/moge-2-vitl-normal`).
-4. **Run the app**:
+5. **Run the app**:
```bash
- python app_local.py # Local UI with live reload tweaks
- python app.py # ZeroGPU-ready launch script
+ python app_local.py # Local UI with v3 tab and warm caches
+ python app.py # ZeroGPU-ready launch script (loads models on demand)
```
### HuggingFace Spaces (ZeroGPU)
1. Push the repository contents to a Gradio Space.
2. Select the **ZeroGPU** hardware preset.
-3. The app will download required checkpoints on demand and aggressively free memory after each inference via `clear_model_cache()`.
+3. The app downloads required checkpoints (Depth Anything v1/v2/v3, PPD, MoGe) on demand and aggressively frees memory via `clear_model_cache()` between requests.
## π Project Structure
```
Depth-Estimation-Compare-demo/
-βββ app.py # ZeroGPU deployment entrypoint
-βββ app_local.py # Local-friendly launch script
-βββ requirements.txt # Python dependencies (Gradio, Torch, PPD stack)
+βββ app.py # ZeroGPU deployment entrypoint (includes v3 tab)
+βββ app_local.py # Local-friendly launch script (full feature set)
+βββ requirements.txt # Python dependencies (Gradio, Torch, PPD stack)
βββ assets/
-β βββ examples/ # Shared demo imagery
-βββ Depth-Anything/ # Depth Anything v1 implementation + utilities
-βββ Depth-Anything-V2/ # Depth Anything v2 implementation & checkpoints
-βββ Pixel-Perfect-Depth/ # Pixel-Perfect Depth diffusion + MoGe helpers
-βββ README.md # You are here
+β βββ examples/ # Shared demo imagery
+βββ Depth-Anything/ # Depth Anything v1 implementation + utilities
+βββ Depth-Anything-V2/ # Depth Anything v2 implementation & checkpoints
+βββ Depth-Anything-3-anysize/ # Bundled AnySize fork powering Depth Anything v3 tab
+β βββ app.py # Standalone AnySize Gradio demo (optional)
+β βββ depth3_anysize.py # Scripted inference example
+β βββ pyproject.toml # Editable install metadata
+β βββ requirements.txt # AnySize-specific dependencies
+β βββ src/depth_anything_3/ # AnySize API, configs, and model code
+βββ Pixel-Perfect-Depth/ # Pixel-Perfect Depth diffusion + MoGe helpers
+βββ README.md # You are here
```
## βοΈ Configuration Notes
-- Model dropdown labels come from `V1_MODEL_CONFIGS`, `V2_MODEL_CONFIGS`, and the PPD entry in `app.py`.
-- `clear_model_cache()` resets every model and flushes CUDA to respect ZeroGPU constraints.
+- Model dropdown labels come from `V1_MODEL_CONFIGS`, `V2_MODEL_CONFIGS`, and `DA3_MODEL_SOURCES` plus the PPD entry in both apps.
+- `clear_model_cache()` resets every model family (v1/v2/v3/PPD) and flushes CUDA to respect ZeroGPU constraints in `app.py`.
+- Depth Anything v3 inference leverages the AnySize API (`process_res=None`, `process_res_method="keep"`) to preserve native resolution and returns processed RGB/depth pairs.
- Pixel-Perfect Depth inference aligns relative depth to metric scale through `recover_metric_depth_ransac()` for consistent visualization.
- Depth visualizations use a normalized `Spectral_r` colormap; PPD uses a dedicated matplotlib colormap for metric maps.
## π Performance Expectations
- **Depth Anything v1**: ViT-S ~1β2 s, ViT-B ~2β4 s, ViT-L ~4β8 s (image dependent).
- **Depth Anything v2**: similar to v1 with improved sharpness; HF downloads add one-time startup overhead.
+- **Depth Anything v3**: nested/giant models are heavier (expect longer cold starts), while base/small options are close to v2 latency when running at native resolution.
- **Pixel-Perfect Depth**: diffusion + metric refinement typically takes longer (10β20 denoise steps) but returns metrically-aligned depth suitable for downstream 3D tasks.
## π― Usage Tips
@@ -95,6 +111,7 @@ Enhancements are welcomeβnew model backends, visualization modes, or memory op
- [Depth Anything v2](https://github.com/DepthAnything/Depth-Anything-V2)
- [Pixel-Perfect Depth](https://github.com/gangweix/pixel-perfect-depth)
- [MoGe](https://huggingface.co/Ruicheng/moge-2-vitl-normal)
+- [Depth Anything 3 AnySize Fork](https://github.com/ByteDance-Seed/Depth-Anything-3) (see bundled `Depth-Anything-3-anysize` directory)
## π License
- Depth Anything v1: MIT License
@@ -104,4 +121,4 @@ Enhancements are welcomeβnew model backends, visualization modes, or memory op
---
-Built as a hands-on playground for exploring modern monocular depth estimators. Adjust tabs, compare outputs, and plug results into your 3D workflows.
+Built as a hands-on playground for exploring modern monocular depth estimators. Adjust tabs, compare outputs, and plug results into your 3D workflows.
\ No newline at end of file
diff --git a/app.py b/app.py
index 2b70e7e5372d9cf7cbcbbc707dee4f90bcd017f0..116603faaf24f25d6b9c0fe202f94fcc45d6d66b 100644
--- a/app.py
+++ b/app.py
@@ -1,7 +1,7 @@
"""
Depth Estimation Comparison Demo (ZeroGPU)
-Compare Depth Anything v1, Depth Anything v2, and Pixel-Perfect Depth side-by-side or with a slider using Gradio.
+Compare Depth Anything v1, Depth Anything v2, Depth Anything v3, and Pixel-Perfect Depth side-by-side or with a slider using Gradio.
Optimized for HuggingFace Spaces with ZeroGPU support.
"""
@@ -9,16 +9,19 @@ import os
import sys
import logging
import gc
-from typing import Optional, Tuple, List
+import inspect
+from typing import Optional, Tuple, List, Dict
import numpy as np
import cv2
import gradio as gr
from huggingface_hub import hf_hub_download
import spaces
+from PIL import Image
# Import v1 and v2 model code
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Depth-Anything"))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Depth-Anything-V2"))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Depth-Anything-3-anysize", "src"))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Pixel-Perfect-Depth"))
# v1 imports
@@ -33,6 +36,10 @@ from depth_anything_v2.dpt import DepthAnythingV2
import matplotlib
+# Depth Anything v3 imports
+from depth_anything_3.api import DepthAnything3
+from depth_anything_3.utils.visualize import visualize_depth
+
# Pixel-Perfect Depth imports
from ppd.utils.set_seed import set_seed
from ppd.utils.align_depth_func import recover_metric_depth_ransac
@@ -82,9 +89,41 @@ V2_MODEL_CONFIGS = {
}
}
+DA3_MODEL_SOURCES: Dict[str, Dict[str, str]] = {
+ "nested_giant_large": {
+ "display_name": "Depth Anything v3 Nested Giant Large",
+ "repo_id": "depth-anything/DA3NESTED-GIANT-LARGE",
+ },
+ "giant": {
+ "display_name": "Depth Anything v3 Giant",
+ "repo_id": "depth-anything/DA3-GIANT",
+ },
+ "large": {
+ "display_name": "Depth Anything v3 Large",
+ "repo_id": "depth-anything/DA3-LARGE",
+ },
+ "base": {
+ "display_name": "Depth Anything v3 Base",
+ "repo_id": "depth-anything/DA3-BASE",
+ },
+ "small": {
+ "display_name": "Depth Anything v3 Small",
+ "repo_id": "depth-anything/DA3-SMALL",
+ },
+ "metric_large": {
+ "display_name": "Depth Anything v3 Metric Large",
+ "repo_id": "depth-anything/DA3METRIC-LARGE",
+ },
+ "mono_large": {
+ "display_name": "Depth Anything v3 Mono Large",
+ "repo_id": "depth-anything/DA3MONO-LARGE",
+ },
+}
+
# Model cache - cleared after each inference for ZeroGPU
_v1_models = {}
_v2_models = {}
+_da3_models: Dict[str, DepthAnything3] = {}
_ppd_model: Optional[PixelPerfectDepth] = None
_moge_model: Optional[MoGeModel] = None
@@ -160,15 +199,83 @@ def load_v2_model(key: str):
_v2_models[key] = model
return model
+
+def load_da3_model(key: str) -> DepthAnything3:
+ if key in _da3_models:
+ return _da3_models[key]
+
+ clear_model_cache()
+
+ repo_id = DA3_MODEL_SOURCES[key]["repo_id"]
+ model = DepthAnything3.from_pretrained(repo_id)
+ model = model.to(device=TORCH_DEVICE)
+ model.eval()
+ _da3_models[key] = model
+ return model
+
+
+def _prep_da3_image(image: np.ndarray) -> np.ndarray:
+ if image.ndim == 2:
+ image = np.stack([image] * 3, axis=-1)
+ if image.dtype != np.uint8:
+ image = np.clip(image, 0, 255).astype(np.uint8)
+ return image
+
+
+def run_da3_inference(model_key: str, image: np.ndarray) -> Tuple[np.ndarray, np.ndarray, str, str]:
+ model = load_da3_model(model_key)
+ if image.ndim == 2:
+ rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
+ else:
+ rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ rgb = _prep_da3_image(rgb)
+ prediction = model.inference(
+ image=[Image.fromarray(rgb)],
+ process_res=None,
+ process_res_method="keep",
+ )
+
+ depth_map = prediction.depth[0]
+ depth_vis = visualize_depth(depth_map, cmap="Spectral")
+ processed_rgb = (
+ prediction.processed_images[0]
+ if getattr(prediction, "processed_images", None) is not None
+ else rgb
+ )
+ processed_rgb = np.clip(processed_rgb, 0, 255).astype(np.uint8)
+
+ target_h, target_w = image.shape[:2]
+ if depth_vis.shape[:2] != (target_h, target_w):
+ depth_vis = cv2.resize(depth_vis, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
+ if processed_rgb.shape[:2] != (target_h, target_w):
+ processed_rgb = cv2.resize(processed_rgb, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
+
+ label = DA3_MODEL_SOURCES[model_key]["display_name"]
+ info_lines = [
+ f"**Model:** `{label}`",
+ f"**Repo:** `{DA3_MODEL_SOURCES[model_key]['repo_id']}`",
+ f"**Device:** `{str(TORCH_DEVICE)}`",
+ f"**Depth shape:** `{tuple(prediction.depth.shape)}`",
+ ]
+ if getattr(prediction, "extrinsics", None) is not None:
+ info_lines.append(f"**Extrinsics shape:** `{prediction.extrinsics.shape}`")
+ if getattr(prediction, "intrinsics", None) is not None:
+ info_lines.append(f"**Intrinsics shape:** `{prediction.intrinsics.shape}`")
+
+ return depth_vis, processed_rgb, "\n".join(info_lines), label
+
def clear_model_cache():
"""Clear model cache to free GPU memory for ZeroGPU"""
- global _v1_models, _v2_models, _ppd_model, _moge_model
+ global _v1_models, _v2_models, _da3_models, _ppd_model, _moge_model
for model in _v1_models.values():
del model
for model in _v2_models.values():
del model
+ for model in _da3_models.values():
+ del model
_v1_models.clear()
_v2_models.clear()
+ _da3_models.clear()
_ppd_model = None
_moge_model = None
gc.collect()
@@ -266,6 +373,8 @@ def get_model_choices() -> List[Tuple[str, str]]:
choices.append((v['display_name'], f'v1_{k}'))
for k, v in V2_MODEL_CONFIGS.items():
choices.append((v['display_name'], f'v2_{k}'))
+ for k, v in DA3_MODEL_SOURCES.items():
+ choices.append((v['display_name'], f'da3_{k}'))
choices.append(("Pixel-Perfect Depth", "ppd"))
return choices
@@ -287,6 +396,10 @@ def run_model(model_key: str, image: np.ndarray) -> Tuple[np.ndarray, str]:
label = V2_MODEL_CONFIGS[key]['display_name']
colored = colorize_depth(depth)
return colored, label
+ elif model_key.startswith('da3_'):
+ key = model_key[4:]
+ depth_vis, _, _, label = run_da3_inference(key, image)
+ return depth_vis, label
elif model_key == 'ppd':
clear_model_cache()
_, colored = pixel_perfect_depth_inference(image)
@@ -429,6 +542,37 @@ def single_inference(image, model: str, progress=gr.Progress()):
# Clean up GPU memory after inference
clear_model_cache()
+
+@spaces.GPU
+def da3_single_inference(image, model: str, progress=gr.Progress()):
+ if image is None:
+ return None, "β Please upload an image."
+
+ try:
+ if isinstance(image, str):
+ np_image = cv2.imread(image)
+ elif hasattr(image, "save"):
+ np_image = np.array(image)
+ if len(np_image.shape) == 3 and np_image.shape[2] == 3:
+ np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
+ else:
+ np_image = np.array(image)
+ if len(np_image.shape) == 3 and np_image.shape[2] == 3:
+ np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
+
+ if np_image is None:
+ raise gr.Error("Invalid image input.")
+
+ key = model[4:] if model.startswith("da3_") else model
+
+ progress(0.1, desc=f"Running {model}")
+ depth_vis, processed_rgb, info_text, _ = run_da3_inference(key, np_image)
+ progress(1.0, desc="Done")
+ return (processed_rgb, depth_vis), info_text
+
+ finally:
+ clear_model_cache()
+
def get_example_images() -> List[str]:
import re
@@ -443,6 +587,7 @@ def get_example_images() -> List[str]:
"assets/examples",
"Depth-Anything/assets/examples",
"Depth-Anything-V2/assets/examples",
+ "Depth-Anything-3-anysize/assets/examples",
"Pixel-Perfect-Depth/assets/examples",
]:
ex_path = os.path.join(os.path.dirname(__file__), ex_dir)
@@ -474,8 +619,19 @@ def create_app():
default2 = next((value for _, value in model_choices if value.startswith('v2_') and value != default1), model_choices[min(1, len(model_choices) - 1)][1])
example_images = get_example_images()
+ da3_choices = [(cfg['display_name'], f"da3_{key}") for key, cfg in DA3_MODEL_SOURCES.items()]
+ if not da3_choices:
+ raise ValueError("Depth Anything v3 models are not configured.")
+ da3_default = next((value for name, value in da3_choices if "Large" in name), da3_choices[0][1])
+
+ blocks_kwargs = {"title": "Depth Estimation Comparison"}
+ try:
+ if "theme" in inspect.signature(gr.Blocks.__init__).parameters and hasattr(gr, "themes"):
+ blocks_kwargs["theme"] = gr.themes.Soft()
+ except (ValueError, TypeError):
+ pass
- with gr.Blocks(title="Depth Estimation Comparison", theme=gr.themes.Soft()) as app:
+ with gr.Blocks(**blocks_kwargs) as app:
gr.Markdown("""
# Depth Estimation Comparison
Compare Depth Anything v1, Depth Anything v2, and Pixel-Perfect Depth side-by-side or with a slider.
@@ -539,6 +695,7 @@ def create_app():
**References:**
- **v1**: [Depth Anything v1](https://github.com/LiheYoung/Depth-Anything)
- **v2**: [Depth Anything v2](https://github.com/DepthAnything/Depth-Anything-V2)
+ - **v3**: [Depth Anything v3](https://github.com/ByteDance-Seed/Depth-Anything-3) & [Depth-Anything-3-anysize](https://github.com/shriarul5273/Depth-Anything-3-anysize)
- **PPD**: [Pixel-Perfect Depth](https://github.com/gangweix/pixel-perfect-depth)
**Note**: This app uses ZeroGPU for efficient GPU resource management. Models are loaded on-demand and GPU memory is automatically cleaned up after each inference.
diff --git a/app_local.py b/app_local.py
index 574564cf83376700e819f6a61b4af18a63af6561..444af91a886c704efb2427b8b1a0b6c3b8e08291 100644
--- a/app_local.py
+++ b/app_local.py
@@ -5,11 +5,14 @@ Compare Depth Anything models (v1 and v2) and Pixel-Perfect Depth side-by-side o
Inspired by the Stereo Matching Methods Comparison Demo.
"""
+from __future__ import annotations
+
import os
import sys
import logging
import tempfile
import shutil
+import inspect
from pathlib import Path
from typing import Optional, Tuple, Dict, List
import numpy as np
@@ -18,10 +21,12 @@ import gradio as gr
from huggingface_hub import hf_hub_download
import open3d as o3d
import trimesh
+from PIL import Image
# Import v1 and v2 model code
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Depth-Anything"))
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Depth-Anything-V2"))
+sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Depth-Anything-3-anysize", "src"))
# v1 imports
from depth_anything.dpt import DepthAnything as DepthAnythingV1
@@ -35,6 +40,10 @@ from depth_anything_v2.dpt import DepthAnythingV2
import matplotlib
+# Depth Anything v3 imports
+from depth_anything_3.api import DepthAnything3
+from depth_anything_3.utils.visualize import visualize_depth
+
# Pixel-Perfect Depth imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Pixel-Perfect-Depth"))
from ppd.utils.set_seed import set_seed
@@ -48,6 +57,7 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
# Device selection
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
+TORCH_DEVICE = torch.device(DEVICE)
# Model configs
V1_MODEL_CONFIGS = {
@@ -83,9 +93,41 @@ V2_MODEL_CONFIGS = {
}
}
+DA3_MODEL_SOURCES = {
+ "nested_giant_large": {
+ "display_name": "Depth Anything v3 Nested Giant Large",
+ "repo_id": "depth-anything/DA3NESTED-GIANT-LARGE",
+ },
+ "giant": {
+ "display_name": "Depth Anything v3 Giant",
+ "repo_id": "depth-anything/DA3-GIANT",
+ },
+ "large": {
+ "display_name": "Depth Anything v3 Large",
+ "repo_id": "depth-anything/DA3-LARGE",
+ },
+ "base": {
+ "display_name": "Depth Anything v3 Base",
+ "repo_id": "depth-anything/DA3-BASE",
+ },
+ "small": {
+ "display_name": "Depth Anything v3 Small",
+ "repo_id": "depth-anything/DA3-SMALL",
+ },
+ "metric_large": {
+ "display_name": "Depth Anything v3 Metric Large",
+ "repo_id": "depth-anything/DA3METRIC-LARGE",
+ },
+ "mono_large": {
+ "display_name": "Depth Anything v3 Mono Large",
+ "repo_id": "depth-anything/DA3MONO-LARGE",
+ },
+}
+
# Model cache
_v1_models = {}
_v2_models = {}
+_da3_models: Dict[str, DepthAnything3] = {}
# v1 transform
v1_transform = Compose([
@@ -146,6 +188,91 @@ def load_v2_model(key: str):
_v2_models[key] = model
return model
+
+def load_da3_model(key: str) -> DepthAnything3:
+ if key in _da3_models:
+ return _da3_models[key]
+ repo_id = DA3_MODEL_SOURCES[key]["repo_id"]
+ model = DepthAnything3.from_pretrained(repo_id)
+ model = model.to(device=TORCH_DEVICE)
+ model.eval()
+ _da3_models[key] = model
+ return model
+
+
+def _prep_da3_image(image: np.ndarray) -> np.ndarray:
+ if image.ndim == 2:
+ image = np.stack([image] * 3, axis=-1)
+ if image.dtype != np.uint8:
+ image = np.clip(image, 0, 255).astype(np.uint8)
+ return image
+
+
+def run_da3_inference(model_key: str, image: np.ndarray) -> Tuple[np.ndarray, np.ndarray, str, str]:
+ model = load_da3_model(model_key)
+ if image.ndim == 2:
+ rgb = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
+ else:
+ rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
+ rgb = _prep_da3_image(rgb)
+ prediction = model.inference(
+ image=[Image.fromarray(rgb)],
+ process_res=None,
+ process_res_method="keep",
+ )
+ depth_map = prediction.depth[0]
+ depth_vis = visualize_depth(depth_map, cmap="Spectral")
+ processed_rgb = (
+ prediction.processed_images[0]
+ if getattr(prediction, "processed_images", None) is not None
+ else rgb
+ )
+ processed_rgb = np.clip(processed_rgb, 0, 255).astype(np.uint8)
+ target_h, target_w = image.shape[:2]
+ if depth_vis.shape[:2] != (target_h, target_w):
+ depth_vis = cv2.resize(depth_vis, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
+ if processed_rgb.shape[:2] != (target_h, target_w):
+ processed_rgb = cv2.resize(processed_rgb, (target_w, target_h), interpolation=cv2.INTER_LINEAR)
+ label = DA3_MODEL_SOURCES[model_key]["display_name"]
+ info_lines = [
+ f"**Model:** `{label}`",
+ f"**Repo:** `{DA3_MODEL_SOURCES[model_key]['repo_id']}`",
+ f"**Device:** `{str(TORCH_DEVICE)}`",
+ f"**Depth shape:** `{tuple(prediction.depth.shape)}`",
+ ]
+ if getattr(prediction, "extrinsics", None) is not None:
+ info_lines.append(f"**Extrinsics shape:** `{prediction.extrinsics.shape}`")
+ if getattr(prediction, "intrinsics", None) is not None:
+ info_lines.append(f"**Intrinsics shape:** `{prediction.intrinsics.shape}`")
+ info_text = "\n".join(info_lines)
+ return depth_vis, processed_rgb, info_text, label
+
+
+def da3_single_inference(image, model: str, progress=gr.Progress()):
+ if image is None:
+ return None, "β Please upload an image."
+
+ if isinstance(image, str):
+ np_image = cv2.imread(image)
+ elif hasattr(image, "save"):
+ np_image = np.array(image)
+ if len(np_image.shape) == 3 and np_image.shape[2] == 3:
+ np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
+ else:
+ np_image = np.array(image)
+ if len(np_image.shape) == 3 and np_image.shape[2] == 3:
+ np_image = cv2.cvtColor(np_image, cv2.COLOR_RGB2BGR)
+
+ if np_image is None:
+ raise gr.Error("Invalid image input.")
+
+ key = model[4:] if model.startswith("da3_") else model
+
+ progress(0.1, desc=f"Running {model}")
+ depth_vis, processed_rgb, info_text, label = run_da3_inference(key, np_image)
+ progress(1.0, desc="Done")
+ return (processed_rgb, depth_vis), info_text
+
def predict_v1(model, image: np.ndarray) -> np.ndarray:
h, w = image.shape[:2]
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) / 255.0
@@ -171,8 +298,6 @@ def colorize_depth(depth: np.ndarray) -> np.ndarray:
# Pixel-Perfect Depth setup -------------------------------------------------
set_seed(666)
-
-TORCH_DEVICE = torch.device(DEVICE)
PPD_DEFAULT_STEPS = 20
PPD_TEMP_ROOT = Path(tempfile.gettempdir()) / "ppd"
@@ -308,6 +433,8 @@ def get_model_choices() -> List[Tuple[str, str]]:
choices.append((v['display_name'], f'v1_{k}'))
for k, v in V2_MODEL_CONFIGS.items():
choices.append((v['display_name'], f'v2_{k}'))
+ for k, v in DA3_MODEL_SOURCES.items():
+ choices.append((v['display_name'], f'da3_{k}'))
choices.append(("Pixel-Perfect Depth", "ppd"))
return choices
@@ -322,6 +449,10 @@ def run_model(model_key: str, image: np.ndarray) -> Tuple[np.ndarray, str]:
model = load_v2_model(key)
depth = predict_v2(model, image)
label = V2_MODEL_CONFIGS[key]['display_name']
+ elif model_key.startswith('da3_'):
+ key = model_key[4:]
+ depth_vis, _, _, label = run_da3_inference(key, image)
+ return depth_vis, label
elif model_key == 'ppd':
slider_data, _, _ = pixel_perfect_depth_inference(
image,
@@ -449,6 +580,7 @@ def get_example_images() -> List[str]:
"assets/examples",
"Depth-Anything/assets/examples",
"Depth-Anything-V2/assets/examples",
+ "Depth-Anything-3-anysize/assets/examples",
"Pixel-Perfect-Depth/assets/examples",
]:
ex_path = os.path.join(os.path.dirname(__file__), ex_dir)
@@ -476,8 +608,19 @@ def create_app():
model_choices = get_model_choices()
default1 = model_choices[0][1]
default2 = model_choices[1][1]
+ da3_choices = [(cfg['display_name'], f"da3_{key}") for key, cfg in DA3_MODEL_SOURCES.items()]
+ if not da3_choices:
+ raise ValueError("Depth Anything v3 models are not configured.")
+ da3_default = da3_choices[2][1] if len(da3_choices) > 2 else da3_choices[0][1]
example_images = get_example_images()
- with gr.Blocks(title="Depth Anything v1 vs v2 Comparison", theme=gr.themes.Soft()) as app:
+ blocks_kwargs = {"title": "Depth Anything v1 vs v2 Comparison"}
+ try:
+ if "theme" in inspect.signature(gr.Blocks.__init__).parameters and hasattr(gr, "themes"):
+ # Use theme only when the installed gradio version accepts it.
+ blocks_kwargs["theme"] = gr.themes.Soft()
+ except (ValueError, TypeError):
+ pass
+ with gr.Blocks(**blocks_kwargs) as app:
gr.Markdown("""
# Depth Estimation Comparison
Compare Depth Anything v1, Depth Anything v2, and Pixel-Perfect Depth side-by-side or with a slider.
@@ -530,6 +673,7 @@ def create_app():
---
- **v1**: [Depth Anything v1](https://github.com/LiheYoung/Depth-Anything)
- **v2**: [Depth Anything v2](https://github.com/DepthAnything/Depth-Anything-V2)
+ - **v3**: [Depth Anything v3](https://github.com/ByteDance-Seed/Depth-Anything-3) & [Depth-Anything-3-anysize](https://github.com/shriarul5273/Depth-Anything-3-anysize)
- **PPD**: [Pixel-Perfect Depth](https://github.com/gangweix/pixel-perfect-depth)
""")
return app
diff --git a/requirements.txt b/requirements.txt
index 4f97e9328543fb854492fd94d81039ecc620f543..0582f8557515a6efa836fdd753061e9161c0aad6 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -14,4 +14,5 @@ open3d
scikit-learn
git+https://github.com/EasternJournalist/utils3d.git@c5daf6f6c244d251f252102d09e9b7bcef791a38
click # ==8.1.7
-trimesh # ==4.5.1
\ No newline at end of file
+trimesh # ==4.5.1
+-e Depth-Anything-3-anysize/.[all]
\ No newline at end of file