File size: 2,718 Bytes
5412d82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
Depth estimation using Depth Anything V2 Small.

Loads the model via the HuggingFace transformers depth-estimation pipeline.
Returns uint8 depth maps normalised to 0-255 and resized to match the input.
"""

import numpy as np
import torch
from PIL import Image
from transformers import pipeline

from ..config import DEPTH_MODEL


class DepthEstimator:
    """Depth Anything V2 Small wrapper around the HuggingFace pipeline."""

    def __init__(self) -> None:
        """Load the depth-estimation pipeline onto the available device."""
        print("Loading Depth Anything V2 Small...")

        # pipeline() uses a plain int device (0 = first CUDA GPU, -1 = CPU).
        # device_map={"": "cuda"} is for from_pretrained, not pipeline — using
        # it here leaves the pipeline's internal device as -1 (CPU) and causes
        # a device mismatch when moving input tensors.
        device: int = 0 if torch.cuda.is_available() else -1

        self.pipe = pipeline(
            task="depth-estimation",
            model=DEPTH_MODEL,
            device=device,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
        )

        if torch.cuda.is_available():
            print(
                f"  GPU memory allocated: "
                f"{torch.cuda.memory_allocated() / 1024**2:.0f} MB"
            )

    def estimate_depth(self, image: np.ndarray) -> np.ndarray:
        """Estimate a depth map from an RGB image.

        Args:
            image: uint8 RGB numpy array of shape (H, W, 3).

        Returns:
            uint8 numpy array of shape (H, W) with values in [0, 255].
            Higher values indicate objects closer to the camera.
        """
        h, w = image.shape[:2]
        pil_image = Image.fromarray(image)

        with torch.inference_mode():
            result = self.pipe(pil_image)

        # The pipeline returns a dict; "depth" is a PIL Image whose mode is
        # typically "I" (32-bit int) or "F" (32-bit float).
        depth_pil: Image.Image = result["depth"]

        # Resize to original spatial dimensions before normalisation so that
        # BILINEAR interpolation operates on the native depth values.
        depth_resized = depth_pil.resize((w, h), Image.BILINEAR)
        depth_array = np.array(depth_resized, dtype=np.float32)

        # Normalise to uint8.  Guard against a flat scene (max == min) by
        # clamping the range to at least 1 to avoid divide-by-zero.
        d_min = float(depth_array.min())
        d_max = float(depth_array.max())
        d_range = d_max - d_min if d_max > d_min else 1.0
        depth_uint8 = ((depth_array - d_min) / d_range * 255).astype(np.uint8)

        return depth_uint8