from __future__ import annotations from dataclasses import dataclass from typing import Dict, Optional import re import numpy as np import torch from PIL import Image from .config import ( ROAD_PROMPT, ROOF_PROMPT, SEGMENTATION_MASK_THRESH, SEGMENTATION_MAX_SIDE, SEGMENTATION_MODEL_ID, SEGMENTATION_SCORE_THRESH, WATER_PROMPT, TREE_PROMPT, ) class SemanticSegmenter: """Promptable segmenter backed by SAM3.""" def __init__(self, model_id: str): import transformers # type: ignore from transformers.utils import logging as hf_logging # type: ignore hf_logging.set_verbosity_error() try: hf_logging.disable_progress_bar() except Exception: pass processor_cls = getattr(transformers, "Sam3Processor", None) or getattr( transformers, "AutoProcessor", None ) or getattr(transformers, "AutoImageProcessor", None) model_cls = getattr(transformers, "Sam3Model", None) or getattr( transformers, "AutoModelForMaskGeneration", None ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") processor = processor_cls.from_pretrained(model_id) model = model_cls.from_pretrained(model_id) try: model = model.to(device) except RuntimeError as exc: # Fall back to CPU if the GPU move fails (e.g., OOM or missing device) device = torch.device("cpu") model = model.to(device) print(f"[WARN] SAM3 fell back to CPU after .to(device) error: {exc}") model.eval() self.processor = processor self.model = model self.device = device if torch.cuda.is_available() and self.device.type != "cuda": print("[WARN] CUDA is available but SAM3 is running on CPU; mask generation will be slow.") else: print(f"[INFO] SAM3 loaded on {self.device}") def segment( self, img: Image.Image, max_side: int, prompts: Dict[str, str], score_threshold: float, mask_threshold: float, ) -> dict[str, np.ndarray]: if not prompts: return {} orig_size = img.size # (W, H) img_proc = img if max(img.size) > max_side: scale = max_side / max(img.size) new_size = (max(1, int(round(img.size[0] * scale))), max(1, int(round(img.size[1] * scale)))) img_proc = img.resize(new_size, resample=Image.BILINEAR) def _split_prompts(text: str) -> list[str]: parts = [p.strip() for p in re.split(r"[;,\n]", text) if p.strip()] return parts if parts else ([text.strip()] if text.strip() else []) masks: dict[str, np.ndarray] = {} for key, prompt in prompts.items(): prompt_texts = _split_prompts(prompt or "") if not prompt_texts: continue mask_union = None for text in prompt_texts: try: inputs = self.processor(images=img_proc, text=text, return_tensors="pt").to(self.device) except TypeError as exc: raise ImportError( "Loaded processor does not accept text prompts; install a transformers build with SAM3 text prompting support (e.g., pip install --upgrade transformers or a nightly that includes Sam3Processor)." ) from exc with torch.inference_mode(): outputs = self.model(**inputs) results = self.processor.post_process_instance_segmentation( outputs, threshold=score_threshold, mask_threshold=mask_threshold, target_sizes=[(orig_size[1], orig_size[0])], )[0] inst_masks = results.get("masks") if inst_masks is None or len(inst_masks) == 0: continue if torch.is_floating_point(inst_masks): inst_masks = inst_masks > 0.5 mask_tensor = torch.any(inst_masks, dim=0) mask_union = mask_tensor if mask_union is None else (mask_union | mask_tensor) if mask_union is None: continue mask_np = mask_union.detach().cpu().numpy().astype(bool) if mask_np.any(): masks[key] = mask_np return masks @dataclass class SegmenterRequest: image: Image.Image source_path: Optional[str] = None want_water: bool = False want_road: bool = False want_roof: bool = False want_tree: bool = False max_side: int = SEGMENTATION_MAX_SIDE water_prompt: str = WATER_PROMPT road_prompt: str = ROAD_PROMPT roof_prompt: str = ROOF_PROMPT tree_prompt: str = TREE_PROMPT score_threshold: float = SEGMENTATION_SCORE_THRESH mask_threshold: float = SEGMENTATION_MASK_THRESH class SegmenterService: """Caches segmenters and mask outputs across UI interactions.""" def __init__(self, model_id: str = SEGMENTATION_MODEL_ID): self.model_id = model_id self._segmenters: Dict[str, SemanticSegmenter] = {} # Eagerly load the default model once to avoid repeated cold-starts. try: self._segmenters[model_id] = SemanticSegmenter(model_id) except Exception as exc: print(f"[WARN] Failed to preload segmentation model {model_id}: {exc}") def _get_segmenter(self, model_id: str) -> SemanticSegmenter: if model_id not in self._segmenters: self._segmenters[model_id] = SemanticSegmenter(model_id) return self._segmenters[model_id] def get_masks(self, request: SegmenterRequest, model_id: str | None = None) -> dict[str, np.ndarray]: if not (request.want_water or request.want_road or request.want_tree or request.want_roof): return {} segmenter = self._get_segmenter(model_id or self.model_id) prompts: dict[str, str] = {} if request.want_water and request.water_prompt: prompts["water"] = request.water_prompt if request.want_road and request.road_prompt: prompts["road"] = request.road_prompt if request.want_roof and request.roof_prompt: prompts["roof"] = request.roof_prompt if request.want_tree and request.tree_prompt: prompts["tree"] = request.tree_prompt try: masks = segmenter.segment( request.image, request.max_side, prompts=prompts, score_threshold=float(request.score_threshold), mask_threshold=float(request.mask_threshold), ) except RuntimeError as exc: print(f"[WARN] Segmentation failed; skipping masks: {exc}") masks = {} result: dict[str, np.ndarray] = {} if request.want_water and masks.get("water") is not None: result["water"] = masks["water"] if request.want_road and masks.get("road") is not None: result["road"] = masks["road"] if request.want_roof and masks.get("roof") is not None: result["roof"] = masks["roof"] if request.want_tree and masks.get("tree") is not None: result["tree"] = masks["tree"] return result __all__ = ["SegmenterService", "SegmenterRequest", "SemanticSegmenter"] # Shared singleton to avoid reloads across analyzer instances _GLOBAL_SEGMENTER: SegmenterService | None = None def get_global_segmenter(default_model_id: str = SEGMENTATION_MODEL_ID) -> SegmenterService: global _GLOBAL_SEGMENTER if _GLOBAL_SEGMENTER is None or _GLOBAL_SEGMENTER.model_id != default_model_id: _GLOBAL_SEGMENTER = SegmenterService(default_model_id) return _GLOBAL_SEGMENTER