wop commited on
Commit
6acdd5b
·
verified ·
1 Parent(s): cfc7d65

Update depth_anything_3/api.py

Browse files
Files changed (1) hide show
  1. depth_anything_3/api.py +49 -264
depth_anything_3/api.py CHANGED
@@ -1,25 +1,3 @@
1
- # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
- """
15
- Depth Anything 3 API module.
16
-
17
- This module provides the main API for Depth Anything 3, including model loading,
18
- inference, and export capabilities. It supports both single and nested model architectures.
19
- """
20
-
21
- from __future__ import annotations
22
-
23
  import time
24
  from typing import Optional, Sequence
25
  import numpy as np
@@ -39,63 +17,30 @@ from depth_anything_3.utils.logger import logger
39
  from depth_anything_3.utils.pose_align import align_poses_umeyama
40
 
41
  torch.backends.cudnn.benchmark = False
42
- # logger.info("CUDNN Benchmark Disabled")
43
-
44
- SAFETENSORS_NAME = "model.safetensors"
45
- CONFIG_NAME = "config.json"
46
-
47
 
48
  class DepthAnything3(nn.Module, PyTorchModelHubMixin):
49
  """
50
- Depth Anything 3 main API class.
51
-
52
- This class provides a high-level interface for depth estimation using Depth Anything 3.
53
- It supports both single and nested model architectures with metric scaling capabilities.
54
-
55
- Features:
56
- - Hugging Face Hub integration via PyTorchModelHubMixin
57
- - Support for multiple model presets (vitb, vitg, nested variants)
58
- - Automatic mixed precision inference
59
- - Export capabilities for various formats (GLB, PLY, NPZ, etc.)
60
- - Camera pose estimation and metric depth scaling
61
-
62
- Usage:
63
- # Load from Hugging Face Hub
64
- model = DepthAnything3.from_pretrained("huggingface/model-name")
65
-
66
- # Or create with specific preset
67
- model = DepthAnything3(preset="vitg")
68
-
69
- # Run inference
70
- prediction = model.inference(images, export_dir="output", export_format="glb")
71
  """
72
 
73
  _commit_hash: str | None = None # Set by mixin when loading from Hub
74
 
75
  def __init__(self, model_name: str = "da3-large", **kwargs):
76
- """
77
- Initialize DepthAnything3 with specified preset.
78
-
79
- Args:
80
- model_name: The name of the model preset to use.
81
- Examples: 'da3-giant', 'da3-large', 'da3metric-large', 'da3nested-giant-large'.
82
- **kwargs: Additional keyword arguments (currently unused).
83
- """
84
  super().__init__()
85
  self.model_name = model_name
86
 
87
- # Build the underlying network
88
  self.config = load_config(MODEL_REGISTRY[self.model_name])
89
  self.model = create_object(self.config)
90
  self.model.eval()
 
 
91
 
92
  # Initialize processors
93
  self.input_processor = InputProcessor()
94
  self.output_processor = OutputProcessor()
95
 
96
- # Device management (set by user)
97
- self.device = None
98
-
99
  @torch.inference_mode()
100
  def forward(
101
  self,
@@ -105,23 +50,9 @@ class DepthAnything3(nn.Module, PyTorchModelHubMixin):
105
  export_feat_layers: list[int] | None = None,
106
  infer_gs: bool = False,
107
  ) -> dict[str, torch.Tensor]:
108
- """
109
- Forward pass through the model.
110
-
111
- Args:
112
- image: Input batch with shape ``(B, N, 3, H, W)`` on the model device.
113
- extrinsics: Optional camera extrinsics with shape ``(B, N, 4, 4)``.
114
- intrinsics: Optional camera intrinsics with shape ``(B, N, 3, 3)``.
115
- export_feat_layers: Layer indices to return intermediate features for.
116
-
117
- Returns:
118
- Dictionary containing model predictions
119
- """
120
- # Determine optimal autocast dtype
121
- autocast_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
122
- with torch.no_grad():
123
- with torch.autocast(device_type=image.device.type, dtype=autocast_dtype):
124
- return self.model(image, extrinsics, intrinsics, export_feat_layers, infer_gs)
125
 
126
  def inference(
127
  self,
@@ -138,119 +69,57 @@ class DepthAnything3(nn.Module, PyTorchModelHubMixin):
138
  export_dir: str | None = None,
139
  export_format: str = "mini_npz",
140
  export_feat_layers: Sequence[int] | None = None,
141
- # GLB export parameters
142
  conf_thresh_percentile: float = 40.0,
143
  num_max_points: int = 1_000_000,
144
  show_cameras: bool = True,
145
- # Feat_vis export parameters
146
  feat_vis_fps: int = 15,
147
  export_kwargs: Optional[dict] = {},
148
  ) -> Prediction:
149
- """
150
- Run inference on input images.
151
-
152
- Args:
153
- image: List of input images (numpy arrays, PIL Images, or file paths)
154
- extrinsics: Camera extrinsics (N, 4, 4)
155
- intrinsics: Camera intrinsics (N, 3, 3)
156
- align_to_input_ext_scale: whether to align the input pose scale to the prediction
157
- infer_gs: Enable the 3D Gaussian branch (needed for `gs_ply`/`gs_video` exports)
158
- render_exts: Optional render extrinsics for Gaussian video export
159
- render_ixts: Optional render intrinsics for Gaussian video export
160
- render_hw: Optional render resolution for Gaussian video export
161
- process_res: Processing resolution
162
- process_res_method: Resize method for processing
163
- export_dir: Directory to export results
164
- export_format: Export format (mini_npz, npz, glb, ply, gs, gs_video)
165
- export_feat_layers: Layer indices to export intermediate features from
166
- conf_thresh_percentile: [GLB] Lower percentile for adaptive confidence threshold (default: 40.0) # noqa: E501
167
- num_max_points: [GLB] Maximum number of points in the point cloud (default: 1,000,000)
168
- show_cameras: [GLB] Show camera wireframes in the exported scene (default: True)
169
- feat_vis_fps: [FEAT_VIS] Frame rate for output video (default: 15)
170
- export_kwargs: additional arguments to export functions.
171
-
172
- Returns:
173
- Prediction object containing depth maps and camera parameters
174
- """
175
  if "gs" in export_format:
176
  assert infer_gs, "must set `infer_gs=True` to perform gs-related export."
177
 
178
- # Preprocess images
179
  imgs_cpu, extrinsics, intrinsics = self._preprocess_inputs(
180
  image, extrinsics, intrinsics, process_res, process_res_method
181
  )
182
-
183
- # Prepare tensors for model
184
  imgs, ex_t, in_t = self._prepare_model_inputs(imgs_cpu, extrinsics, intrinsics)
185
-
186
- # Normalize extrinsics
187
  ex_t_norm = self._normalize_extrinsics(ex_t.clone() if ex_t is not None else None)
188
-
189
- # Run model forward pass
190
  export_feat_layers = list(export_feat_layers) if export_feat_layers is not None else []
191
 
192
  raw_output = self._run_model_forward(imgs, ex_t_norm, in_t, export_feat_layers, infer_gs)
193
-
194
- # Convert raw output to prediction
195
  prediction = self._convert_to_prediction(raw_output)
196
-
197
- # Align prediction to extrinsincs
198
- prediction = self._align_to_input_extrinsics_intrinsics(
199
- extrinsics, intrinsics, prediction, align_to_input_ext_scale
200
- )
201
-
202
- # Add processed images for visualization
203
  prediction = self._add_processed_images(prediction, imgs_cpu)
204
 
205
- # Export if requested
206
  if export_dir is not None:
207
-
208
- if "gs" in export_format:
209
- if infer_gs and "gs_video" not in export_format:
210
  export_format = f"{export_format}-gs_video"
211
- if "gs_video" in export_format:
212
- if "gs_video" not in export_kwargs:
213
- export_kwargs["gs_video"] = {}
214
- export_kwargs["gs_video"].update(
215
- {
216
- "extrinsics": render_exts,
217
- "intrinsics": render_ixts,
218
- "out_image_hw": render_hw,
219
- }
220
- )
221
- # Add GLB export parameters
222
  if "glb" in export_format:
223
  if "glb" not in export_kwargs:
224
  export_kwargs["glb"] = {}
225
- export_kwargs["glb"].update(
226
- {
227
- "conf_thresh_percentile": conf_thresh_percentile,
228
- "num_max_points": num_max_points,
229
- "show_cameras": show_cameras,
230
- }
231
- )
232
- # Add Feat_vis export parameters
233
  if "feat_vis" in export_format:
234
  if "feat_vis" not in export_kwargs:
235
  export_kwargs["feat_vis"] = {}
236
- export_kwargs["feat_vis"].update(
237
- {
238
- "fps": feat_vis_fps,
239
- }
240
- )
241
  self._export_results(prediction, export_format, export_dir, **export_kwargs)
242
 
243
  return prediction
244
 
245
- def _preprocess_inputs(
246
- self,
247
- image: list[np.ndarray | Image.Image | str],
248
- extrinsics: np.ndarray | None = None,
249
- intrinsics: np.ndarray | None = None,
250
- process_res: int = 504,
251
- process_res_method: str = "upper_bound_resize",
252
- ) -> torch.Tensor:
253
- """Preprocess input images using input processor."""
254
  start_time = time.time()
255
  imgs_cpu, extrinsics, intrinsics = self.input_processor(
256
  image,
@@ -259,43 +128,17 @@ class DepthAnything3(nn.Module, PyTorchModelHubMixin):
259
  process_res,
260
  process_res_method,
261
  )
262
- end_time = time.time()
263
- logger.info(
264
- "Processed Images Done taking",
265
- end_time - start_time,
266
- "seconds. Shape: ",
267
- imgs_cpu.shape,
268
- )
269
  return imgs_cpu, extrinsics, intrinsics
270
 
271
- def _prepare_model_inputs(
272
- self,
273
- imgs_cpu: torch.Tensor,
274
- extrinsics: torch.tensor | None,
275
- intrinsics: torch.tensor | None,
276
- ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
277
- """Prepare tensors for model input."""
278
- device = self._get_model_device()
279
-
280
- # Move images to model device
281
- imgs = imgs_cpu.to(device, non_blocking=True)[None].float()
282
-
283
- # Convert camera parameters to tensors
284
- ex_t = (
285
- extrinsics.to(device, non_blocking=True)[None].float()
286
- if extrinsics is not None
287
- else None
288
- )
289
- in_t = (
290
- intrinsics.to(device, non_blocking=True)[None].float()
291
- if intrinsics is not None
292
- else None
293
- )
294
-
295
  return imgs, ex_t, in_t
296
 
297
- def _normalize_extrinsics(self, ex_t: torch.Tensor) -> torch.Tensor:
298
- """Normalize extrinsics"""
299
  if ex_t is None:
300
  return None
301
  transform = affine_inverse(ex_t[:, :1])
@@ -303,20 +146,11 @@ class DepthAnything3(nn.Module, PyTorchModelHubMixin):
303
  c2ws = affine_inverse(ex_t_norm)
304
  translations = c2ws[..., :3, 3]
305
  dists = translations.norm(dim=-1)
306
- median_dist = torch.median(dists)
307
- median_dist = torch.clamp(median_dist, min=1e-1)
308
  ex_t_norm[..., :3, 3] = ex_t_norm[..., :3, 3] / median_dist
309
  return ex_t_norm
310
 
311
- def _align_to_input_extrinsics_intrinsics(
312
- self,
313
- extrinsics: torch.Tensor,
314
- intrinsics: torch.Tensor,
315
- prediction: Prediction,
316
- align_to_input_ext_scale: bool = True,
317
- ransac_view_thresh: int = 10,
318
- ) -> Prediction:
319
- """Align depth map to input extrinsics"""
320
  if extrinsics is None:
321
  return prediction
322
  prediction.intrinsics = intrinsics.numpy()
@@ -334,81 +168,32 @@ class DepthAnything3(nn.Module, PyTorchModelHubMixin):
334
  prediction.extrinsics = aligned_extrinsics
335
  return prediction
336
 
337
- def _run_model_forward(
338
- self,
339
- imgs: torch.Tensor,
340
- ex_t: torch.Tensor | None,
341
- in_t: torch.Tensor | None,
342
- export_feat_layers: Sequence[int] | None = None,
343
- infer_gs: bool = False,
344
- ) -> dict[str, torch.Tensor]:
345
- """Run model forward pass."""
346
- device = imgs.device
347
- need_sync = device.type == "cuda"
348
- if need_sync:
349
- torch.cuda.synchronize(device)
350
  start_time = time.time()
351
- feat_layers = list(export_feat_layers) if export_feat_layers is not None else None
352
- output = self.forward(imgs, ex_t, in_t, feat_layers, infer_gs)
353
- if need_sync:
354
- torch.cuda.synchronize(device)
355
- end_time = time.time()
356
- logger.info(f"Model Forward Pass Done. Time: {end_time - start_time} seconds")
357
  return output
358
 
359
- def _convert_to_prediction(self, raw_output: dict[str, torch.Tensor]) -> Prediction:
360
- """Convert raw model output to Prediction object."""
361
  start_time = time.time()
362
  output = self.output_processor(raw_output)
363
- end_time = time.time()
364
- logger.info(f"Conversion to Prediction Done. Time: {end_time - start_time} seconds")
365
  return output
366
 
367
- def _add_processed_images(self, prediction: Prediction, imgs_cpu: torch.Tensor) -> Prediction:
368
- """Add processed images to prediction for visualization."""
369
- # Convert from (N, 3, H, W) to (N, H, W, 3) and denormalize
370
- processed_imgs = imgs_cpu.permute(0, 2, 3, 1).cpu().numpy() # (N, H, W, 3)
371
-
372
- # Denormalize from ImageNet normalization
373
  mean = np.array([0.485, 0.456, 0.406])
374
  std = np.array([0.229, 0.224, 0.225])
375
- processed_imgs = processed_imgs * std + mean
376
- processed_imgs = np.clip(processed_imgs, 0, 1)
377
  processed_imgs = (processed_imgs * 255).astype(np.uint8)
378
-
379
  prediction.processed_images = processed_imgs
380
  return prediction
381
 
382
- def _export_results(
383
- self, prediction: Prediction, export_format: str, export_dir: str, **kwargs
384
- ) -> None:
385
- """Export results to specified format and directory."""
386
  start_time = time.time()
387
  export(prediction, export_format, export_dir, **kwargs)
388
- end_time = time.time()
389
- logger.info(f"Export Results Done. Time: {end_time - start_time} seconds")
390
-
391
- def _get_model_device(self) -> torch.device:
392
- """
393
- Get the device where the model is located.
394
-
395
- Returns:
396
- Device where the model parameters are located
397
-
398
- Raises:
399
- ValueError: If no tensors are found in the model
400
- """
401
- if self.device is not None:
402
- return self.device
403
-
404
- # Find device from parameters
405
- for param in self.parameters():
406
- self.device = param.device
407
- return param.device
408
-
409
- # Find device from buffers
410
- for buffer in self.buffers():
411
- self.device = buffer.device
412
- return buffer.device
413
 
414
- raise ValueError("No tensor found in model")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import time
2
  from typing import Optional, Sequence
3
  import numpy as np
 
17
  from depth_anything_3.utils.pose_align import align_poses_umeyama
18
 
19
  torch.backends.cudnn.benchmark = False
 
 
 
 
 
20
 
21
  class DepthAnything3(nn.Module, PyTorchModelHubMixin):
22
  """
23
+ CPU-only Depth Anything 3 API class.
24
+ This class provides depth estimation with all tensors and models on CPU.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  """
26
 
27
  _commit_hash: str | None = None # Set by mixin when loading from Hub
28
 
29
  def __init__(self, model_name: str = "da3-large", **kwargs):
 
 
 
 
 
 
 
 
30
  super().__init__()
31
  self.model_name = model_name
32
 
33
+ # Build network and force CPU
34
  self.config = load_config(MODEL_REGISTRY[self.model_name])
35
  self.model = create_object(self.config)
36
  self.model.eval()
37
+ self.device = torch.device("cpu")
38
+ self.model.to(self.device)
39
 
40
  # Initialize processors
41
  self.input_processor = InputProcessor()
42
  self.output_processor = OutputProcessor()
43
 
 
 
 
44
  @torch.inference_mode()
45
  def forward(
46
  self,
 
50
  export_feat_layers: list[int] | None = None,
51
  infer_gs: bool = False,
52
  ) -> dict[str, torch.Tensor]:
53
+ """Forward pass on CPU."""
54
+ image = image.to(self.device)
55
+ return self.model(image, extrinsics, intrinsics, export_feat_layers, infer_gs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  def inference(
58
  self,
 
69
  export_dir: str | None = None,
70
  export_format: str = "mini_npz",
71
  export_feat_layers: Sequence[int] | None = None,
 
72
  conf_thresh_percentile: float = 40.0,
73
  num_max_points: int = 1_000_000,
74
  show_cameras: bool = True,
 
75
  feat_vis_fps: int = 15,
76
  export_kwargs: Optional[dict] = {},
77
  ) -> Prediction:
78
+ """Run inference on input images (CPU)."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  if "gs" in export_format:
80
  assert infer_gs, "must set `infer_gs=True` to perform gs-related export."
81
 
 
82
  imgs_cpu, extrinsics, intrinsics = self._preprocess_inputs(
83
  image, extrinsics, intrinsics, process_res, process_res_method
84
  )
 
 
85
  imgs, ex_t, in_t = self._prepare_model_inputs(imgs_cpu, extrinsics, intrinsics)
 
 
86
  ex_t_norm = self._normalize_extrinsics(ex_t.clone() if ex_t is not None else None)
 
 
87
  export_feat_layers = list(export_feat_layers) if export_feat_layers is not None else []
88
 
89
  raw_output = self._run_model_forward(imgs, ex_t_norm, in_t, export_feat_layers, infer_gs)
 
 
90
  prediction = self._convert_to_prediction(raw_output)
91
+ prediction = self._align_to_input_extrinsics_intrinsics(extrinsics, intrinsics, prediction, align_to_input_ext_scale)
 
 
 
 
 
 
92
  prediction = self._add_processed_images(prediction, imgs_cpu)
93
 
 
94
  if export_dir is not None:
95
+ if "gs" in export_format and infer_gs:
96
+ if "gs_video" not in export_format:
 
97
  export_format = f"{export_format}-gs_video"
98
+ if "gs_video" in export_format and "gs_video" not in export_kwargs:
99
+ export_kwargs["gs_video"] = {}
100
+ export_kwargs["gs_video"].update({
101
+ "extrinsics": render_exts,
102
+ "intrinsics": render_ixts,
103
+ "out_image_hw": render_hw,
104
+ })
 
 
 
 
105
  if "glb" in export_format:
106
  if "glb" not in export_kwargs:
107
  export_kwargs["glb"] = {}
108
+ export_kwargs["glb"].update({
109
+ "conf_thresh_percentile": conf_thresh_percentile,
110
+ "num_max_points": num_max_points,
111
+ "show_cameras": show_cameras,
112
+ })
 
 
 
113
  if "feat_vis" in export_format:
114
  if "feat_vis" not in export_kwargs:
115
  export_kwargs["feat_vis"] = {}
116
+ export_kwargs["feat_vis"].update({"fps": feat_vis_fps})
 
 
 
 
117
  self._export_results(prediction, export_format, export_dir, **export_kwargs)
118
 
119
  return prediction
120
 
121
+ def _preprocess_inputs(self, image, extrinsics=None, intrinsics=None, process_res=504, process_res_method="upper_bound_resize"):
122
+ """Preprocess input images on CPU."""
 
 
 
 
 
 
 
123
  start_time = time.time()
124
  imgs_cpu, extrinsics, intrinsics = self.input_processor(
125
  image,
 
128
  process_res,
129
  process_res_method,
130
  )
131
+ logger.info("Processed Images Done taking", time.time() - start_time, "seconds. Shape:", imgs_cpu.shape)
 
 
 
 
 
 
132
  return imgs_cpu, extrinsics, intrinsics
133
 
134
+ def _prepare_model_inputs(self, imgs_cpu, extrinsics, intrinsics):
135
+ """Prepare tensors for model input (CPU-only)."""
136
+ imgs = imgs_cpu[None].float().to(self.device)
137
+ ex_t = extrinsics[None].float().to(self.device) if extrinsics is not None else None
138
+ in_t = intrinsics[None].float().to(self.device) if intrinsics is not None else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  return imgs, ex_t, in_t
140
 
141
+ def _normalize_extrinsics(self, ex_t):
 
142
  if ex_t is None:
143
  return None
144
  transform = affine_inverse(ex_t[:, :1])
 
146
  c2ws = affine_inverse(ex_t_norm)
147
  translations = c2ws[..., :3, 3]
148
  dists = translations.norm(dim=-1)
149
+ median_dist = torch.clamp(torch.median(dists), min=1e-1)
 
150
  ex_t_norm[..., :3, 3] = ex_t_norm[..., :3, 3] / median_dist
151
  return ex_t_norm
152
 
153
+ def _align_to_input_extrinsics_intrinsics(self, extrinsics, intrinsics, prediction, align_to_input_ext_scale=True, ransac_view_thresh=10):
 
 
 
 
 
 
 
 
154
  if extrinsics is None:
155
  return prediction
156
  prediction.intrinsics = intrinsics.numpy()
 
168
  prediction.extrinsics = aligned_extrinsics
169
  return prediction
170
 
171
+ def _run_model_forward(self, imgs, ex_t, in_t, export_feat_layers=None, infer_gs=False):
 
 
 
 
 
 
 
 
 
 
 
 
172
  start_time = time.time()
173
+ output = self.forward(imgs, ex_t, in_t, export_feat_layers, infer_gs)
174
+ logger.info(f"Model Forward Pass Done (CPU). Time: {time.time() - start_time} seconds")
 
 
 
 
175
  return output
176
 
177
+ def _convert_to_prediction(self, raw_output):
 
178
  start_time = time.time()
179
  output = self.output_processor(raw_output)
180
+ logger.info(f"Conversion to Prediction Done. Time: {time.time() - start_time} seconds")
 
181
  return output
182
 
183
+ def _add_processed_images(self, prediction, imgs_cpu):
184
+ processed_imgs = imgs_cpu.permute(0, 2, 3, 1).cpu().numpy()
 
 
 
 
185
  mean = np.array([0.485, 0.456, 0.406])
186
  std = np.array([0.229, 0.224, 0.225])
187
+ processed_imgs = np.clip(processed_imgs * std + mean, 0, 1)
 
188
  processed_imgs = (processed_imgs * 255).astype(np.uint8)
 
189
  prediction.processed_images = processed_imgs
190
  return prediction
191
 
192
+ def _export_results(self, prediction, export_format, export_dir, **kwargs):
 
 
 
193
  start_time = time.time()
194
  export(prediction, export_format, export_dir, **kwargs)
195
+ logger.info(f"Export Results Done. Time: {time.time() - start_time} seconds")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
+ def _get_model_device(self):
198
+ """Always return CPU device."""
199
+ return torch.device("cpu")