from pathlib import Path import numpy as np import torch import torch.nn.functional as F from PIL import Image _PACKAGE_ROOT = Path(__file__).resolve().parent DA3_BF16_MODEL = "depth/depth_anything_v3_vitl_bf16.safetensors" DA3_METRIC_BF16_MODEL = "depth/depth_anything_v3_metric_large_bf16.safetensors" def resolve_da3_chunk_size(chunk_size=-1, device=None): chunk_size = int(chunk_size if chunk_size is not None else -1) if chunk_size != -1: return chunk_size if not torch.cuda.is_available(): return 33 device = torch.device("cuda" if device is None else device) if device.type != "cuda": return 33 device_index = torch.cuda.current_device() if device.index is None else device.index vram_gb = torch.cuda.get_device_properties(device_index).total_memory / 1_000_000_000 if vram_gb < 8: return 33 if vram_gb < 24: return 65 return 97 def _load_da3(pretrained_model, device, model_name="da3-large"): from mmgp import offload from safetensors import safe_open from .api import DepthAnything3 model = DepthAnything3(model_name=model_name) pretrained_model = str(pretrained_model) if not pretrained_model.endswith(".safetensors"): raise ValueError(f"Depth Anything 3 now expects the bf16 safetensors checkpoint, got: {pretrained_model}") model_keys = set(model.state_dict().keys()) with safe_open(pretrained_model, framework="pt", device="cpu") as f: checkpoint_keys = set(f.keys()) missing = sorted(model_keys - checkpoint_keys) unexpected = sorted(checkpoint_keys - model_keys) allowed_missing = tuple(f"model.head.scratch.output_conv2_aux.{idx}.2." for idx in range(1, 4)) unsupported_missing = [key for key in missing if not key.startswith(allowed_missing)] if unexpected or unsupported_missing: raise RuntimeError(f"Unexpected DA3 checkpoint keys: unexpected={unexpected}, missing={unsupported_missing}") offload.load_model_data(model, pretrained_model, writable_tensors=False, default_dtype=torch.bfloat16, ignore_missing_keys=True) model.requires_grad_(False) model.to(device=device, dtype=torch.bfloat16) model.eval() return model def _resize_2d(array, height, width, mode="bilinear", inverse=False): if array.shape[-2:] == (height, width): return array.copy() dtype = array.dtype tensor = torch.from_numpy(array).to(torch.float64) leading = tensor.shape[:-2] tensor = tensor.reshape(-1, *tensor.shape[-2:]) if inverse: tensor = 1 / tensor tensor = F.interpolate(tensor[:, None], size=(height, width), mode=mode)[:, 0] if inverse: tensor = 1 / tensor tensor = tensor.reshape(*leading, height, width) if dtype == np.bool_: tensor = tensor >= 0.5 return tensor.numpy().astype(dtype) def _k_to_intrinsics(k): intrinsics = np.zeros((k.shape[0], 4), dtype=np.float32) intrinsics[:, 0] = k[:, 0, 0] intrinsics[:, 1] = k[:, 1, 1] intrinsics[:, 2] = k[:, 0, 2] intrinsics[:, 3] = k[:, 1, 2] return intrinsics def _prediction_to_arrays(prediction, height, width): depths = prediction.depth.astype(np.float32) sky = getattr(prediction, "sky", None) if sky is None: sky = np.zeros_like(depths, dtype=np.bool_) else: sky = sky.astype(np.bool_) cam_w2c = prediction.extrinsics.astype(np.float32) intrinsics = _k_to_intrinsics(prediction.intrinsics.astype(np.float32)) processed = prediction.processed_images proc_h, proc_w = processed.shape[1:3] depths = _resize_2d(depths, height, width, mode="bilinear", inverse=True) sky = _resize_2d(sky, height, width, mode="nearest", inverse=False) intrinsics[:, 0::2] *= width / proc_w intrinsics[:, 1::2] *= height / proc_h return depths, sky, cam_w2c, intrinsics def _camera_w2c_to_c2w(cam_w2c): cam_w2c_44 = np.zeros((cam_w2c.shape[0], 4, 4), dtype=np.float32) cam_w2c_44[:, :3, :4] = cam_w2c cam_w2c_44[:, 3, 3] = 1.0 cam_c2w = np.linalg.inv(cam_w2c_44) return (np.linalg.inv(cam_c2w[0])[None] @ cam_c2w).astype(np.float32) def _w2c_to_pose(cam_w2c): cam_w2c_44 = np.zeros((cam_w2c.shape[0], 4, 4), dtype=np.float64) cam_w2c_44[:, :3, :4] = cam_w2c.astype(np.float64) cam_w2c_44[:, 3, 3] = 1.0 return np.linalg.inv(cam_w2c_44) def _closest_rotation(matrix): u, _, vh = np.linalg.svd(matrix) rotation = u @ vh if np.linalg.det(rotation) < 0: u[:, -1] *= -1 rotation = u @ vh return rotation def _pose_based_chunk_alignment(ref_w2c, est_w2c): ref_pose = _w2c_to_pose(ref_w2c) est_pose = _w2c_to_pose(est_w2c) rotation = _closest_rotation(np.mean(ref_pose[:, :3, :3] @ np.swapaxes(est_pose[:, :3, :3], -1, -2), axis=0)) ref_centers = ref_pose[:, :3, 3] est_centers = est_pose[:, :3, 3] pair_i, pair_j = np.triu_indices(ref_centers.shape[0], k=1) ref_dists = np.linalg.norm(ref_centers[pair_i] - ref_centers[pair_j], axis=1) est_dists = np.linalg.norm(est_centers[pair_i] - est_centers[pair_j], axis=1) valid = est_dists > np.finfo(np.float64).eps scale = float(np.median(ref_dists[valid] / est_dists[valid])) if valid.any() else 1.0 est_mean = est_centers.mean(axis=0) ref_mean = ref_centers.mean(axis=0) translation = ref_mean - scale * (rotation @ est_mean) return rotation.astype(np.float32), translation.astype(np.float32), np.float32(scale) def _apply_sim3_to_w2c(cam_w2c, rotation, translation, scale): cam_w2c_44 = np.zeros((cam_w2c.shape[0], 4, 4), dtype=np.float32) cam_w2c_44[:, :3, :4] = cam_w2c cam_w2c_44[:, 3, 3] = 1.0 poses = np.linalg.inv(cam_w2c_44) aligned = poses.copy() aligned[:, :3, :3] = rotation @ poses[:, :3, :3] aligned[:, :3, 3] = (rotation @ (scale * poses[:, :3, 3]).T).T + translation return np.linalg.inv(aligned)[:, :3, :4].astype(np.float32) def _chunk_ranges(frame_count, chunk_size, overlap): if chunk_size <= 0 or chunk_size >= frame_count: return [(0, frame_count)] if overlap < 8: raise ValueError("DA3 temporal chunking requires at least 8 overlap frames") if overlap >= chunk_size: raise ValueError("DA3 temporal chunk overlap must be smaller than the chunk size") ranges, start, step = [], 0, chunk_size - overlap while True: end = start + chunk_size if end >= frame_count: ranges.append((frame_count - chunk_size, frame_count)) break ranges.append((start, end)) next_start = start + step final_start = frame_count - chunk_size start = final_start if end - final_start >= overlap else next_start return ranges def _infer_da3_prediction(model, video, frame_indices, process_res): frames = [Image.fromarray(video[i]) for i in frame_indices] return model.inference(frames, process_res=process_res, export_format="npz") def _infer_da3_depth_prediction(model, video, frame_indices, process_res): frames = [Image.fromarray(video[i]) for i in frame_indices] prediction = model.inference(frames, process_res=process_res, export_format="npz") return _resize_2d(prediction.depth.astype(np.float32), video.shape[1], video.shape[2], mode="bilinear", inverse=True) def _run_da3_prediction(model, video, process_res, chunk_size=0, chunk_overlap=8): frame_count, height, width = video.shape[:3] chunk_size = resolve_da3_chunk_size(chunk_size) ranges = _chunk_ranges(frame_count, chunk_size, chunk_overlap) if len(ranges) == 1: prediction = _infer_da3_prediction(model, video, range(frame_count), process_res) depths, sky, cam_w2c, intrinsics = _prediction_to_arrays(prediction, height, width) return depths, sky, _camera_w2c_to_c2w(cam_w2c), intrinsics depths_all = np.empty((frame_count, height, width), dtype=np.float32) sky_all = np.empty((frame_count, height, width), dtype=np.bool_) cam_w2c_all = np.empty((frame_count, 3, 4), dtype=np.float32) intrinsics_all = np.empty((frame_count, 4), dtype=np.float32) filled = np.zeros(frame_count, dtype=np.bool_) for start, end in ranges: indices = np.arange(start, end) prediction = _infer_da3_prediction(model, video, indices, process_res) depths, sky, cam_w2c, intrinsics = _prediction_to_arrays(prediction, height, width) overlap_mask = filled[indices] if overlap_mask.any(): if int(overlap_mask.sum()) < 3: raise ValueError("DA3 temporal chunking produced fewer than 3 overlap frames for alignment") ref_w2c = cam_w2c_all[indices[overlap_mask]] est_w2c = cam_w2c[overlap_mask] rotation, translation, scale = _pose_based_chunk_alignment(ref_w2c, est_w2c) cam_w2c = _apply_sim3_to_w2c(cam_w2c, rotation, translation, scale) depths *= np.float32(scale) keep_mask = ~filled[indices] keep_indices = indices[keep_mask] depths_all[keep_indices] = depths[keep_mask] sky_all[keep_indices] = sky[keep_mask] cam_w2c_all[keep_indices] = cam_w2c[keep_mask] intrinsics_all[keep_indices] = intrinsics[keep_mask] filled[keep_indices] = True del prediction, depths, sky, cam_w2c, intrinsics if torch.cuda.is_available(): torch.cuda.empty_cache() if not filled.all(): missing = np.flatnonzero(~filled).tolist() raise RuntimeError(f"DA3 temporal chunking failed to fill frames: {missing}") return depths_all, sky_all, _camera_w2c_to_c2w(cam_w2c_all), intrinsics_all def _run_da3_depth_prediction(model, video, process_res, chunk_size=0): frame_count, height, width = video.shape[:3] chunk_size = resolve_da3_chunk_size(chunk_size) if chunk_size <= 0 or chunk_size >= frame_count: return _infer_da3_depth_prediction(model, video, range(frame_count), process_res) depth_all = np.empty((frame_count, height, width), dtype=np.float32) for start in range(0, frame_count, chunk_size): end = min(frame_count, start + chunk_size) depth_all[start:end] = _infer_da3_depth_prediction(model, video, range(start, end), process_res) if torch.cuda.is_available(): torch.cuda.empty_cache() return depth_all @torch.inference_mode() def run_da3_reconstruction(video, pretrained_model=None, process_res=0, device=None, chunk_size=0, chunk_overlap=8): from shared.utils import files_locator as fl device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device chunk_size = resolve_da3_chunk_size(chunk_size, device) pretrained_model = pretrained_model or fl.locate_file(DA3_BF16_MODEL) model = _load_da3(pretrained_model, device, model_name="da3-large") height, width = video.shape[1:3] if process_res <= 0: process_res = width depths, sky, cam_c2w, intrinsics = _run_da3_prediction(model, video, process_res, chunk_size=chunk_size, chunk_overlap=chunk_overlap) model.to("cpu") del model if torch.cuda.is_available(): torch.cuda.empty_cache() return depths, sky, cam_c2w.astype(np.float32), intrinsics.astype(np.float32) class DepthV3VideoAnnotator: def __init__(self, cfg, device=None): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device self.process_res = int(cfg.get("PROCESS_RES", 0) or 0) self.chunk_size = resolve_da3_chunk_size(cfg.get("CHUNK_SIZE", -1), self.device) self.chunk_overlap = int(cfg.get("CHUNK_OVERLAP", 8) or 8) self.model_name = cfg.get("MODEL_NAME", "da3-large") self.model = _load_da3(cfg["PRETRAINED_MODEL"], self.device, model_name=self.model_name) @torch.inference_mode() def forward(self, frames): video = np.stack([np.asarray(frame) for frame in frames], axis=0) if self.model_name == "da3metric-large": depth = _run_da3_depth_prediction(self.model, video, self.process_res or video.shape[2], chunk_size=self.chunk_size) else: depth, _, _, _ = _run_da3_prediction(self.model, video, self.process_res or video.shape[2], chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap) disp = 1.0 / np.maximum(depth, 1e-6) disp -= disp.min() disp /= max(float(disp.max()), 1e-6) depth_video = (disp * 255.0).clip(0, 255).astype(np.uint8) return [np.repeat(frame[..., None], 3, axis=2) for frame in depth_video] def close(self): if self.model is not None: self.model.to("cpu") self.model = None if torch.cuda.is_available(): torch.cuda.empty_cache() def __del__(self): try: self.close() except Exception: pass