yakvrz commited on
Commit
deeabb9
·
1 Parent(s): a463107

Tune defaults for masks and clarify warnings

Browse files
Files changed (3) hide show
  1. app/config.py +61 -0
  2. app/safety.py +463 -0
  3. app/segmentation.py +177 -0
app/config.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+
6
+ VISLOC_DIR = Path("data/Image/VISLOC")
7
+ HAGDAVS_DIR = Path("data/Image/HAGDAVS")
8
+ VIDEO_DIR = Path("data/Video")
9
+ IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG")
10
+ VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".webm", ".m4v"}
11
+ DEFAULT_ALTITUDE_M = 450.0
12
+ ASSUMED_FOV_DEG = 90.0
13
+ DEFAULT_MODEL_ID = "depth-anything/DA3MONO-LARGE"
14
+ SEGMENTATION_MODEL_ID = "facebook/sam3"
15
+ SEGMENTATION_MAX_SIDE = 384
16
+ SEGMENTATION_SCORE_THRESH = 0.5
17
+ SEGMENTATION_MASK_THRESH = 0.5
18
+ WATER_PROMPT = "water, river, lake, ocean, sea"
19
+ ROAD_PROMPT = "road, highway, street, runway"
20
+
21
+
22
+ @dataclass(frozen=True)
23
+ class AnalyzerSettings:
24
+ """Bundle knobs shared between the UI and the processing pipeline."""
25
+
26
+ footprint_m: float = 15.0
27
+ std_thresh: float = 0.005
28
+ grad_thresh: float = 0.1
29
+ clearance_factor: float = 0.0
30
+ process_res_cap: int = 1024
31
+ depth_smoothing_base: float = 0.8
32
+ segmentation_max_side: int = SEGMENTATION_MAX_SIDE
33
+ segmentation_score_thresh: float = SEGMENTATION_SCORE_THRESH
34
+ segmentation_mask_thresh: float = SEGMENTATION_MASK_THRESH
35
+ water_prompt: str = WATER_PROMPT
36
+ road_prompt: str = ROAD_PROMPT
37
+ coverage_strictness: float = 0.95
38
+ openness_weight: float = 0.3
39
+ texture_threshold: float = 0.5
40
+ altitude_m: float = DEFAULT_ALTITUDE_M
41
+ fov_deg: float = ASSUMED_FOV_DEG
42
+ model_id: str = DEFAULT_MODEL_ID
43
+
44
+
45
+ __all__ = [
46
+ "VISLOC_DIR",
47
+ "HAGDAVS_DIR",
48
+ "VIDEO_DIR",
49
+ "IMAGE_EXTS",
50
+ "VIDEO_EXTS",
51
+ "DEFAULT_ALTITUDE_M",
52
+ "ASSUMED_FOV_DEG",
53
+ "DEFAULT_MODEL_ID",
54
+ "SEGMENTATION_MODEL_ID",
55
+ "SEGMENTATION_MAX_SIDE",
56
+ "SEGMENTATION_SCORE_THRESH",
57
+ "SEGMENTATION_MASK_THRESH",
58
+ "WATER_PROMPT",
59
+ "ROAD_PROMPT",
60
+ "AnalyzerSettings",
61
+ ]
app/safety.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, replace
4
+ from pathlib import Path
5
+ from typing import Dict, Optional
6
+ import time
7
+
8
+ import cv2
9
+ import numpy as np
10
+ from PIL import Image
11
+
12
+ from .config import IMAGE_EXTS
13
+ from .depth_pipeline import DepthEngine, compute_roof_mask_depth, crop_nonblack, pick_flat_patch, smooth_depth
14
+ from .segmentation import SegmenterRequest, SegmenterService
15
+ from .visualization import build_result_layers
16
+
17
+
18
+ @dataclass
19
+ class AnalysisRequest:
20
+ footprint_m: float
21
+ std_thresh: float
22
+ grad_thresh: float
23
+ use_water_mask: bool
24
+ use_road_mask: bool
25
+ use_roof_mask: bool
26
+ water_prompt: str
27
+ road_prompt: str
28
+ altitude_m: float
29
+ fov_deg: float
30
+ clearance_factor: float
31
+ process_res_cap: int
32
+ depth_smoothing_base: float
33
+ segmentation_max_side: int
34
+ segmentation_score_thresh: float
35
+ segmentation_mask_thresh: float
36
+ coverage_strictness: float
37
+ model_id: str
38
+ openness_weight: float
39
+ texture_threshold: float
40
+ source_path: Optional[str] = None
41
+
42
+
43
+ @dataclass
44
+ class AnalysisSummary:
45
+ model_id: str
46
+ process_resolution: int
47
+ runtime_ms: float
48
+ footprint_m: float
49
+ footprint_depth_px: int
50
+ footprint_image_px: int
51
+ landing_center_depth: tuple[int, int]
52
+ landing_center_image: tuple[int, int]
53
+ safe_area_pct: float
54
+ hazard_pct: float
55
+ water_mask_pct: Optional[float]
56
+ road_mask_pct: Optional[float]
57
+ roof_mask_pct: Optional[float]
58
+ water_mask_enabled: bool
59
+ road_mask_enabled: bool
60
+ roof_mask_enabled: bool
61
+ used_valid_center: bool
62
+ warnings: list[str]
63
+
64
+
65
+ @dataclass
66
+ class AnalysisResult:
67
+ images: Dict[str, Image.Image]
68
+ summary: AnalysisSummary
69
+
70
+
71
+ class SafetyAnalyzer:
72
+ def __init__(self, depth_engine: DepthEngine | None = None, segmenter: SegmenterService | None = None):
73
+ self.depth_engine = depth_engine or DepthEngine()
74
+ self.segmenter = segmenter or SegmenterService()
75
+
76
+ @staticmethod
77
+ def build_depth_roof_mask(
78
+ depth: np.ndarray,
79
+ grad_norm: np.ndarray,
80
+ footprint_px: int,
81
+ aggressiveness: float = 1.2,
82
+ grad_threshold: float = 0.35,
83
+ max_area_frac: float = 0.2,
84
+ ) -> np.ndarray | None:
85
+ depth_mask = compute_roof_mask_depth(
86
+ depth,
87
+ aggressiveness=aggressiveness,
88
+ morph_kernel=max(3, int(round(max(3, footprint_px * 0.15))) | 1),
89
+ )
90
+ flat_mask = grad_norm < grad_threshold
91
+ roof_mask = depth_mask & flat_mask
92
+ roof_mask = roof_mask.astype(np.uint8)
93
+ kernel = cv2.getStructuringElement(
94
+ cv2.MORPH_ELLIPSE,
95
+ (
96
+ max(3, int(round(footprint_px * 0.1)) | 1),
97
+ max(3, int(round(footprint_px * 0.1)) | 1),
98
+ ),
99
+ )
100
+ roof_mask = cv2.morphologyEx(roof_mask, cv2.MORPH_CLOSE, kernel)
101
+ roof_mask = cv2.morphologyEx(roof_mask, cv2.MORPH_OPEN, kernel)
102
+ area_thresh = max(footprint_px * footprint_px // 4, 64)
103
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(roof_mask, connectivity=8)
104
+ refined = np.zeros_like(roof_mask, dtype=bool)
105
+ max_area = max_area_frac * depth_mask.size if max_area_frac > 0 else None
106
+ for i in range(1, num_labels):
107
+ area = stats[i, cv2.CC_STAT_AREA]
108
+ if area < area_thresh:
109
+ continue
110
+ if max_area is not None and area > max_area:
111
+ # Skip overly large blobs (e.g., entire fields) to avoid over-masking
112
+ continue
113
+ refined |= labels == i
114
+ return refined if refined.any() else None
115
+
116
+ def analyze_image(self, image: Image.Image, request: AnalysisRequest) -> AnalysisResult:
117
+ t0 = time.perf_counter()
118
+ rgb_np = np.array(image)
119
+ depth_raw, depth, process_res = self.depth_engine.predict_depth(rgb_np, request.model_id, request.process_res_cap)
120
+ res_scale = max(0.5, min(2.5, process_res / 1024))
121
+ sigma = max(0.0, request.depth_smoothing_base) * res_scale
122
+ depth = smooth_depth(depth, sigma)
123
+
124
+ fov = max(10.0, min(170.0, float(request.fov_deg)))
125
+ altitude = max(1.0, float(request.altitude_m))
126
+ fx = (depth.shape[1] / 2.0) / np.tan(np.radians(fov) / 2.0)
127
+ patch_px = request.footprint_m * fx / altitude
128
+ patch_px = max(3, min(int(round(patch_px)), min(depth.shape) - 1))
129
+ if patch_px % 2 == 0:
130
+ patch_px += 1
131
+ half_span = patch_px // 2
132
+
133
+ depth_norm = (depth - depth.min()) / (np.ptp(depth) + 1e-6)
134
+ vis_patch = max(
135
+ 5,
136
+ min(
137
+ patch_px,
138
+ max(7, min(depth.shape) // 8),
139
+ min(depth.shape) - 1,
140
+ ),
141
+ )
142
+ if vis_patch % 2 == 0:
143
+ vis_patch += 1
144
+
145
+ import torch.nn.functional as F
146
+ import torch
147
+
148
+ def box_mean_np(arr: np.ndarray, k: int):
149
+ pad = k // 2
150
+ t = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0)
151
+ t = F.pad(t, (pad, pad, pad, pad), mode="reflect")
152
+ mean = F.avg_pool2d(t, kernel_size=k, stride=1, padding=0, count_include_pad=False)
153
+ return mean.squeeze(0).squeeze(0).numpy()
154
+
155
+ std_map_vis = np.sqrt(
156
+ np.maximum(box_mean_np(depth_norm * depth_norm, vis_patch) - box_mean_np(depth_norm, vis_patch) ** 2, 0.0)
157
+ )
158
+
159
+ gray = cv2.cvtColor(rgb_np, cv2.COLOR_RGB2GRAY).astype(np.float32) / 255.0
160
+ gx = cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3)
161
+ gy = cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3)
162
+ texture = np.sqrt(gx * gx + gy * gy)
163
+ sigma_tex = max(1.0, patch_px / 40.0)
164
+ texture = cv2.GaussianBlur(texture, (0, 0), sigmaX=sigma_tex, sigmaY=sigma_tex)
165
+ if texture.max() > texture.min():
166
+ texture_norm = (texture - texture.min()) / (np.ptp(texture) + 1e-6)
167
+ else:
168
+ texture_norm = np.zeros_like(texture)
169
+ texture_norm = cv2.resize(texture_norm, (depth.shape[1], depth.shape[0]), interpolation=cv2.INTER_LINEAR)
170
+
171
+ water_mask_resized = None
172
+ road_mask_resized = None
173
+ roof_mask_resized = None
174
+ water_mask_block = None
175
+ road_mask_block = None
176
+ roof_mask_block = None
177
+
178
+ def expand_mask_for_footprint(mask: np.ndarray | None) -> np.ndarray | None:
179
+ if mask is None:
180
+ return None
181
+ if patch_px <= 1:
182
+ return mask.copy()
183
+ try:
184
+ kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (patch_px, patch_px))
185
+ except Exception:
186
+ return mask.copy()
187
+ expanded = cv2.dilate(mask.astype(np.uint8), kernel, iterations=1)
188
+ return expanded.astype(bool)
189
+ if request.use_water_mask or request.use_road_mask:
190
+ masks = self.segmenter.get_masks(
191
+ SegmenterRequest(
192
+ image=image,
193
+ source_path=request.source_path,
194
+ want_water=request.use_water_mask,
195
+ want_road=request.use_road_mask,
196
+ max_side=int(max(128, request.segmentation_max_side)),
197
+ water_prompt=request.water_prompt,
198
+ road_prompt=request.road_prompt,
199
+ score_threshold=float(request.segmentation_score_thresh),
200
+ mask_threshold=float(request.segmentation_mask_thresh),
201
+ )
202
+ )
203
+ if request.use_water_mask and masks.get("water") is not None:
204
+ water_mask_resized = Image.fromarray(masks["water"].astype(np.uint8) * 255).resize(
205
+ (depth.shape[1], depth.shape[0]), resample=Image.NEAREST
206
+ )
207
+ water_mask_resized = np.array(water_mask_resized) > 0
208
+ water_mask_block = expand_mask_for_footprint(water_mask_resized)
209
+ if request.use_road_mask and masks.get("road") is not None:
210
+ road_mask_resized = Image.fromarray(masks["road"].astype(np.uint8) * 255).resize(
211
+ (depth.shape[1], depth.shape[0]), resample=Image.NEAREST
212
+ )
213
+ road_mask_resized = np.array(road_mask_resized) > 0
214
+ road_mask_block = expand_mask_for_footprint(road_mask_resized)
215
+
216
+ box, std_map, grad_norm, grad_mask, landing_mask = pick_flat_patch(
217
+ depth,
218
+ patch=patch_px,
219
+ std_thresh=request.std_thresh,
220
+ grad_thresh=request.grad_thresh,
221
+ water_mask=water_mask_block if water_mask_block is not None else water_mask_resized,
222
+ )
223
+ if request.use_roof_mask:
224
+ roof_mask_resized = self.build_depth_roof_mask(
225
+ depth=depth,
226
+ grad_norm=grad_norm,
227
+ footprint_px=patch_px,
228
+ max_area_frac=0.2,
229
+ )
230
+ roof_mask_block = expand_mask_for_footprint(roof_mask_resized)
231
+ seg_block_mask = None
232
+ for mask in (water_mask_block, road_mask_block, roof_mask_block):
233
+ if mask is None:
234
+ continue
235
+ if seg_block_mask is None:
236
+ seg_block_mask = mask.copy()
237
+ else:
238
+ seg_block_mask |= mask
239
+ if seg_block_mask is not None:
240
+ landing_mask = landing_mask & (~seg_block_mask)
241
+ if half_span > 0:
242
+ if (landing_mask.shape[0] > 2 * half_span) and (landing_mask.shape[1] > 2 * half_span):
243
+ interior_mask = np.zeros_like(landing_mask, dtype=bool)
244
+ interior_mask[
245
+ half_span : landing_mask.shape[0] - half_span,
246
+ half_span : landing_mask.shape[1] - half_span,
247
+ ] = True
248
+ else:
249
+ interior_mask = np.zeros_like(landing_mask, dtype=bool)
250
+ else:
251
+ interior_mask = np.ones_like(landing_mask, dtype=bool)
252
+ landing_mask = landing_mask & interior_mask
253
+ texture_mask = texture_norm <= max(0.0, min(1.0, request.texture_threshold))
254
+ safe_mask = (std_map < request.std_thresh) & (grad_norm < request.grad_thresh) & landing_mask & texture_mask
255
+
256
+ try:
257
+ clearance_px = max(1, int(round(request.clearance_factor * patch_px)))
258
+ if clearance_px % 2 == 0:
259
+ clearance_px += 1
260
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (clearance_px, clearance_px))
261
+ hazard = ~safe_mask
262
+ if seg_block_mask is not None:
263
+ hazard = hazard & (~seg_block_mask)
264
+ buffered = cv2.dilate(hazard.astype(np.uint8), kernel, iterations=1).astype(bool)
265
+ safe_mask = safe_mask & (~buffered)
266
+ if seg_block_mask is not None:
267
+ safe_mask = safe_mask & (~seg_block_mask)
268
+ except Exception:
269
+ pass
270
+
271
+ try:
272
+ coverage = cv2.boxFilter(
273
+ safe_mask.astype(np.float32),
274
+ ddepth=-1,
275
+ ksize=(patch_px, patch_px),
276
+ normalize=True,
277
+ anchor=(patch_px // 2, patch_px // 2),
278
+ )
279
+ safe_mask = coverage >= max(0.0, min(1.0, request.coverage_strictness))
280
+ except Exception:
281
+ pass
282
+
283
+ area_thresh = max(1, int(patch_px * patch_px))
284
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(safe_mask.astype(np.uint8), connectivity=8)
285
+ if num_labels > 1:
286
+ keep = np.zeros_like(labels, dtype=bool)
287
+ for i in range(1, num_labels):
288
+ if stats[i, cv2.CC_STAT_AREA] >= area_thresh:
289
+ keep |= labels == i
290
+ safe_mask = keep
291
+
292
+ risk_std = np.clip((std_map - request.std_thresh) / (request.std_thresh + 1e-6), 0.0, 1.0)
293
+ risk_grad = np.clip((grad_norm - request.grad_thresh) / (request.grad_thresh + 1e-6), 0.0, 1.0)
294
+ risk_map = np.maximum(risk_std, risk_grad) * (~safe_mask)
295
+
296
+ safe_fit = safe_mask.astype(np.float32)
297
+ safe_mask_uint = safe_mask.astype(np.uint8)
298
+ try:
299
+ distance = cv2.distanceTransform(safe_mask_uint, cv2.DIST_L2, 3)
300
+ except Exception:
301
+ distance = np.zeros_like(safe_fit)
302
+ try:
303
+ coverage = cv2.boxFilter(
304
+ safe_fit.astype(np.float32),
305
+ ddepth=-1,
306
+ ksize=(patch_px, patch_px),
307
+ normalize=True,
308
+ anchor=(patch_px // 2, patch_px // 2),
309
+ )
310
+ valid_centers = coverage >= 1.0
311
+ except Exception:
312
+ valid_centers = safe_fit > 0.5
313
+
314
+ used_valid_center = bool(valid_centers.any())
315
+ if used_valid_center:
316
+ cc_mask = valid_centers.astype(np.uint8)
317
+ num_c, labels_c, stats_c, _ = cv2.connectedComponentsWithStats(cc_mask, connectivity=8)
318
+ target_mask = valid_centers
319
+ if num_c > 1:
320
+ areas = stats_c[1:, cv2.CC_STAT_AREA]
321
+ largest_idx = 1 + int(np.argmax(areas))
322
+ target_mask = labels_c == largest_idx
323
+ cand = np.where(target_mask)
324
+ dist_cand = distance[cand]
325
+ std_cand = std_map[cand]
326
+ if dist_cand.size:
327
+ dist_norm = dist_cand / (dist_cand.max() + 1e-6)
328
+ std_norm = (std_cand - std_cand.min()) / (np.ptp(std_cand) + 1e-6)
329
+ weight = max(0.0, min(1.0, request.openness_weight))
330
+ score = dist_norm - weight * std_norm
331
+ idx = int(np.argmax(score))
332
+ else:
333
+ idx = int(np.argmin(std_cand))
334
+ cy, cx = cand[0][idx], cand[1][idx]
335
+ else:
336
+ fallback_mask = landing_mask.copy()
337
+ if not fallback_mask.any():
338
+ fallback_mask = np.ones_like(landing_mask, dtype=bool)
339
+ if seg_block_mask is not None:
340
+ fallback_mask &= (~seg_block_mask)
341
+ fallback_mask &= interior_mask
342
+ if fallback_mask.any():
343
+ cand = np.where(fallback_mask)
344
+ std_cand = std_map[cand]
345
+ idx = int(np.argmin(std_cand))
346
+ cy, cx = cand[0][idx], cand[1][idx]
347
+ else:
348
+ y0, x0, y1, x1 = box[1], box[0], box[3], box[2]
349
+ cy, cx = (y0 + y1) // 2, (x0 + x1) // 2
350
+ if half_span > 0 and depth.shape[0] > 2 * half_span:
351
+ cy = min(max(int(cy), half_span), depth.shape[0] - half_span - 1)
352
+ else:
353
+ cy = min(max(int(cy), 0), depth.shape[0] - 1)
354
+ if half_span > 0 and depth.shape[1] > 2 * half_span:
355
+ cx = min(max(int(cx), half_span), depth.shape[1] - half_span - 1)
356
+ else:
357
+ cx = min(max(int(cx), 0), depth.shape[1] - 1)
358
+
359
+ scale_x = image.width / depth.shape[1]
360
+ scale_y = image.height / depth.shape[0]
361
+ footprint_img_px = max(3, int(round(patch_px * scale_x)))
362
+ cx_img = int(round(cx * scale_x))
363
+ cy_img = int(round(cy * scale_y))
364
+ center_img = (cx_img, cy_img)
365
+ center_depth = (cx, cy)
366
+
367
+ safe_display_mask = safe_mask
368
+ try:
369
+ footprint_kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (patch_px, patch_px))
370
+ safe_display_mask = cv2.dilate(safe_mask.astype(np.uint8), footprint_kernel, iterations=1).astype(bool)
371
+ except Exception:
372
+ safe_display_mask = safe_mask
373
+ mask_union = None
374
+ for mask in (water_mask_resized, road_mask_resized, roof_mask_resized):
375
+ if mask is None:
376
+ continue
377
+ if mask_union is None:
378
+ mask_union = mask.copy()
379
+ else:
380
+ mask_union |= mask
381
+ seg_mask_union = mask_union.copy() if mask_union is not None else None
382
+ if mask_union is not None:
383
+ safe_display_mask = safe_display_mask & (~mask_union)
384
+ hazard_mask = ~safe_display_mask
385
+
386
+ layers = build_result_layers(
387
+ image=image,
388
+ depth_raw=depth_raw,
389
+ std_map_vis=std_map_vis,
390
+ grad_norm=grad_norm,
391
+ grad_thresh=request.grad_thresh,
392
+ safe_mask=safe_display_mask,
393
+ risk_map=risk_map,
394
+ footprint_img_px=footprint_img_px,
395
+ center_img=center_img,
396
+ water_mask=water_mask_resized,
397
+ road_mask=road_mask_resized,
398
+ roof_mask=roof_mask_resized,
399
+ seg_mask_union=seg_mask_union,
400
+ hazard_mask=hazard_mask,
401
+ )
402
+ runtime_ms = (time.perf_counter() - t0) * 1000.0
403
+ safe_area_pct = float(safe_display_mask.mean()) * 100.0
404
+ hazard_pct = 100.0 - safe_area_pct
405
+
406
+ def mask_pct(mask: np.ndarray | None) -> Optional[float]:
407
+ if mask is None:
408
+ return None
409
+ return float(mask.mean()) * 100.0
410
+
411
+ warnings: list[str] = []
412
+ if not safe_mask.any():
413
+ warnings.append("No regions satisfied safety thresholds; showing flattest candidate.")
414
+ if not request.use_water_mask:
415
+ warnings.append("Water mask disabled.")
416
+ elif water_mask_resized is None:
417
+ warnings.append("No water detected; continuing without a water mask.")
418
+ if not request.use_road_mask:
419
+ warnings.append("Road mask disabled.")
420
+ elif road_mask_resized is None:
421
+ warnings.append("Road segmentation unavailable; continuing without mask.")
422
+ if not request.use_roof_mask:
423
+ warnings.append("Roof mask disabled.")
424
+ elif roof_mask_resized is None:
425
+ warnings.append("Roof segmentation unavailable; continuing without mask.")
426
+
427
+ summary = AnalysisSummary(
428
+ model_id=request.model_id,
429
+ process_resolution=process_res,
430
+ runtime_ms=runtime_ms,
431
+ footprint_m=request.footprint_m,
432
+ footprint_depth_px=patch_px,
433
+ footprint_image_px=footprint_img_px,
434
+ landing_center_depth=center_depth,
435
+ landing_center_image=center_img,
436
+ safe_area_pct=safe_area_pct,
437
+ hazard_pct=hazard_pct,
438
+ water_mask_pct=mask_pct(water_mask_resized) if request.use_water_mask else None,
439
+ road_mask_pct=mask_pct(road_mask_resized) if request.use_road_mask else None,
440
+ roof_mask_pct=mask_pct(roof_mask_resized) if request.use_roof_mask else None,
441
+ water_mask_enabled=request.use_water_mask,
442
+ road_mask_enabled=request.use_road_mask,
443
+ roof_mask_enabled=request.use_roof_mask,
444
+ used_valid_center=used_valid_center,
445
+ warnings=warnings,
446
+ )
447
+ return AnalysisResult(images=layers, summary=summary)
448
+
449
+ def process_path(self, path: Path, request: AnalysisRequest) -> AnalysisResult:
450
+ if not path.exists():
451
+ raise ValueError(f"Input path not found: {path}")
452
+ if path.suffix.lower() not in IMAGE_EXTS:
453
+ raise ValueError(f"Unsupported image type for path: {path}")
454
+ image = crop_nonblack(Image.open(path).convert("RGB"))
455
+ request_with_source = replace(request, source_path=str(path))
456
+ return self.analyze_image(image, request_with_source)
457
+
458
+
459
+ def build_request(**kwargs) -> AnalysisRequest:
460
+ return AnalysisRequest(**kwargs)
461
+
462
+
463
+ __all__ = ["SafetyAnalyzer", "AnalysisRequest", "AnalysisResult", "AnalysisSummary", "build_request"]
app/segmentation.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Dict, Optional
5
+ import re
6
+
7
+ import numpy as np
8
+ import torch
9
+ from PIL import Image
10
+
11
+ from .config import (
12
+ ROAD_PROMPT,
13
+ SEGMENTATION_MASK_THRESH,
14
+ SEGMENTATION_MAX_SIDE,
15
+ SEGMENTATION_MODEL_ID,
16
+ SEGMENTATION_SCORE_THRESH,
17
+ WATER_PROMPT,
18
+ )
19
+
20
+
21
+ class SemanticSegmenter:
22
+ """Promptable segmenter backed by SAM3."""
23
+
24
+ def __init__(self, model_id: str):
25
+ import transformers # type: ignore
26
+
27
+ processor_cls = getattr(transformers, "Sam3Processor", None) or getattr(
28
+ transformers, "AutoProcessor", None
29
+ ) or getattr(transformers, "AutoImageProcessor", None)
30
+ model_cls = getattr(transformers, "Sam3Model", None) or getattr(
31
+ transformers, "AutoModelForMaskGeneration", None
32
+ )
33
+
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ processor = processor_cls.from_pretrained(model_id)
36
+ model = model_cls.from_pretrained(model_id)
37
+ try:
38
+ model = model.to(device)
39
+ except RuntimeError as exc:
40
+ # Fall back to CPU if the GPU move fails (e.g., OOM or missing device)
41
+ device = torch.device("cpu")
42
+ model = model.to(device)
43
+ print(f"[WARN] SAM3 fell back to CPU after .to(device) error: {exc}")
44
+ model.eval()
45
+ self.processor = processor
46
+ self.model = model
47
+ self.device = device
48
+ if torch.cuda.is_available() and self.device.type != "cuda":
49
+ print("[WARN] CUDA is available but SAM3 is running on CPU; mask generation will be slow.")
50
+ else:
51
+ print(f"[INFO] SAM3 loaded on {self.device}")
52
+
53
+ def segment(
54
+ self,
55
+ img: Image.Image,
56
+ max_side: int,
57
+ prompts: Dict[str, str],
58
+ score_threshold: float,
59
+ mask_threshold: float,
60
+ ) -> dict[str, np.ndarray]:
61
+ if not prompts:
62
+ return {}
63
+ orig_size = img.size # (W, H)
64
+ img_proc = img
65
+ if max(img.size) > max_side:
66
+ scale = max_side / max(img.size)
67
+ new_size = (max(1, int(round(img.size[0] * scale))), max(1, int(round(img.size[1] * scale))))
68
+ img_proc = img.resize(new_size, resample=Image.BILINEAR)
69
+
70
+ def _split_prompts(text: str) -> list[str]:
71
+ parts = [p.strip() for p in re.split(r"[;,\\n]", text) if p.strip()]
72
+ return parts if parts else ([text.strip()] if text.strip() else [])
73
+
74
+ masks: dict[str, np.ndarray] = {}
75
+ for key, prompt in prompts.items():
76
+ prompt_texts = _split_prompts(prompt or "")
77
+ if not prompt_texts:
78
+ continue
79
+ mask_union = None
80
+ for text in prompt_texts:
81
+ try:
82
+ inputs = self.processor(images=img_proc, text=text, return_tensors="pt").to(self.device)
83
+ except TypeError as exc:
84
+ raise ImportError(
85
+ "Loaded processor does not accept text prompts; install a transformers build with SAM3 text prompting support (e.g., pip install --upgrade transformers or a nightly that includes Sam3Processor)."
86
+ ) from exc
87
+ with torch.inference_mode():
88
+ outputs = self.model(**inputs)
89
+ results = self.processor.post_process_instance_segmentation(
90
+ outputs,
91
+ threshold=score_threshold,
92
+ mask_threshold=mask_threshold,
93
+ target_sizes=[(orig_size[1], orig_size[0])],
94
+ )[0]
95
+ inst_masks = results.get("masks")
96
+ if inst_masks is None or len(inst_masks) == 0:
97
+ continue
98
+ if torch.is_floating_point(inst_masks):
99
+ inst_masks = inst_masks > 0.5
100
+ mask_tensor = torch.any(inst_masks, dim=0)
101
+ mask_union = mask_tensor if mask_union is None else (mask_union | mask_tensor)
102
+ if mask_union is None:
103
+ continue
104
+ mask_np = mask_union.detach().cpu().numpy().astype(bool)
105
+ if mask_np.any():
106
+ masks[key] = mask_np
107
+ return masks
108
+
109
+
110
+ @dataclass
111
+ class SegmenterRequest:
112
+ image: Image.Image
113
+ source_path: Optional[str] = None
114
+ want_water: bool = False
115
+ want_road: bool = False
116
+ max_side: int = SEGMENTATION_MAX_SIDE
117
+ water_prompt: str = WATER_PROMPT
118
+ road_prompt: str = ROAD_PROMPT
119
+ score_threshold: float = SEGMENTATION_SCORE_THRESH
120
+ mask_threshold: float = SEGMENTATION_MASK_THRESH
121
+
122
+
123
+ class SegmenterService:
124
+ """Caches segmenters and mask outputs across UI interactions."""
125
+
126
+ def __init__(self, model_id: str = SEGMENTATION_MODEL_ID):
127
+ self.model_id = model_id
128
+ self._segmenters: Dict[str, SemanticSegmenter] = {}
129
+ self._mask_cache: Dict[tuple[str, str, int], dict[str, np.ndarray]] = {}
130
+
131
+ def _get_segmenter(self, model_id: str) -> SemanticSegmenter:
132
+ if model_id not in self._segmenters:
133
+ self._segmenters[model_id] = SemanticSegmenter(model_id)
134
+ return self._segmenters[model_id]
135
+
136
+ def get_masks(self, request: SegmenterRequest) -> dict[str, np.ndarray]:
137
+ if not (request.want_water or request.want_road):
138
+ return {}
139
+ key = (
140
+ self.model_id,
141
+ request.source_path or "",
142
+ request.max_side,
143
+ (request.water_prompt or "").strip(),
144
+ (request.road_prompt or "").strip(),
145
+ float(request.score_threshold),
146
+ float(request.mask_threshold),
147
+ )
148
+ masks = self._mask_cache.get(key)
149
+ if masks is None:
150
+ segmenter = self._get_segmenter(self.model_id)
151
+ prompts: dict[str, str] = {}
152
+ if request.want_water and request.water_prompt:
153
+ prompts["water"] = request.water_prompt
154
+ if request.want_road and request.road_prompt:
155
+ prompts["road"] = request.road_prompt
156
+ try:
157
+ masks = segmenter.segment(
158
+ request.image,
159
+ request.max_side,
160
+ prompts=prompts,
161
+ score_threshold=float(request.score_threshold),
162
+ mask_threshold=float(request.mask_threshold),
163
+ )
164
+ except RuntimeError as exc:
165
+ print(f"[WARN] Segmentation failed; skipping masks: {exc}")
166
+ masks = {}
167
+ if request.source_path and masks:
168
+ self._mask_cache[key] = masks
169
+ result: dict[str, np.ndarray] = {}
170
+ if request.want_water and masks.get("water") is not None:
171
+ result["water"] = masks["water"]
172
+ if request.want_road and masks.get("road") is not None:
173
+ result["road"] = masks["road"]
174
+ return result
175
+
176
+
177
+ __all__ = ["SegmenterService", "SegmenterRequest", "SemanticSegmenter"]