""" Core DepthPro depth estimation wrapper. """ from __future__ import annotations from dataclasses import dataclass from pathlib import Path from typing import Optional, Union import numpy as np import torch from PIL import Image try: from transformers import DepthProImageProcessorFast, DepthProForDepthEstimation except ImportError as exc: raise ImportError( "transformers>=4.40.0 is required. " "Install it with: pip install transformers[torch]" ) from exc @dataclass class DepthResult: """Container for depth estimation results.""" depth: np.ndarray """(H, W) float32 metric depth map in meters.""" focal_length: float """Estimated focal length in pixels (for the original image resolution).""" field_of_view: float """Estimated horizontal field of view in degrees.""" image: np.ndarray """(H, W, 3) uint8 RGB image at original resolution.""" confidence: Optional[np.ndarray] = None """(H, W) optional confidence / uncertainty map.""" @property def height(self) -> int: return self.depth.shape[0] @property def width(self) -> int: return self.depth.shape[1] class DepthProEstimator: """ High-level wrapper around Apple's DepthPro for metric depth estimation. DepthPro is a single-image zero-shot metric depth estimator. Unlike relative-depth models (e.g. Depth Anything), it outputs **absolute-scale metric depth in meters** and also estimates the **camera focal length** and **field of view** automatically — no calibration required. Parameters ---------- model_name : str, default "apple/DepthPro-hf" HuggingFace model id or local path. device : str or torch.device, default "cuda:0" PyTorch device. CUDA strongly recommended for 952M-parameter ViT-L. dtype : torch.dtype, default torch.float16 Inference dtype. fp16 halves memory and is the default the model was trained with; fp32 gives marginally higher precision. """ _MODEL_INPUT_SIZE: int = 1536 # DepthPro always processes at 1536×1536 def __init__( self, model_name: str = "apple/DepthPro-hf", device: Union[str, torch.device] = "cuda:0", dtype: torch.dtype = torch.float16, ): self.device = torch.device(device) self.dtype = dtype self.model_name = model_name if not torch.cuda.is_available() and self.device.type == "cuda": raise RuntimeError( "CUDA is not available but device='cuda' was requested. " "DepthPro is a 952M ViT-L model; CPU inference will be extremely slow. " "Pass device='cpu' explicitly if you really want this." ) self._load_model() # ------------------------------------------------------------------ # # Loading # # ------------------------------------------------------------------ # def _load_model(self) -> None: """Load the processor and model from HF.""" self.processor = DepthProImageProcessorFast.from_pretrained(self.model_name) self.model = DepthProForDepthEstimation.from_pretrained( self.model_name, torch_dtype=self.dtype, ).to(self.device) self.model.eval() # ------------------------------------------------------------------ # # Public API # # ------------------------------------------------------------------ # @torch.no_grad() def estimate( self, image: Union[str, Path, Image.Image, np.ndarray], *, return_confidence: bool = False, ) -> DepthResult: """ Run metric depth estimation on a single RGB image. Parameters ---------- image : str, Path, PIL.Image, or np.ndarray Input RGB image. If a path, loaded with PIL. return_confidence : bool, default False If True and the model provides a confidence map, include it in the result. Returns ------- DepthResult Dataclass containing: * ``depth`` — (H, W) metric depth in meters * ``focal_length`` — estimated focal length (px) at original res * ``field_of_view`` — estimated horizontal FOV in degrees * ``image`` — original RGB image as (H, W, 3) uint8 """ # ---- load / normalise image --------------------------------------- pil_image = self._to_pil(image) rgb_array = np.array(pil_image.convert("RGB"), dtype=np.uint8) # ---- preprocess & forward pass ------------------------------------ inputs = self.processor(images=pil_image, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} outputs = self.model(**inputs) # ---- post-process back to original resolution --------------------- post = self.processor.post_process_depth_estimation( outputs, target_sizes=[(pil_image.height, pil_image.width)], )[0] depth = post["predicted_depth"] # torch.Tensor [H, W] focal_length = post["focal_length"] fov = post["field_of_view"] # DepthPro returns focal_length for the 1536×1536 processed image. # Scale it to the original image resolution. focal_original = focal_length * (pil_image.width / self._MODEL_INPUT_SIZE) depth_np = depth.cpu().float().numpy() confidence_np = None if return_confidence and "confidence" in post: confidence_np = post["confidence"].cpu().float().numpy() return DepthResult( depth=depth_np, focal_length=focal_original.item(), field_of_view=fov.item(), image=rgb_array, confidence=confidence_np, ) @torch.no_grad() def estimate_batch( self, images: list, *, return_confidence: bool = False, ) -> list[DepthResult]: """ Run depth estimation on a batch of images. Parameters ---------- images : list of str, Path, PIL.Image, or np.ndarray return_confidence : bool Returns ------- list[DepthResult] """ pil_images = [self._to_pil(img) for img in images] rgb_arrays = [np.array(p.convert("RGB"), dtype=np.uint8) for p in pil_images] # Processor batch handling inputs = self.processor(images=pil_images, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} outputs = self.model(**inputs) posts = self.processor.post_process_depth_estimation( outputs, target_sizes=[(p.height, p.width) for p in pil_images], ) results = [] for post, pil_img, rgb in zip(posts, pil_images, rgb_arrays): depth = post["predicted_depth"].cpu().float().numpy() focal = post["focal_length"] * (pil_img.width / self._MODEL_INPUT_SIZE) fov = post["field_of_view"] conf = None if return_confidence and "confidence" in post: conf = post["confidence"].cpu().float().numpy() results.append(DepthResult( depth=depth, focal_length=focal.item(), field_of_view=fov.item(), image=rgb, confidence=conf, )) return results # ------------------------------------------------------------------ # # Helpers # # ------------------------------------------------------------------ # @staticmethod def _to_pil(image: Union[str, Path, Image.Image, np.ndarray]) -> Image.Image: """Normalise input to a PIL RGB image.""" if isinstance(image, (str, Path)): return Image.open(str(image)).convert("RGB") if isinstance(image, np.ndarray): if image.dtype != np.uint8: image = (image * 255).clip(0, 255).astype(np.uint8) return Image.fromarray(image).convert("RGB") if isinstance(image, Image.Image): return image.convert("RGB") raise TypeError(f"Unsupported image type: {type(image)}")