File size: 12,866 Bytes
7344bef | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 | from pathlib import Path
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
_PACKAGE_ROOT = Path(__file__).resolve().parent
DA3_BF16_MODEL = "depth/depth_anything_v3_vitl_bf16.safetensors"
DA3_METRIC_BF16_MODEL = "depth/depth_anything_v3_metric_large_bf16.safetensors"
def resolve_da3_chunk_size(chunk_size=-1, device=None):
chunk_size = int(chunk_size if chunk_size is not None else -1)
if chunk_size != -1:
return chunk_size
if not torch.cuda.is_available():
return 33
device = torch.device("cuda" if device is None else device)
if device.type != "cuda":
return 33
device_index = torch.cuda.current_device() if device.index is None else device.index
vram_gb = torch.cuda.get_device_properties(device_index).total_memory / 1_000_000_000
if vram_gb < 8:
return 33
if vram_gb < 24:
return 65
return 97
def _load_da3(pretrained_model, device, model_name="da3-large"):
from mmgp import offload
from safetensors import safe_open
from .api import DepthAnything3
model = DepthAnything3(model_name=model_name)
pretrained_model = str(pretrained_model)
if not pretrained_model.endswith(".safetensors"):
raise ValueError(f"Depth Anything 3 now expects the bf16 safetensors checkpoint, got: {pretrained_model}")
model_keys = set(model.state_dict().keys())
with safe_open(pretrained_model, framework="pt", device="cpu") as f:
checkpoint_keys = set(f.keys())
missing = sorted(model_keys - checkpoint_keys)
unexpected = sorted(checkpoint_keys - model_keys)
allowed_missing = tuple(f"model.head.scratch.output_conv2_aux.{idx}.2." for idx in range(1, 4))
unsupported_missing = [key for key in missing if not key.startswith(allowed_missing)]
if unexpected or unsupported_missing:
raise RuntimeError(f"Unexpected DA3 checkpoint keys: unexpected={unexpected}, missing={unsupported_missing}")
offload.load_model_data(model, pretrained_model, writable_tensors=False, default_dtype=torch.bfloat16, ignore_missing_keys=True)
model.requires_grad_(False)
model.to(device=device, dtype=torch.bfloat16)
model.eval()
return model
def _resize_2d(array, height, width, mode="bilinear", inverse=False):
if array.shape[-2:] == (height, width):
return array.copy()
dtype = array.dtype
tensor = torch.from_numpy(array).to(torch.float64)
leading = tensor.shape[:-2]
tensor = tensor.reshape(-1, *tensor.shape[-2:])
if inverse:
tensor = 1 / tensor
tensor = F.interpolate(tensor[:, None], size=(height, width), mode=mode)[:, 0]
if inverse:
tensor = 1 / tensor
tensor = tensor.reshape(*leading, height, width)
if dtype == np.bool_:
tensor = tensor >= 0.5
return tensor.numpy().astype(dtype)
def _k_to_intrinsics(k):
intrinsics = np.zeros((k.shape[0], 4), dtype=np.float32)
intrinsics[:, 0] = k[:, 0, 0]
intrinsics[:, 1] = k[:, 1, 1]
intrinsics[:, 2] = k[:, 0, 2]
intrinsics[:, 3] = k[:, 1, 2]
return intrinsics
def _prediction_to_arrays(prediction, height, width):
depths = prediction.depth.astype(np.float32)
sky = getattr(prediction, "sky", None)
if sky is None:
sky = np.zeros_like(depths, dtype=np.bool_)
else:
sky = sky.astype(np.bool_)
cam_w2c = prediction.extrinsics.astype(np.float32)
intrinsics = _k_to_intrinsics(prediction.intrinsics.astype(np.float32))
processed = prediction.processed_images
proc_h, proc_w = processed.shape[1:3]
depths = _resize_2d(depths, height, width, mode="bilinear", inverse=True)
sky = _resize_2d(sky, height, width, mode="nearest", inverse=False)
intrinsics[:, 0::2] *= width / proc_w
intrinsics[:, 1::2] *= height / proc_h
return depths, sky, cam_w2c, intrinsics
def _camera_w2c_to_c2w(cam_w2c):
cam_w2c_44 = np.zeros((cam_w2c.shape[0], 4, 4), dtype=np.float32)
cam_w2c_44[:, :3, :4] = cam_w2c
cam_w2c_44[:, 3, 3] = 1.0
cam_c2w = np.linalg.inv(cam_w2c_44)
return (np.linalg.inv(cam_c2w[0])[None] @ cam_c2w).astype(np.float32)
def _w2c_to_pose(cam_w2c):
cam_w2c_44 = np.zeros((cam_w2c.shape[0], 4, 4), dtype=np.float64)
cam_w2c_44[:, :3, :4] = cam_w2c.astype(np.float64)
cam_w2c_44[:, 3, 3] = 1.0
return np.linalg.inv(cam_w2c_44)
def _closest_rotation(matrix):
u, _, vh = np.linalg.svd(matrix)
rotation = u @ vh
if np.linalg.det(rotation) < 0:
u[:, -1] *= -1
rotation = u @ vh
return rotation
def _pose_based_chunk_alignment(ref_w2c, est_w2c):
ref_pose = _w2c_to_pose(ref_w2c)
est_pose = _w2c_to_pose(est_w2c)
rotation = _closest_rotation(np.mean(ref_pose[:, :3, :3] @ np.swapaxes(est_pose[:, :3, :3], -1, -2), axis=0))
ref_centers = ref_pose[:, :3, 3]
est_centers = est_pose[:, :3, 3]
pair_i, pair_j = np.triu_indices(ref_centers.shape[0], k=1)
ref_dists = np.linalg.norm(ref_centers[pair_i] - ref_centers[pair_j], axis=1)
est_dists = np.linalg.norm(est_centers[pair_i] - est_centers[pair_j], axis=1)
valid = est_dists > np.finfo(np.float64).eps
scale = float(np.median(ref_dists[valid] / est_dists[valid])) if valid.any() else 1.0
est_mean = est_centers.mean(axis=0)
ref_mean = ref_centers.mean(axis=0)
translation = ref_mean - scale * (rotation @ est_mean)
return rotation.astype(np.float32), translation.astype(np.float32), np.float32(scale)
def _apply_sim3_to_w2c(cam_w2c, rotation, translation, scale):
cam_w2c_44 = np.zeros((cam_w2c.shape[0], 4, 4), dtype=np.float32)
cam_w2c_44[:, :3, :4] = cam_w2c
cam_w2c_44[:, 3, 3] = 1.0
poses = np.linalg.inv(cam_w2c_44)
aligned = poses.copy()
aligned[:, :3, :3] = rotation @ poses[:, :3, :3]
aligned[:, :3, 3] = (rotation @ (scale * poses[:, :3, 3]).T).T + translation
return np.linalg.inv(aligned)[:, :3, :4].astype(np.float32)
def _chunk_ranges(frame_count, chunk_size, overlap):
if chunk_size <= 0 or chunk_size >= frame_count:
return [(0, frame_count)]
if overlap < 8:
raise ValueError("DA3 temporal chunking requires at least 8 overlap frames")
if overlap >= chunk_size:
raise ValueError("DA3 temporal chunk overlap must be smaller than the chunk size")
ranges, start, step = [], 0, chunk_size - overlap
while True:
end = start + chunk_size
if end >= frame_count:
ranges.append((frame_count - chunk_size, frame_count))
break
ranges.append((start, end))
next_start = start + step
final_start = frame_count - chunk_size
start = final_start if end - final_start >= overlap else next_start
return ranges
def _infer_da3_prediction(model, video, frame_indices, process_res):
frames = [Image.fromarray(video[i]) for i in frame_indices]
return model.inference(frames, process_res=process_res, export_format="npz")
def _infer_da3_depth_prediction(model, video, frame_indices, process_res):
frames = [Image.fromarray(video[i]) for i in frame_indices]
prediction = model.inference(frames, process_res=process_res, export_format="npz")
return _resize_2d(prediction.depth.astype(np.float32), video.shape[1], video.shape[2], mode="bilinear", inverse=True)
def _run_da3_prediction(model, video, process_res, chunk_size=0, chunk_overlap=8):
frame_count, height, width = video.shape[:3]
chunk_size = resolve_da3_chunk_size(chunk_size)
ranges = _chunk_ranges(frame_count, chunk_size, chunk_overlap)
if len(ranges) == 1:
prediction = _infer_da3_prediction(model, video, range(frame_count), process_res)
depths, sky, cam_w2c, intrinsics = _prediction_to_arrays(prediction, height, width)
return depths, sky, _camera_w2c_to_c2w(cam_w2c), intrinsics
depths_all = np.empty((frame_count, height, width), dtype=np.float32)
sky_all = np.empty((frame_count, height, width), dtype=np.bool_)
cam_w2c_all = np.empty((frame_count, 3, 4), dtype=np.float32)
intrinsics_all = np.empty((frame_count, 4), dtype=np.float32)
filled = np.zeros(frame_count, dtype=np.bool_)
for start, end in ranges:
indices = np.arange(start, end)
prediction = _infer_da3_prediction(model, video, indices, process_res)
depths, sky, cam_w2c, intrinsics = _prediction_to_arrays(prediction, height, width)
overlap_mask = filled[indices]
if overlap_mask.any():
if int(overlap_mask.sum()) < 3:
raise ValueError("DA3 temporal chunking produced fewer than 3 overlap frames for alignment")
ref_w2c = cam_w2c_all[indices[overlap_mask]]
est_w2c = cam_w2c[overlap_mask]
rotation, translation, scale = _pose_based_chunk_alignment(ref_w2c, est_w2c)
cam_w2c = _apply_sim3_to_w2c(cam_w2c, rotation, translation, scale)
depths *= np.float32(scale)
keep_mask = ~filled[indices]
keep_indices = indices[keep_mask]
depths_all[keep_indices] = depths[keep_mask]
sky_all[keep_indices] = sky[keep_mask]
cam_w2c_all[keep_indices] = cam_w2c[keep_mask]
intrinsics_all[keep_indices] = intrinsics[keep_mask]
filled[keep_indices] = True
del prediction, depths, sky, cam_w2c, intrinsics
if torch.cuda.is_available():
torch.cuda.empty_cache()
if not filled.all():
missing = np.flatnonzero(~filled).tolist()
raise RuntimeError(f"DA3 temporal chunking failed to fill frames: {missing}")
return depths_all, sky_all, _camera_w2c_to_c2w(cam_w2c_all), intrinsics_all
def _run_da3_depth_prediction(model, video, process_res, chunk_size=0):
frame_count, height, width = video.shape[:3]
chunk_size = resolve_da3_chunk_size(chunk_size)
if chunk_size <= 0 or chunk_size >= frame_count:
return _infer_da3_depth_prediction(model, video, range(frame_count), process_res)
depth_all = np.empty((frame_count, height, width), dtype=np.float32)
for start in range(0, frame_count, chunk_size):
end = min(frame_count, start + chunk_size)
depth_all[start:end] = _infer_da3_depth_prediction(model, video, range(start, end), process_res)
if torch.cuda.is_available():
torch.cuda.empty_cache()
return depth_all
@torch.inference_mode()
def run_da3_reconstruction(video, pretrained_model=None, process_res=0, device=None, chunk_size=0, chunk_overlap=8):
from shared.utils import files_locator as fl
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
chunk_size = resolve_da3_chunk_size(chunk_size, device)
pretrained_model = pretrained_model or fl.locate_file(DA3_BF16_MODEL)
model = _load_da3(pretrained_model, device, model_name="da3-large")
height, width = video.shape[1:3]
if process_res <= 0:
process_res = width
depths, sky, cam_c2w, intrinsics = _run_da3_prediction(model, video, process_res, chunk_size=chunk_size, chunk_overlap=chunk_overlap)
model.to("cpu")
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
return depths, sky, cam_c2w.astype(np.float32), intrinsics.astype(np.float32)
class DepthV3VideoAnnotator:
def __init__(self, cfg, device=None):
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
self.process_res = int(cfg.get("PROCESS_RES", 0) or 0)
self.chunk_size = resolve_da3_chunk_size(cfg.get("CHUNK_SIZE", -1), self.device)
self.chunk_overlap = int(cfg.get("CHUNK_OVERLAP", 8) or 8)
self.model_name = cfg.get("MODEL_NAME", "da3-large")
self.model = _load_da3(cfg["PRETRAINED_MODEL"], self.device, model_name=self.model_name)
@torch.inference_mode()
def forward(self, frames):
video = np.stack([np.asarray(frame) for frame in frames], axis=0)
if self.model_name == "da3metric-large":
depth = _run_da3_depth_prediction(self.model, video, self.process_res or video.shape[2], chunk_size=self.chunk_size)
else:
depth, _, _, _ = _run_da3_prediction(self.model, video, self.process_res or video.shape[2], chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap)
disp = 1.0 / np.maximum(depth, 1e-6)
disp -= disp.min()
disp /= max(float(disp.max()), 1e-6)
depth_video = (disp * 255.0).clip(0, 255).astype(np.uint8)
return [np.repeat(frame[..., None], 3, axis=2) for frame in depth_video]
def close(self):
if self.model is not None:
self.model.to("cpu")
self.model = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
def __del__(self):
try:
self.close()
except Exception:
pass
|