yakvrz commited on
Commit
af8f4ba
·
1 Parent(s): ba15bb2

Disable mask caching and adjust defaults

Browse files
Files changed (3) hide show
  1. ARCHITECTURE.md +2 -2
  2. app/config.py +2 -2
  3. app/segmentation.py +17 -31
ARCHITECTURE.md CHANGED
@@ -5,7 +5,7 @@ This document describes the flow in the current Gradio app (`app/ui.py`), from i
5
  ## Data and Models
6
  - **Inputs**: Images from `data/Image/VISLOC` (populated via `list_all_data_inputs`) with a 5% border crop (`crop_nonblack`) to drop black padding. Supported extensions: jpg/jpeg/png (any case).
7
  - **Depth model**: Depth Anything 3, cached per model id (`DepthEngine`). Inference caps the long side to `process_res_cap` (default 1024) using `upper_bound_resize` before predicting.
8
- - **Segmentation model**: SAM3 (`facebook/sam3`) for promptable water/road masking. Loaded once per model id; masks are cached per `(model_id, source_path, prompts, thresholds, max_side)`. Default `segmentation_max_side` is 384 to keep it fast on CUDA.
9
 
10
  ## Constants and Defaults
11
  - Altitude/FOV defaults: 450 m, 90°.
@@ -56,7 +56,7 @@ This document describes the flow in the current Gradio app (`app/ui.py`), from i
56
 
57
  ## Caching and State
58
  - Depth model cache keyed by model id (`DepthEngine`).
59
- - SAM3 cache keyed by model id + source + prompts + thresholds + max_side (`SegmenterService`).
60
  - `images_state` holds the latest rendered layers; overlay-only changes don’t rerun inference. Prompt changes only re-trigger processing on submit/Run, not every keystroke.
61
 
62
  ## User Controls and Effects
 
5
  ## Data and Models
6
  - **Inputs**: Images from `data/Image/VISLOC` (populated via `list_all_data_inputs`) with a 5% border crop (`crop_nonblack`) to drop black padding. Supported extensions: jpg/jpeg/png (any case).
7
  - **Depth model**: Depth Anything 3, cached per model id (`DepthEngine`). Inference caps the long side to `process_res_cap` (default 1024) using `upper_bound_resize` before predicting.
8
+ - **Segmentation model**: SAM3 (`facebook/sam3`) for promptable water/road masking. Loaded once per model id; masks are recomputed every run (no caching). Default `segmentation_max_side` is 384 to keep it fast on CUDA.
9
 
10
  ## Constants and Defaults
11
  - Altitude/FOV defaults: 450 m, 90°.
 
56
 
57
  ## Caching and State
58
  - Depth model cache keyed by model id (`DepthEngine`).
59
+ - SAM3 masks are not cached; every run recomputes them to reflect real-time cost. Segmenter models stay loaded per id.
60
  - `images_state` holds the latest rendered layers; overlay-only changes don’t rerun inference. Prompt changes only re-trigger processing on submit/Run, not every keystroke.
61
 
62
  ## User Controls and Effects
app/config.py CHANGED
@@ -28,7 +28,7 @@ class AnalyzerSettings:
28
  grad_thresh: float = 0.1
29
  clearance_factor: float = 0.0
30
  process_res_cap: int = 1024
31
- depth_smoothing_base: float = 0.8
32
  segmentation_max_side: int = SEGMENTATION_MAX_SIDE
33
  segmentation_score_thresh: float = SEGMENTATION_SCORE_THRESH
34
  segmentation_mask_thresh: float = SEGMENTATION_MASK_THRESH
@@ -36,7 +36,7 @@ class AnalyzerSettings:
36
  road_prompt: str = ROAD_PROMPT
37
  coverage_strictness: float = 0.95
38
  openness_weight: float = 0.3
39
- texture_threshold: float = 0.5
40
  altitude_m: float = DEFAULT_ALTITUDE_M
41
  fov_deg: float = ASSUMED_FOV_DEG
42
  model_id: str = DEFAULT_MODEL_ID
 
28
  grad_thresh: float = 0.1
29
  clearance_factor: float = 0.0
30
  process_res_cap: int = 1024
31
+ depth_smoothing_base: float = 0.0
32
  segmentation_max_side: int = SEGMENTATION_MAX_SIDE
33
  segmentation_score_thresh: float = SEGMENTATION_SCORE_THRESH
34
  segmentation_mask_thresh: float = SEGMENTATION_MASK_THRESH
 
36
  road_prompt: str = ROAD_PROMPT
37
  coverage_strictness: float = 0.95
38
  openness_weight: float = 0.3
39
+ texture_threshold: float = 0.1
40
  altitude_m: float = DEFAULT_ALTITUDE_M
41
  fov_deg: float = ASSUMED_FOV_DEG
42
  model_id: str = DEFAULT_MODEL_ID
app/segmentation.py CHANGED
@@ -126,7 +126,6 @@ class SegmenterService:
126
  def __init__(self, model_id: str = SEGMENTATION_MODEL_ID):
127
  self.model_id = model_id
128
  self._segmenters: Dict[str, SemanticSegmenter] = {}
129
- self._mask_cache: Dict[tuple[str, str, int], dict[str, np.ndarray]] = {}
130
 
131
  def _get_segmenter(self, model_id: str) -> SemanticSegmenter:
132
  if model_id not in self._segmenters:
@@ -136,36 +135,23 @@ class SegmenterService:
136
  def get_masks(self, request: SegmenterRequest) -> dict[str, np.ndarray]:
137
  if not (request.want_water or request.want_road):
138
  return {}
139
- key = (
140
- self.model_id,
141
- request.source_path or "",
142
- request.max_side,
143
- (request.water_prompt or "").strip(),
144
- (request.road_prompt or "").strip(),
145
- float(request.score_threshold),
146
- float(request.mask_threshold),
147
- )
148
- masks = self._mask_cache.get(key)
149
- if masks is None:
150
- segmenter = self._get_segmenter(self.model_id)
151
- prompts: dict[str, str] = {}
152
- if request.want_water and request.water_prompt:
153
- prompts["water"] = request.water_prompt
154
- if request.want_road and request.road_prompt:
155
- prompts["road"] = request.road_prompt
156
- try:
157
- masks = segmenter.segment(
158
- request.image,
159
- request.max_side,
160
- prompts=prompts,
161
- score_threshold=float(request.score_threshold),
162
- mask_threshold=float(request.mask_threshold),
163
- )
164
- except RuntimeError as exc:
165
- print(f"[WARN] Segmentation failed; skipping masks: {exc}")
166
- masks = {}
167
- if request.source_path and masks:
168
- self._mask_cache[key] = masks
169
  result: dict[str, np.ndarray] = {}
170
  if request.want_water and masks.get("water") is not None:
171
  result["water"] = masks["water"]
 
126
  def __init__(self, model_id: str = SEGMENTATION_MODEL_ID):
127
  self.model_id = model_id
128
  self._segmenters: Dict[str, SemanticSegmenter] = {}
 
129
 
130
  def _get_segmenter(self, model_id: str) -> SemanticSegmenter:
131
  if model_id not in self._segmenters:
 
135
  def get_masks(self, request: SegmenterRequest) -> dict[str, np.ndarray]:
136
  if not (request.want_water or request.want_road):
137
  return {}
138
+ segmenter = self._get_segmenter(self.model_id)
139
+ prompts: dict[str, str] = {}
140
+ if request.want_water and request.water_prompt:
141
+ prompts["water"] = request.water_prompt
142
+ if request.want_road and request.road_prompt:
143
+ prompts["road"] = request.road_prompt
144
+ try:
145
+ masks = segmenter.segment(
146
+ request.image,
147
+ request.max_side,
148
+ prompts=prompts,
149
+ score_threshold=float(request.score_threshold),
150
+ mask_threshold=float(request.mask_threshold),
151
+ )
152
+ except RuntimeError as exc:
153
+ print(f"[WARN] Segmentation failed; skipping masks: {exc}")
154
+ masks = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  result: dict[str, np.ndarray] = {}
156
  if request.want_water and masks.get("water") is not None:
157
  result["water"] = masks["water"]