|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Depth Anything 3 API module. |
|
|
|
|
|
This module provides the main API for Depth Anything 3, including model loading, |
|
|
inference, and export capabilities. It supports both single and nested model architectures. |
|
|
""" |
|
|
|
|
|
from __future__ import annotations |
|
|
|
|
|
import time |
|
|
from typing import Optional, Sequence |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from huggingface_hub import PyTorchModelHubMixin |
|
|
from PIL import Image |
|
|
|
|
|
from depth_anything_3.cache import get_model_cache |
|
|
from depth_anything_3.cfg import create_object, load_config |
|
|
from depth_anything_3.registry import MODEL_REGISTRY |
|
|
from depth_anything_3.specs import Prediction |
|
|
from depth_anything_3.utils.adaptive_batching import ( |
|
|
AdaptiveBatchConfig, |
|
|
AdaptiveBatchSizeCalculator, |
|
|
adaptive_batch_iterator, |
|
|
estimate_max_batch_size, |
|
|
) |
|
|
from depth_anything_3.utils.export import export |
|
|
from depth_anything_3.utils.geometry import affine_inverse |
|
|
from depth_anything_3.utils.io.gpu_input_processor import GPUInputProcessor |
|
|
from depth_anything_3.utils.io.input_processor import InputProcessor |
|
|
from depth_anything_3.utils.io.output_processor import OutputProcessor |
|
|
from depth_anything_3.utils.logger import logger |
|
|
from depth_anything_3.utils.pose_align import align_poses_umeyama |
|
|
|
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
|
|
|
SAFETENSORS_NAME = "model.safetensors" |
|
|
CONFIG_NAME = "config.json" |
|
|
|
|
|
|
|
|
class DepthAnything3(nn.Module, PyTorchModelHubMixin): |
|
|
""" |
|
|
Depth Anything 3 main API class. |
|
|
|
|
|
This class provides a high-level interface for depth estimation using Depth Anything 3. |
|
|
It supports both single and nested model architectures with metric scaling capabilities. |
|
|
|
|
|
Features: |
|
|
- Hugging Face Hub integration via PyTorchModelHubMixin |
|
|
- Support for multiple model presets (vitb, vitg, nested variants) |
|
|
- Automatic mixed precision inference |
|
|
- Export capabilities for various formats (GLB, PLY, NPZ, etc.) |
|
|
- Camera pose estimation and metric depth scaling |
|
|
|
|
|
Usage: |
|
|
# Load from Hugging Face Hub |
|
|
model = DepthAnything3.from_pretrained("huggingface/model-name") |
|
|
|
|
|
# Or create with specific preset |
|
|
model = DepthAnything3(preset="vitg") |
|
|
|
|
|
# Run inference |
|
|
prediction = model.inference(images, export_dir="output", export_format="glb") |
|
|
""" |
|
|
|
|
|
_commit_hash: str | None = None |
|
|
|
|
|
def __init__(self, model_name: str = "da3-large", device: str | torch.device | None = None, use_cache: bool = True, **kwargs): |
|
|
""" |
|
|
Initialize DepthAnything3 with specified preset. |
|
|
|
|
|
Args: |
|
|
model_name: The name of the model preset to use. |
|
|
Examples: 'da3-giant', 'da3-large', 'da3metric-large', 'da3nested-giant-large'. |
|
|
device: Target device ('cuda', 'mps', 'cpu'). If None, auto-detect. |
|
|
use_cache: Whether to use model caching (default: True). |
|
|
Set to False to force reload model from disk. |
|
|
**kwargs: Additional keyword arguments (currently unused). |
|
|
""" |
|
|
super().__init__() |
|
|
self.model_name = model_name |
|
|
self.use_cache = use_cache |
|
|
|
|
|
|
|
|
if device is None: |
|
|
device = self._auto_detect_device() |
|
|
self.device = torch.device(device) if isinstance(device, str) else device |
|
|
|
|
|
|
|
|
self.config = load_config(MODEL_REGISTRY[self.model_name]) |
|
|
|
|
|
|
|
|
if use_cache: |
|
|
cache = get_model_cache() |
|
|
self.model = cache.get( |
|
|
model_name=self.model_name, |
|
|
device=self.device, |
|
|
loader_fn=lambda: self._create_model() |
|
|
) |
|
|
else: |
|
|
logger.info(f"Model cache disabled, loading {self.model_name} from disk") |
|
|
self.model = self._create_model() |
|
|
|
|
|
|
|
|
self.model = self.model.to(self.device) |
|
|
self.model.eval() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.device.type in ("cuda", "mps"): |
|
|
self.input_processor = GPUInputProcessor(device=self.device) |
|
|
decoding_info = "NVJPEG support enabled" if self.device.type == "cuda" else "TorchVision decoding" |
|
|
logger.info(f"Using GPUInputProcessor ({decoding_info} on {self.device})") |
|
|
else: |
|
|
self.input_processor = InputProcessor() |
|
|
logger.info("Using standard InputProcessor (optimized CPU pipeline)") |
|
|
|
|
|
self.output_processor = OutputProcessor() |
|
|
|
|
|
def _auto_detect_device(self) -> torch.device: |
|
|
"""Auto-detect best available device.""" |
|
|
if torch.cuda.is_available(): |
|
|
return torch.device("cuda") |
|
|
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
|
|
return torch.device("mps") |
|
|
else: |
|
|
return torch.device("cpu") |
|
|
|
|
|
def _create_model(self) -> nn.Module: |
|
|
"""Create and return new model instance on correct device.""" |
|
|
model = create_object(self.config) |
|
|
model = model.to(self.device) |
|
|
model.eval() |
|
|
return model |
|
|
|
|
|
@torch.inference_mode() |
|
|
def forward( |
|
|
self, |
|
|
image: torch.Tensor, |
|
|
extrinsics: torch.Tensor | None = None, |
|
|
intrinsics: torch.Tensor | None = None, |
|
|
export_feat_layers: list[int] | None = None, |
|
|
infer_gs: bool = False, |
|
|
use_ray_pose: bool = False, |
|
|
ref_view_strategy: str = "saddle_balanced", |
|
|
) -> dict[str, torch.Tensor]: |
|
|
""" |
|
|
Forward pass through the model. |
|
|
|
|
|
Args: |
|
|
image: Input batch with shape ``(B, N, 3, H, W)`` on the model device. |
|
|
extrinsics: Optional camera extrinsics with shape ``(B, N, 4, 4)``. |
|
|
intrinsics: Optional camera intrinsics with shape ``(B, N, 3, 3)``. |
|
|
export_feat_layers: Layer indices to return intermediate features for. |
|
|
infer_gs: Enable Gaussian Splatting branch. |
|
|
use_ray_pose: Use ray-based pose estimation instead of camera decoder. |
|
|
ref_view_strategy: Strategy for selecting reference view from multiple views. |
|
|
|
|
|
Returns: |
|
|
Dictionary containing model predictions |
|
|
""" |
|
|
with torch.no_grad(): |
|
|
|
|
|
if image.device.type == "mps": |
|
|
return self.model( |
|
|
image, extrinsics, intrinsics, export_feat_layers, infer_gs, use_ray_pose, ref_view_strategy |
|
|
) |
|
|
else: |
|
|
|
|
|
autocast_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 |
|
|
with torch.autocast(device_type=image.device.type, dtype=autocast_dtype): |
|
|
return self.model( |
|
|
image, extrinsics, intrinsics, export_feat_layers, infer_gs, use_ray_pose, ref_view_strategy |
|
|
) |
|
|
|
|
|
def inference( |
|
|
self, |
|
|
image: list[np.ndarray | Image.Image | str], |
|
|
extrinsics: np.ndarray | None = None, |
|
|
intrinsics: np.ndarray | None = None, |
|
|
align_to_input_ext_scale: bool = True, |
|
|
infer_gs: bool = False, |
|
|
use_ray_pose: bool = False, |
|
|
ref_view_strategy: str = "saddle_balanced", |
|
|
render_exts: np.ndarray | None = None, |
|
|
render_ixts: np.ndarray | None = None, |
|
|
render_hw: tuple[int, int] | None = None, |
|
|
process_res: int = 504, |
|
|
process_res_method: str = "upper_bound_resize", |
|
|
export_dir: str | None = None, |
|
|
export_format: str = "mini_npz", |
|
|
export_feat_layers: Sequence[int] | None = None, |
|
|
|
|
|
conf_thresh_percentile: float = 40.0, |
|
|
num_max_points: int = 1_000_000, |
|
|
show_cameras: bool = True, |
|
|
|
|
|
feat_vis_fps: int = 15, |
|
|
|
|
|
export_kwargs: Optional[dict] = {}, |
|
|
) -> Prediction: |
|
|
""" |
|
|
Run inference on input images. |
|
|
|
|
|
Args: |
|
|
image: List of input images (numpy arrays, PIL Images, or file paths) |
|
|
extrinsics: Camera extrinsics (N, 4, 4) |
|
|
intrinsics: Camera intrinsics (N, 3, 3) |
|
|
align_to_input_ext_scale: whether to align the input pose scale to the prediction |
|
|
infer_gs: Enable the 3D Gaussian branch (needed for `gs_ply`/`gs_video` exports) |
|
|
use_ray_pose: Use ray-based pose estimation instead of camera decoder (default: False) |
|
|
ref_view_strategy: Strategy for selecting reference view from multiple views. |
|
|
Options: "first", "middle", "saddle_balanced", "saddle_sim_range". |
|
|
Default: "saddle_balanced". For single view input (S ≤ 2), no reordering is performed. |
|
|
render_exts: Optional render extrinsics for Gaussian video export |
|
|
render_ixts: Optional render intrinsics for Gaussian video export |
|
|
render_hw: Optional render resolution for Gaussian video export |
|
|
process_res: Processing resolution |
|
|
process_res_method: Resize method for processing |
|
|
export_dir: Directory to export results |
|
|
export_format: Export format (mini_npz, npz, glb, ply, gs, gs_video) |
|
|
export_feat_layers: Layer indices to export intermediate features from |
|
|
conf_thresh_percentile: [GLB] Lower percentile for adaptive confidence threshold (default: 40.0) # noqa: E501 |
|
|
num_max_points: [GLB] Maximum number of points in the point cloud (default: 1,000,000) |
|
|
show_cameras: [GLB] Show camera wireframes in the exported scene (default: True) |
|
|
feat_vis_fps: [FEAT_VIS] Frame rate for output video (default: 15) |
|
|
export_kwargs: additional arguments to export functions. |
|
|
|
|
|
Returns: |
|
|
Prediction object containing depth maps and camera parameters |
|
|
""" |
|
|
if "gs" in export_format: |
|
|
assert infer_gs, "must set `infer_gs=True` to perform gs-related export." |
|
|
|
|
|
if "colmap" in export_format: |
|
|
assert isinstance(image[0], str), "`image` must be image paths for COLMAP export." |
|
|
|
|
|
|
|
|
imgs_cpu, extrinsics, intrinsics = self._preprocess_inputs( |
|
|
image, extrinsics, intrinsics, process_res, process_res_method |
|
|
) |
|
|
|
|
|
|
|
|
imgs, ex_t, in_t = self._prepare_model_inputs(imgs_cpu, extrinsics, intrinsics) |
|
|
|
|
|
|
|
|
ex_t_norm = self._normalize_extrinsics(ex_t.clone() if ex_t is not None else None) |
|
|
|
|
|
|
|
|
export_feat_layers = list(export_feat_layers) if export_feat_layers is not None else [] |
|
|
|
|
|
raw_output = self._run_model_forward( |
|
|
imgs, ex_t_norm, in_t, export_feat_layers, infer_gs, use_ray_pose, ref_view_strategy |
|
|
) |
|
|
|
|
|
|
|
|
prediction = self._convert_to_prediction(raw_output) |
|
|
|
|
|
|
|
|
prediction = self._align_to_input_extrinsics_intrinsics( |
|
|
extrinsics, intrinsics, prediction, align_to_input_ext_scale |
|
|
) |
|
|
|
|
|
|
|
|
prediction = self._add_processed_images(prediction, imgs_cpu) |
|
|
|
|
|
|
|
|
if export_dir is not None: |
|
|
|
|
|
if "gs" in export_format: |
|
|
if infer_gs and "gs_video" not in export_format: |
|
|
export_format = f"{export_format}-gs_video" |
|
|
if "gs_video" in export_format: |
|
|
if "gs_video" not in export_kwargs: |
|
|
export_kwargs["gs_video"] = {} |
|
|
export_kwargs["gs_video"].update( |
|
|
{ |
|
|
"extrinsics": render_exts, |
|
|
"intrinsics": render_ixts, |
|
|
"out_image_hw": render_hw, |
|
|
} |
|
|
) |
|
|
|
|
|
if "glb" in export_format: |
|
|
if "glb" not in export_kwargs: |
|
|
export_kwargs["glb"] = {} |
|
|
export_kwargs["glb"].update( |
|
|
{ |
|
|
"conf_thresh_percentile": conf_thresh_percentile, |
|
|
"num_max_points": num_max_points, |
|
|
"show_cameras": show_cameras, |
|
|
} |
|
|
) |
|
|
|
|
|
if "feat_vis" in export_format: |
|
|
if "feat_vis" not in export_kwargs: |
|
|
export_kwargs["feat_vis"] = {} |
|
|
export_kwargs["feat_vis"].update( |
|
|
{ |
|
|
"fps": feat_vis_fps, |
|
|
} |
|
|
) |
|
|
|
|
|
if "colmap" in export_format: |
|
|
if "colmap" not in export_kwargs: |
|
|
export_kwargs["colmap"] = {} |
|
|
export_kwargs["colmap"].update( |
|
|
{ |
|
|
"image_paths": image, |
|
|
"conf_thresh_percentile": conf_thresh_percentile, |
|
|
"process_res_method": process_res_method, |
|
|
} |
|
|
) |
|
|
self._export_results(prediction, export_format, export_dir, **export_kwargs) |
|
|
|
|
|
return prediction |
|
|
|
|
|
def _preprocess_inputs( |
|
|
self, |
|
|
image: list[np.ndarray | Image.Image | str], |
|
|
extrinsics: np.ndarray | None = None, |
|
|
intrinsics: np.ndarray | None = None, |
|
|
process_res: int = 504, |
|
|
process_res_method: str = "upper_bound_resize", |
|
|
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: |
|
|
"""Preprocess input images using input processor.""" |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
perform_norm = True |
|
|
if self.device.type in ("cuda", "mps") and not isinstance(self.input_processor, GPUInputProcessor): |
|
|
perform_norm = False |
|
|
|
|
|
imgs_cpu, extrinsics, intrinsics = self.input_processor( |
|
|
image, |
|
|
extrinsics.copy() if extrinsics is not None else None, |
|
|
intrinsics.copy() if intrinsics is not None else None, |
|
|
process_res, |
|
|
process_res_method, |
|
|
perform_normalization=perform_norm, |
|
|
) |
|
|
end_time = time.time() |
|
|
logger.info( |
|
|
"Processed Images Done taking", |
|
|
end_time - start_time, |
|
|
"seconds. Shape: ", |
|
|
imgs_cpu.shape, |
|
|
) |
|
|
return imgs_cpu, extrinsics, intrinsics |
|
|
|
|
|
def _prepare_model_inputs( |
|
|
self, |
|
|
imgs_cpu: torch.Tensor, |
|
|
extrinsics: torch.Tensor | None, |
|
|
intrinsics: torch.Tensor | None, |
|
|
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]: |
|
|
""" |
|
|
Prepare tensors for model input with optimized device transfer. |
|
|
""" |
|
|
device = self._get_model_device() |
|
|
|
|
|
|
|
|
|
|
|
imgs_on_target_device = (imgs_cpu.device.type == device.type) |
|
|
if imgs_on_target_device: |
|
|
|
|
|
|
|
|
imgs = imgs_cpu |
|
|
if imgs.dim() == 3: |
|
|
|
|
|
imgs = imgs.unsqueeze(0).unsqueeze(0) |
|
|
elif imgs.dim() == 4: |
|
|
|
|
|
imgs = imgs.unsqueeze(0) |
|
|
|
|
|
if imgs.dtype == torch.uint8: |
|
|
|
|
|
imgs = imgs.float() / 255.0 |
|
|
imgs = InputProcessor.normalize_tensor( |
|
|
imgs, |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225] |
|
|
) |
|
|
else: |
|
|
|
|
|
if imgs_cpu.dtype == torch.uint8: |
|
|
|
|
|
if device.type == "cuda": |
|
|
imgs_cpu = imgs_cpu.pin_memory() |
|
|
|
|
|
imgs = imgs_cpu.to(device, non_blocking=True).float() / 255.0 |
|
|
imgs = InputProcessor.normalize_tensor( |
|
|
imgs, |
|
|
mean=[0.485, 0.456, 0.406], |
|
|
std=[0.229, 0.224, 0.225] |
|
|
) |
|
|
imgs = imgs[None] |
|
|
else: |
|
|
|
|
|
if device.type == "cuda": |
|
|
imgs_cpu = imgs_cpu.pin_memory() |
|
|
imgs = imgs_cpu.to(device, non_blocking=True)[None].float() |
|
|
|
|
|
|
|
|
ex_t = ( |
|
|
extrinsics.pin_memory().to(device, non_blocking=True)[None].float() |
|
|
if extrinsics is not None and device.type == "cuda" and extrinsics.device.type == "cpu" |
|
|
else extrinsics.to(device, non_blocking=True)[None].float() |
|
|
if extrinsics is not None and extrinsics.device != device |
|
|
else extrinsics[None].float() |
|
|
if extrinsics is not None |
|
|
else None |
|
|
) |
|
|
in_t = ( |
|
|
intrinsics.pin_memory().to(device, non_blocking=True)[None].float() |
|
|
if intrinsics is not None and device.type == "cuda" and intrinsics.device.type == "cpu" |
|
|
else intrinsics.to(device, non_blocking=True)[None].float() |
|
|
if intrinsics is not None and intrinsics.device != device |
|
|
else intrinsics[None].float() |
|
|
if intrinsics is not None |
|
|
else None |
|
|
) |
|
|
|
|
|
return imgs, ex_t, in_t |
|
|
|
|
|
def _normalize_extrinsics(self, ex_t: torch.Tensor | None) -> torch.Tensor | None: |
|
|
"""Normalize extrinsics""" |
|
|
if ex_t is None: |
|
|
return None |
|
|
transform = affine_inverse(ex_t[:, :1]) |
|
|
ex_t_norm = ex_t @ transform |
|
|
c2ws = affine_inverse(ex_t_norm) |
|
|
translations = c2ws[..., :3, 3] |
|
|
dists = translations.norm(dim=-1) |
|
|
median_dist = torch.median(dists) |
|
|
median_dist = torch.clamp(median_dist, min=1e-1) |
|
|
ex_t_norm[..., :3, 3] = ex_t_norm[..., :3, 3] / median_dist |
|
|
return ex_t_norm |
|
|
|
|
|
def _align_to_input_extrinsics_intrinsics( |
|
|
self, |
|
|
extrinsics: torch.Tensor | None, |
|
|
intrinsics: torch.Tensor | None, |
|
|
prediction: Prediction, |
|
|
align_to_input_ext_scale: bool = True, |
|
|
ransac_view_thresh: int = 10, |
|
|
) -> Prediction: |
|
|
"""Align depth map to input extrinsics""" |
|
|
if extrinsics is None: |
|
|
return prediction |
|
|
prediction.intrinsics = intrinsics.numpy() |
|
|
_, _, scale, aligned_extrinsics = align_poses_umeyama( |
|
|
prediction.extrinsics, |
|
|
extrinsics.numpy(), |
|
|
ransac=len(extrinsics) >= ransac_view_thresh, |
|
|
return_aligned=True, |
|
|
random_state=42, |
|
|
) |
|
|
if align_to_input_ext_scale: |
|
|
prediction.extrinsics = extrinsics[..., :3, :].numpy() |
|
|
prediction.depth /= scale |
|
|
else: |
|
|
prediction.extrinsics = aligned_extrinsics |
|
|
return prediction |
|
|
|
|
|
def _run_model_forward( |
|
|
self, |
|
|
imgs: torch.Tensor, |
|
|
ex_t: torch.Tensor | None, |
|
|
in_t: torch.Tensor | None, |
|
|
export_feat_layers: Sequence[int] | None = None, |
|
|
infer_gs: bool = False, |
|
|
use_ray_pose: bool = False, |
|
|
ref_view_strategy: str = "saddle_balanced", |
|
|
) -> dict[str, torch.Tensor]: |
|
|
"""Run model forward pass.""" |
|
|
device = imgs.device |
|
|
need_sync = device.type == "cuda" |
|
|
if need_sync: |
|
|
torch.cuda.synchronize(device) |
|
|
start_time = time.time() |
|
|
feat_layers = list(export_feat_layers) if export_feat_layers is not None else None |
|
|
output = self.forward(imgs, ex_t, in_t, feat_layers, infer_gs, use_ray_pose, ref_view_strategy) |
|
|
if need_sync: |
|
|
torch.cuda.synchronize(device) |
|
|
end_time = time.time() |
|
|
logger.info(f"Model Forward Pass Done. Time: {end_time - start_time} seconds") |
|
|
return output |
|
|
|
|
|
def _convert_to_prediction(self, raw_output: dict[str, torch.Tensor]) -> Prediction: |
|
|
"""Convert raw model output to Prediction object.""" |
|
|
start_time = time.time() |
|
|
output = self.output_processor(raw_output) |
|
|
end_time = time.time() |
|
|
logger.info(f"Conversion to Prediction Done. Time: {end_time - start_time} seconds") |
|
|
return output |
|
|
|
|
|
def _add_processed_images(self, prediction: Prediction, imgs_cpu: torch.Tensor) -> Prediction: |
|
|
"""Add processed images to prediction for visualization.""" |
|
|
|
|
|
processed_imgs = imgs_cpu.permute(0, 2, 3, 1).cpu().numpy() |
|
|
|
|
|
if imgs_cpu.dtype == torch.uint8: |
|
|
|
|
|
pass |
|
|
else: |
|
|
|
|
|
mean = np.array([0.485, 0.456, 0.406]) |
|
|
std = np.array([0.229, 0.224, 0.225]) |
|
|
processed_imgs = processed_imgs * std + mean |
|
|
processed_imgs = np.clip(processed_imgs, 0, 1) |
|
|
processed_imgs = (processed_imgs * 255).astype(np.uint8) |
|
|
|
|
|
prediction.processed_images = processed_imgs |
|
|
return prediction |
|
|
|
|
|
def _export_results( |
|
|
self, prediction: Prediction, export_format: str, export_dir: str, **kwargs |
|
|
) -> None: |
|
|
"""Export results to specified format and directory.""" |
|
|
start_time = time.time() |
|
|
export(prediction, export_format, export_dir, **kwargs) |
|
|
end_time = time.time() |
|
|
logger.info(f"Export Results Done. Time: {end_time - start_time} seconds") |
|
|
|
|
|
def _get_model_device(self) -> torch.device: |
|
|
""" |
|
|
Get the device where the model is located. |
|
|
|
|
|
Returns: |
|
|
Device where the model parameters are located |
|
|
|
|
|
Raises: |
|
|
ValueError: If no tensors are found in the model |
|
|
""" |
|
|
if self.device is not None: |
|
|
return self.device |
|
|
|
|
|
|
|
|
for param in self.parameters(): |
|
|
self.device = param.device |
|
|
return param.device |
|
|
|
|
|
|
|
|
for buffer in self.buffers(): |
|
|
self.device = buffer.device |
|
|
return buffer.device |
|
|
|
|
|
raise ValueError("No tensor found in model") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def batch_inference( |
|
|
self, |
|
|
images: list[np.ndarray | Image.Image | str], |
|
|
process_res: int = 504, |
|
|
batch_size: int | str = "auto", |
|
|
max_batch_size: int = 64, |
|
|
target_memory_utilization: float = 0.85, |
|
|
progress_callback: callable | None = None, |
|
|
) -> list[Prediction]: |
|
|
""" |
|
|
Run inference on multiple images with adaptive batching. |
|
|
|
|
|
This method automatically determines optimal batch sizes based on |
|
|
available GPU memory, maximizing throughput while preventing OOM errors. |
|
|
|
|
|
Args: |
|
|
images: List of input images (numpy arrays, PIL Images, or file paths) |
|
|
process_res: Processing resolution (default: 504) |
|
|
batch_size: Batch size or "auto" for adaptive batching (default: "auto") |
|
|
max_batch_size: Maximum batch size when using adaptive batching (default: 64) |
|
|
target_memory_utilization: Target GPU memory usage 0.0-1.0 (default: 0.85) |
|
|
progress_callback: Optional callback(processed, total) for progress updates |
|
|
|
|
|
Returns: |
|
|
List of Prediction objects, one per batch |
|
|
|
|
|
Example: |
|
|
>>> model = DepthAnything3(model_name="da3-large") |
|
|
>>> images = ["img1.jpg", "img2.jpg", ..., "img100.jpg"] |
|
|
>>> |
|
|
>>> # Adaptive batching (recommended) |
|
|
>>> results = model.batch_inference(images, process_res=518) |
|
|
>>> |
|
|
>>> # Fixed batch size |
|
|
>>> results = model.batch_inference(images, batch_size=4) |
|
|
>>> |
|
|
>>> # With progress callback |
|
|
>>> def on_progress(done, total): |
|
|
... print(f"Processed {done}/{total}") |
|
|
>>> results = model.batch_inference(images, progress_callback=on_progress) |
|
|
""" |
|
|
import gc |
|
|
|
|
|
num_images = len(images) |
|
|
if num_images == 0: |
|
|
return [] |
|
|
|
|
|
results: list[Prediction] = [] |
|
|
|
|
|
|
|
|
if batch_size == "auto": |
|
|
config = AdaptiveBatchConfig( |
|
|
max_batch_size=max_batch_size, |
|
|
target_memory_utilization=target_memory_utilization, |
|
|
) |
|
|
calculator = AdaptiveBatchSizeCalculator( |
|
|
model_name=self.model_name, |
|
|
device=self.device, |
|
|
config=config, |
|
|
) |
|
|
|
|
|
for batch_info in adaptive_batch_iterator(images, calculator, process_res): |
|
|
|
|
|
prediction = self.inference( |
|
|
image=batch_info.items, |
|
|
process_res=process_res, |
|
|
) |
|
|
results.append(prediction) |
|
|
|
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(batch_info.end_idx, num_images) |
|
|
|
|
|
|
|
|
if not batch_info.is_last: |
|
|
gc.collect() |
|
|
if self.device.type == "cuda": |
|
|
torch.cuda.empty_cache() |
|
|
elif self.device.type == "mps": |
|
|
torch.mps.empty_cache() |
|
|
|
|
|
|
|
|
if calculator.config.enable_profiling and self.device.type == "cuda": |
|
|
memory_used = torch.cuda.max_memory_allocated(self.device) / (1024 * 1024) |
|
|
calculator.update_from_profiling( |
|
|
batch_size=batch_info.batch_size, |
|
|
memory_used_mb=memory_used, |
|
|
process_res=process_res, |
|
|
) |
|
|
torch.cuda.reset_peak_memory_stats(self.device) |
|
|
|
|
|
else: |
|
|
|
|
|
fixed_batch_size = int(batch_size) |
|
|
for i in range(0, num_images, fixed_batch_size): |
|
|
end_idx = min(i + fixed_batch_size, num_images) |
|
|
batch_images = images[i:end_idx] |
|
|
|
|
|
prediction = self.inference( |
|
|
image=batch_images, |
|
|
process_res=process_res, |
|
|
) |
|
|
results.append(prediction) |
|
|
|
|
|
if progress_callback: |
|
|
progress_callback(end_idx, num_images) |
|
|
|
|
|
|
|
|
if end_idx < num_images: |
|
|
gc.collect() |
|
|
if self.device.type == "cuda": |
|
|
torch.cuda.empty_cache() |
|
|
elif self.device.type == "mps": |
|
|
torch.mps.empty_cache() |
|
|
|
|
|
return results |
|
|
|
|
|
def get_optimal_batch_size( |
|
|
self, |
|
|
process_res: int = 504, |
|
|
target_utilization: float = 0.85, |
|
|
) -> int: |
|
|
""" |
|
|
Get the optimal batch size for current GPU memory state. |
|
|
|
|
|
Args: |
|
|
process_res: Processing resolution (default: 504) |
|
|
target_utilization: Target GPU memory usage 0.0-1.0 (default: 0.85) |
|
|
|
|
|
Returns: |
|
|
Recommended batch size |
|
|
|
|
|
Example: |
|
|
>>> model = DepthAnything3(model_name="da3-large") |
|
|
>>> batch_size = model.get_optimal_batch_size(process_res=518) |
|
|
>>> print(f"Optimal batch size: {batch_size}") |
|
|
""" |
|
|
return estimate_max_batch_size( |
|
|
model_name=self.model_name, |
|
|
device=self.device, |
|
|
process_res=process_res, |
|
|
target_utilization=target_utilization, |
|
|
) |
|
|
|