drone-landing-safety / app /segmentation.py
yakvrz's picture
Switch rooftop masking to SAM3 and refresh demos
c5794e7
raw
history blame
7.91 kB
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Optional
import re
import numpy as np
import torch
from PIL import Image
from .config import (
ROAD_PROMPT,
ROOF_PROMPT,
SEGMENTATION_MASK_THRESH,
SEGMENTATION_MAX_SIDE,
SEGMENTATION_MODEL_ID,
SEGMENTATION_SCORE_THRESH,
WATER_PROMPT,
TREE_PROMPT,
)
class SemanticSegmenter:
"""Promptable segmenter backed by SAM3."""
def __init__(self, model_id: str):
import transformers # type: ignore
from transformers.utils import logging as hf_logging # type: ignore
hf_logging.set_verbosity_error()
try:
hf_logging.disable_progress_bar()
except Exception:
pass
processor_cls = getattr(transformers, "Sam3Processor", None) or getattr(
transformers, "AutoProcessor", None
) or getattr(transformers, "AutoImageProcessor", None)
model_cls = getattr(transformers, "Sam3Model", None) or getattr(
transformers, "AutoModelForMaskGeneration", None
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = processor_cls.from_pretrained(model_id)
model = model_cls.from_pretrained(model_id)
try:
model = model.to(device)
except RuntimeError as exc:
# Fall back to CPU if the GPU move fails (e.g., OOM or missing device)
device = torch.device("cpu")
model = model.to(device)
print(f"[WARN] SAM3 fell back to CPU after .to(device) error: {exc}")
model.eval()
self.processor = processor
self.model = model
self.device = device
if torch.cuda.is_available() and self.device.type != "cuda":
print("[WARN] CUDA is available but SAM3 is running on CPU; mask generation will be slow.")
else:
print(f"[INFO] SAM3 loaded on {self.device}")
def segment(
self,
img: Image.Image,
max_side: int,
prompts: Dict[str, str],
score_threshold: float,
mask_threshold: float,
) -> dict[str, np.ndarray]:
if not prompts:
return {}
orig_size = img.size # (W, H)
img_proc = img
if max(img.size) > max_side:
scale = max_side / max(img.size)
new_size = (max(1, int(round(img.size[0] * scale))), max(1, int(round(img.size[1] * scale))))
img_proc = img.resize(new_size, resample=Image.BILINEAR)
def _split_prompts(text: str) -> list[str]:
parts = [p.strip() for p in re.split(r"[;,\n]", text) if p.strip()]
return parts if parts else ([text.strip()] if text.strip() else [])
masks: dict[str, np.ndarray] = {}
for key, prompt in prompts.items():
prompt_texts = _split_prompts(prompt or "")
if not prompt_texts:
continue
mask_union = None
for text in prompt_texts:
try:
inputs = self.processor(images=img_proc, text=text, return_tensors="pt").to(self.device)
except TypeError as exc:
raise ImportError(
"Loaded processor does not accept text prompts; install a transformers build with SAM3 text prompting support (e.g., pip install --upgrade transformers or a nightly that includes Sam3Processor)."
) from exc
with torch.inference_mode():
outputs = self.model(**inputs)
results = self.processor.post_process_instance_segmentation(
outputs,
threshold=score_threshold,
mask_threshold=mask_threshold,
target_sizes=[(orig_size[1], orig_size[0])],
)[0]
inst_masks = results.get("masks")
if inst_masks is None or len(inst_masks) == 0:
continue
if torch.is_floating_point(inst_masks):
inst_masks = inst_masks > 0.5
mask_tensor = torch.any(inst_masks, dim=0)
mask_union = mask_tensor if mask_union is None else (mask_union | mask_tensor)
if mask_union is None:
continue
mask_np = mask_union.detach().cpu().numpy().astype(bool)
if mask_np.any():
masks[key] = mask_np
return masks
@dataclass
class SegmenterRequest:
image: Image.Image
source_path: Optional[str] = None
want_water: bool = False
want_road: bool = False
want_roof: bool = False
want_tree: bool = False
max_side: int = SEGMENTATION_MAX_SIDE
water_prompt: str = WATER_PROMPT
road_prompt: str = ROAD_PROMPT
roof_prompt: str = ROOF_PROMPT
tree_prompt: str = TREE_PROMPT
score_threshold: float = SEGMENTATION_SCORE_THRESH
mask_threshold: float = SEGMENTATION_MASK_THRESH
class SegmenterService:
"""Caches segmenters and mask outputs across UI interactions."""
def __init__(self, model_id: str = SEGMENTATION_MODEL_ID):
self.model_id = model_id
self._segmenters: Dict[str, SemanticSegmenter] = {}
# Eagerly load the default model once to avoid repeated cold-starts.
try:
self._segmenters[model_id] = SemanticSegmenter(model_id)
except Exception as exc:
print(f"[WARN] Failed to preload segmentation model {model_id}: {exc}")
def _get_segmenter(self, model_id: str) -> SemanticSegmenter:
if model_id not in self._segmenters:
self._segmenters[model_id] = SemanticSegmenter(model_id)
return self._segmenters[model_id]
def get_masks(self, request: SegmenterRequest, model_id: str | None = None) -> dict[str, np.ndarray]:
if not (request.want_water or request.want_road or request.want_tree or request.want_roof):
return {}
segmenter = self._get_segmenter(model_id or self.model_id)
prompts: dict[str, str] = {}
if request.want_water and request.water_prompt:
prompts["water"] = request.water_prompt
if request.want_road and request.road_prompt:
prompts["road"] = request.road_prompt
if request.want_roof and request.roof_prompt:
prompts["roof"] = request.roof_prompt
if request.want_tree and request.tree_prompt:
prompts["tree"] = request.tree_prompt
try:
masks = segmenter.segment(
request.image,
request.max_side,
prompts=prompts,
score_threshold=float(request.score_threshold),
mask_threshold=float(request.mask_threshold),
)
except RuntimeError as exc:
print(f"[WARN] Segmentation failed; skipping masks: {exc}")
masks = {}
result: dict[str, np.ndarray] = {}
if request.want_water and masks.get("water") is not None:
result["water"] = masks["water"]
if request.want_road and masks.get("road") is not None:
result["road"] = masks["road"]
if request.want_roof and masks.get("roof") is not None:
result["roof"] = masks["roof"]
if request.want_tree and masks.get("tree") is not None:
result["tree"] = masks["tree"]
return result
__all__ = ["SegmenterService", "SegmenterRequest", "SemanticSegmenter"]
# Shared singleton to avoid reloads across analyzer instances
_GLOBAL_SEGMENTER: SegmenterService | None = None
def get_global_segmenter(default_model_id: str = SEGMENTATION_MODEL_ID) -> SegmenterService:
global _GLOBAL_SEGMENTER
if _GLOBAL_SEGMENTER is None or _GLOBAL_SEGMENTER.model_id != default_model_id:
_GLOBAL_SEGMENTER = SegmenterService(default_model_id)
return _GLOBAL_SEGMENTER