File size: 5,772 Bytes
05c6078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcfd69e
05c6078
 
bcfd69e
 
 
05c6078
 
 
 
bcfd69e
 
 
 
 
 
 
 
 
 
05c6078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bcfd69e
 
 
 
 
05c6078
 
bcfd69e
05c6078
 
 
 
 
 
 
bcfd69e
05c6078
bcfd69e
 
 
 
 
 
 
 
05c6078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from __future__ import annotations

from pathlib import Path

import cv2
import numpy as np
import torch
from PIL import Image

try:  # pragma: no cover - optional dependency resolution
    from depth_anything_3.api import DepthAnything3  # type: ignore
    from depth_anything_3.utils.visualize import visualize_depth  # type: ignore
except ModuleNotFoundError:  # pragma: no cover
    import sys

    ROOT = Path(__file__).resolve().parents[1]
    sys.path.append(str(ROOT / "src"))
    from depth_anything_3.api import DepthAnything3  # type: ignore  # noqa: E402
    from depth_anything_3.utils.visualize import visualize_depth  # type: ignore  # noqa: E402


def crop_nonblack(img: Image.Image, frac: float = 0.05) -> Image.Image:
    w, h = img.size
    dx = int(round(w * frac))
    dy = int(round(h * frac))
    return img.crop((dx, dy, w - dx, h - dy))


def remove_global_plane(depth: np.ndarray, method: str = "least_squares") -> np.ndarray:
    if depth.ndim != 2:
        return depth
    method = (method or "least_squares").lower()
    if method in {"none", "off"}:
        return depth
    h, w = depth.shape
    yy, xx = np.mgrid[0:h, 0:w].astype(np.float32)
    points = np.stack((xx.flatten(), yy.flatten()), axis=1)
    values = depth.astype(np.float32).reshape(-1, 1)
    coef = None
    if method in {"ls", "least_squares", "lstsq"}:
        try:
            coef, *_ = np.linalg.lstsq(
                np.concatenate([points, np.ones((points.shape[0], 1), dtype=np.float32)], axis=1),
                values,
                rcond=None,
            )
        except np.linalg.LinAlgError:
            coef = None
    if coef is None:
        return depth
    plane = (points @ coef[:2] + coef[2]).reshape(h, w)
    return depth - plane


def pick_flat_patch(
    depth: np.ndarray,
    patch: int = 96,
    std_thresh: float = 0.03,
    grad_thresh: float = 0.35,
    water_mask: np.ndarray | None = None,
):
    depth = depth.astype(np.float32)
    if depth.ndim != 2:
        raise ValueError("Depth map must be 2D (H, W)")

    patch = max(3, min(patch, min(depth.shape)))
    if patch % 2 == 0:
        patch += 1
    depth_norm = (depth - depth.min()) / (np.ptp(depth) + 1e-6)

    import torch.nn.functional as F

    def box_mean(arr, k):
        pad = k // 2
        t = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0)
        t = F.pad(t, (pad, pad, pad, pad), mode="reflect")
        mean = F.avg_pool2d(t, kernel_size=k, stride=1, padding=0, count_include_pad=False)
        return mean.squeeze(0).squeeze(0).numpy()

    mean = box_mean(depth_norm, patch)
    mean_sq = box_mean(depth_norm * depth_norm, patch)
    var = np.maximum(mean_sq - mean * mean, 0.0)
    std_map = np.sqrt(var)

    dy, dx = np.gradient(depth_norm)
    grad = np.sqrt(dx * dx + dy * dy)
    grad_ref = np.percentile(grad, 95) + 1e-6
    grad_norm = np.clip(grad / grad_ref, 0.0, 1.0)
    grad_mask = grad_norm < grad_thresh

    landing_mask = grad_mask
    if water_mask is not None and water_mask.shape == grad_mask.shape:
        landing_mask = landing_mask & (~water_mask)

    masked_std = np.where(landing_mask, std_map, np.inf)
    if not np.isfinite(masked_std).any():
        masked_std = std_map
    y, x = np.unravel_index(np.argmin(masked_std), masked_std.shape)
    half = patch // 2
    y0, y1 = max(y - half, 0), min(y + half, depth.shape[0] - 1)
    x0, x1 = max(x - half, 0), min(x + half, depth.shape[1] - 1)
    return (x0, y0, x1, y1), std_map, grad_norm, grad_mask, landing_mask


class DepthEngine:
    """Caches DepthAnything models and runs inference at bounded resolution."""

    def __init__(self):
        self._model_cache: dict[str, tuple[DepthAnything3, torch.device]] = {}

    def _load_model(self, model_id: str) -> tuple[DepthAnything3, torch.device]:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model = DepthAnything3.from_pretrained(model_id).to(device)
        model.eval()
        return model, device

    def get_model(self, model_id: str) -> tuple[DepthAnything3, torch.device]:
        if model_id not in self._model_cache:
            self._model_cache[model_id] = self._load_model(model_id)
        return self._model_cache[model_id]

    def predict_depth(
        self, image: np.ndarray, model_id: str, process_res_cap: int, plane_method: str = "least_squares"
    ) -> tuple[np.ndarray, np.ndarray, int, dict[str, float]]:
        import time as _time

        t0 = _time.perf_counter()
        model, device = self.get_model(model_id)
        process_res = min(max(image.shape[0], image.shape[1]), int(process_res_cap))
        t_pre = _time.perf_counter()
        with torch.inference_mode():
            pred = model.inference(
                image=[image],
                process_res=process_res,
                process_res_method="upper_bound_resize",
                export_dir=None,
            )
        t_model = _time.perf_counter()
        depth_raw = np.array(pred.depth[0])
        depth = remove_global_plane(depth_raw, method=plane_method)
        t_post = _time.perf_counter()
        timings = {
            "prep_ms": (t_pre - t0) * 1000.0,
            "model_ms": (t_model - t_pre) * 1000.0,
            "plane_ms": (t_post - t_model) * 1000.0,
        }
        return depth_raw, depth, process_res, timings


def smooth_depth(depth: np.ndarray, sigma: float) -> np.ndarray:
    if sigma <= 0:
        return depth
    k = max(3, int(round(sigma * 3)) * 2 + 1)
    try:
        depth = cv2.GaussianBlur(depth, (k, k), sigmaX=sigma, sigmaY=sigma)
    except Exception:
        pass
    return depth


__all__ = [
    "DepthEngine",
    "crop_nonblack",
    "pick_flat_patch",
    "remove_global_plane",
    "smooth_depth",
    "visualize_depth",
]