depthpro-wrapper / depthpro_wrapper /depth_estimator.py
bdck's picture
Upload depthpro_wrapper/depth_estimator.py
99a77de verified
Raw
History Blame Contribute Delete
8.45 kB
"""
Core DepthPro depth estimation wrapper.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union
import numpy as np
import torch
from PIL import Image
try:
from transformers import DepthProImageProcessorFast, DepthProForDepthEstimation
except ImportError as exc:
raise ImportError(
"transformers>=4.40.0 is required. "
"Install it with: pip install transformers[torch]"
) from exc
@dataclass
class DepthResult:
"""Container for depth estimation results."""
depth: np.ndarray
"""(H, W) float32 metric depth map in meters."""
focal_length: float
"""Estimated focal length in pixels (for the original image resolution)."""
field_of_view: float
"""Estimated horizontal field of view in degrees."""
image: np.ndarray
"""(H, W, 3) uint8 RGB image at original resolution."""
confidence: Optional[np.ndarray] = None
"""(H, W) optional confidence / uncertainty map."""
@property
def height(self) -> int:
return self.depth.shape[0]
@property
def width(self) -> int:
return self.depth.shape[1]
class DepthProEstimator:
"""
High-level wrapper around Apple's DepthPro for metric depth estimation.
DepthPro is a single-image zero-shot metric depth estimator. Unlike
relative-depth models (e.g. Depth Anything), it outputs **absolute-scale
metric depth in meters** and also estimates the **camera focal length**
and **field of view** automatically — no calibration required.
Parameters
----------
model_name : str, default "apple/DepthPro-hf"
HuggingFace model id or local path.
device : str or torch.device, default "cuda:0"
PyTorch device. CUDA strongly recommended for 952M-parameter ViT-L.
dtype : torch.dtype, default torch.float16
Inference dtype. fp16 halves memory and is the default the model
was trained with; fp32 gives marginally higher precision.
"""
_MODEL_INPUT_SIZE: int = 1536 # DepthPro always processes at 1536×1536
def __init__(
self,
model_name: str = "apple/DepthPro-hf",
device: Union[str, torch.device] = "cuda:0",
dtype: torch.dtype = torch.float16,
):
self.device = torch.device(device)
self.dtype = dtype
self.model_name = model_name
if not torch.cuda.is_available() and self.device.type == "cuda":
raise RuntimeError(
"CUDA is not available but device='cuda' was requested. "
"DepthPro is a 952M ViT-L model; CPU inference will be extremely slow. "
"Pass device='cpu' explicitly if you really want this."
)
self._load_model()
# ------------------------------------------------------------------ #
# Loading #
# ------------------------------------------------------------------ #
def _load_model(self) -> None:
"""Load the processor and model from HF."""
self.processor = DepthProImageProcessorFast.from_pretrained(self.model_name)
self.model = DepthProForDepthEstimation.from_pretrained(
self.model_name,
torch_dtype=self.dtype,
).to(self.device)
self.model.eval()
# ------------------------------------------------------------------ #
# Public API #
# ------------------------------------------------------------------ #
@torch.no_grad()
def estimate(
self,
image: Union[str, Path, Image.Image, np.ndarray],
*,
return_confidence: bool = False,
) -> DepthResult:
"""
Run metric depth estimation on a single RGB image.
Parameters
----------
image : str, Path, PIL.Image, or np.ndarray
Input RGB image. If a path, loaded with PIL.
return_confidence : bool, default False
If True and the model provides a confidence map, include it in
the result.
Returns
-------
DepthResult
Dataclass containing:
* ``depth`` — (H, W) metric depth in meters
* ``focal_length`` — estimated focal length (px) at original res
* ``field_of_view`` — estimated horizontal FOV in degrees
* ``image`` — original RGB image as (H, W, 3) uint8
"""
# ---- load / normalise image ---------------------------------------
pil_image = self._to_pil(image)
rgb_array = np.array(pil_image.convert("RGB"), dtype=np.uint8)
# ---- preprocess & forward pass ------------------------------------
inputs = self.processor(images=pil_image, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.model(**inputs)
# ---- post-process back to original resolution ---------------------
post = self.processor.post_process_depth_estimation(
outputs,
target_sizes=[(pil_image.height, pil_image.width)],
)[0]
depth = post["predicted_depth"] # torch.Tensor [H, W]
focal_length = post["focal_length"]
fov = post["field_of_view"]
# DepthPro returns focal_length for the 1536×1536 processed image.
# Scale it to the original image resolution.
focal_original = focal_length * (pil_image.width / self._MODEL_INPUT_SIZE)
depth_np = depth.cpu().float().numpy()
confidence_np = None
if return_confidence and "confidence" in post:
confidence_np = post["confidence"].cpu().float().numpy()
return DepthResult(
depth=depth_np,
focal_length=focal_original.item(),
field_of_view=fov.item(),
image=rgb_array,
confidence=confidence_np,
)
@torch.no_grad()
def estimate_batch(
self,
images: list,
*,
return_confidence: bool = False,
) -> list[DepthResult]:
"""
Run depth estimation on a batch of images.
Parameters
----------
images : list of str, Path, PIL.Image, or np.ndarray
return_confidence : bool
Returns
-------
list[DepthResult]
"""
pil_images = [self._to_pil(img) for img in images]
rgb_arrays = [np.array(p.convert("RGB"), dtype=np.uint8) for p in pil_images]
# Processor batch handling
inputs = self.processor(images=pil_images, return_tensors="pt")
inputs = {k: v.to(self.device) for k, v in inputs.items()}
outputs = self.model(**inputs)
posts = self.processor.post_process_depth_estimation(
outputs,
target_sizes=[(p.height, p.width) for p in pil_images],
)
results = []
for post, pil_img, rgb in zip(posts, pil_images, rgb_arrays):
depth = post["predicted_depth"].cpu().float().numpy()
focal = post["focal_length"] * (pil_img.width / self._MODEL_INPUT_SIZE)
fov = post["field_of_view"]
conf = None
if return_confidence and "confidence" in post:
conf = post["confidence"].cpu().float().numpy()
results.append(DepthResult(
depth=depth,
focal_length=focal.item(),
field_of_view=fov.item(),
image=rgb,
confidence=conf,
))
return results
# ------------------------------------------------------------------ #
# Helpers #
# ------------------------------------------------------------------ #
@staticmethod
def _to_pil(image: Union[str, Path, Image.Image, np.ndarray]) -> Image.Image:
"""Normalise input to a PIL RGB image."""
if isinstance(image, (str, Path)):
return Image.open(str(image)).convert("RGB")
if isinstance(image, np.ndarray):
if image.dtype != np.uint8:
image = (image * 255).clip(0, 255).astype(np.uint8)
return Image.fromarray(image).convert("RGB")
if isinstance(image, Image.Image):
return image.convert("RGB")
raise TypeError(f"Unsupported image type: {type(image)}")