Spaces:
Runtime error
Runtime error
| from __future__ import annotations | |
| from dataclasses import dataclass | |
| from typing import Dict, Optional | |
| import re | |
| import numpy as np | |
| import torch | |
| from PIL import Image | |
| from .config import ( | |
| ROAD_PROMPT, | |
| ROOF_PROMPT, | |
| SEGMENTATION_MASK_THRESH, | |
| SEGMENTATION_MAX_SIDE, | |
| SEGMENTATION_MODEL_ID, | |
| SEGMENTATION_SCORE_THRESH, | |
| WATER_PROMPT, | |
| TREE_PROMPT, | |
| ) | |
| class SemanticSegmenter: | |
| """Promptable segmenter backed by SAM3.""" | |
| def __init__(self, model_id: str): | |
| import transformers # type: ignore | |
| from transformers.utils import logging as hf_logging # type: ignore | |
| hf_logging.set_verbosity_error() | |
| try: | |
| hf_logging.disable_progress_bar() | |
| except Exception: | |
| pass | |
| processor_cls = getattr(transformers, "Sam3Processor", None) or getattr( | |
| transformers, "AutoProcessor", None | |
| ) or getattr(transformers, "AutoImageProcessor", None) | |
| model_cls = getattr(transformers, "Sam3Model", None) or getattr( | |
| transformers, "AutoModelForMaskGeneration", None | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| processor = processor_cls.from_pretrained(model_id) | |
| model = model_cls.from_pretrained(model_id) | |
| try: | |
| model = model.to(device) | |
| except RuntimeError as exc: | |
| # Fall back to CPU if the GPU move fails (e.g., OOM or missing device) | |
| device = torch.device("cpu") | |
| model = model.to(device) | |
| print(f"[WARN] SAM3 fell back to CPU after .to(device) error: {exc}") | |
| model.eval() | |
| self.processor = processor | |
| self.model = model | |
| self.device = device | |
| if torch.cuda.is_available() and self.device.type != "cuda": | |
| print("[WARN] CUDA is available but SAM3 is running on CPU; mask generation will be slow.") | |
| else: | |
| print(f"[INFO] SAM3 loaded on {self.device}") | |
| def segment( | |
| self, | |
| img: Image.Image, | |
| max_side: int, | |
| prompts: Dict[str, str], | |
| score_threshold: float, | |
| mask_threshold: float, | |
| ) -> dict[str, np.ndarray]: | |
| if not prompts: | |
| return {} | |
| orig_size = img.size # (W, H) | |
| img_proc = img | |
| if max(img.size) > max_side: | |
| scale = max_side / max(img.size) | |
| new_size = (max(1, int(round(img.size[0] * scale))), max(1, int(round(img.size[1] * scale)))) | |
| img_proc = img.resize(new_size, resample=Image.BILINEAR) | |
| def _split_prompts(text: str) -> list[str]: | |
| parts = [p.strip() for p in re.split(r"[;,\n]", text) if p.strip()] | |
| return parts if parts else ([text.strip()] if text.strip() else []) | |
| masks: dict[str, np.ndarray] = {} | |
| for key, prompt in prompts.items(): | |
| prompt_texts = _split_prompts(prompt or "") | |
| if not prompt_texts: | |
| continue | |
| mask_union = None | |
| for text in prompt_texts: | |
| try: | |
| inputs = self.processor(images=img_proc, text=text, return_tensors="pt").to(self.device) | |
| except TypeError as exc: | |
| raise ImportError( | |
| "Loaded processor does not accept text prompts; install a transformers build with SAM3 text prompting support (e.g., pip install --upgrade transformers or a nightly that includes Sam3Processor)." | |
| ) from exc | |
| with torch.inference_mode(): | |
| outputs = self.model(**inputs) | |
| results = self.processor.post_process_instance_segmentation( | |
| outputs, | |
| threshold=score_threshold, | |
| mask_threshold=mask_threshold, | |
| target_sizes=[(orig_size[1], orig_size[0])], | |
| )[0] | |
| inst_masks = results.get("masks") | |
| if inst_masks is None or len(inst_masks) == 0: | |
| continue | |
| if torch.is_floating_point(inst_masks): | |
| inst_masks = inst_masks > 0.5 | |
| mask_tensor = torch.any(inst_masks, dim=0) | |
| mask_union = mask_tensor if mask_union is None else (mask_union | mask_tensor) | |
| if mask_union is None: | |
| continue | |
| mask_np = mask_union.detach().cpu().numpy().astype(bool) | |
| if mask_np.any(): | |
| masks[key] = mask_np | |
| return masks | |
| class SegmenterRequest: | |
| image: Image.Image | |
| source_path: Optional[str] = None | |
| want_water: bool = False | |
| want_road: bool = False | |
| want_roof: bool = False | |
| want_tree: bool = False | |
| max_side: int = SEGMENTATION_MAX_SIDE | |
| water_prompt: str = WATER_PROMPT | |
| road_prompt: str = ROAD_PROMPT | |
| roof_prompt: str = ROOF_PROMPT | |
| tree_prompt: str = TREE_PROMPT | |
| score_threshold: float = SEGMENTATION_SCORE_THRESH | |
| mask_threshold: float = SEGMENTATION_MASK_THRESH | |
| class SegmenterService: | |
| """Caches segmenters and mask outputs across UI interactions.""" | |
| def __init__(self, model_id: str = SEGMENTATION_MODEL_ID): | |
| self.model_id = model_id | |
| self._segmenters: Dict[str, SemanticSegmenter] = {} | |
| # Eagerly load the default model once to avoid repeated cold-starts. | |
| try: | |
| self._segmenters[model_id] = SemanticSegmenter(model_id) | |
| except Exception as exc: | |
| print(f"[WARN] Failed to preload segmentation model {model_id}: {exc}") | |
| def _get_segmenter(self, model_id: str) -> SemanticSegmenter: | |
| if model_id not in self._segmenters: | |
| self._segmenters[model_id] = SemanticSegmenter(model_id) | |
| return self._segmenters[model_id] | |
| def get_masks(self, request: SegmenterRequest, model_id: str | None = None) -> dict[str, np.ndarray]: | |
| if not (request.want_water or request.want_road or request.want_tree or request.want_roof): | |
| return {} | |
| segmenter = self._get_segmenter(model_id or self.model_id) | |
| prompts: dict[str, str] = {} | |
| if request.want_water and request.water_prompt: | |
| prompts["water"] = request.water_prompt | |
| if request.want_road and request.road_prompt: | |
| prompts["road"] = request.road_prompt | |
| if request.want_roof and request.roof_prompt: | |
| prompts["roof"] = request.roof_prompt | |
| if request.want_tree and request.tree_prompt: | |
| prompts["tree"] = request.tree_prompt | |
| try: | |
| masks = segmenter.segment( | |
| request.image, | |
| request.max_side, | |
| prompts=prompts, | |
| score_threshold=float(request.score_threshold), | |
| mask_threshold=float(request.mask_threshold), | |
| ) | |
| except RuntimeError as exc: | |
| print(f"[WARN] Segmentation failed; skipping masks: {exc}") | |
| masks = {} | |
| result: dict[str, np.ndarray] = {} | |
| if request.want_water and masks.get("water") is not None: | |
| result["water"] = masks["water"] | |
| if request.want_road and masks.get("road") is not None: | |
| result["road"] = masks["road"] | |
| if request.want_roof and masks.get("roof") is not None: | |
| result["roof"] = masks["roof"] | |
| if request.want_tree and masks.get("tree") is not None: | |
| result["tree"] = masks["tree"] | |
| return result | |
| __all__ = ["SegmenterService", "SegmenterRequest", "SemanticSegmenter"] | |
| # Shared singleton to avoid reloads across analyzer instances | |
| _GLOBAL_SEGMENTER: SegmenterService | None = None | |
| def get_global_segmenter(default_model_id: str = SEGMENTATION_MODEL_ID) -> SegmenterService: | |
| global _GLOBAL_SEGMENTER | |
| if _GLOBAL_SEGMENTER is None or _GLOBAL_SEGMENTER.model_id != default_model_id: | |
| _GLOBAL_SEGMENTER = SegmenterService(default_model_id) | |
| return _GLOBAL_SEGMENTER | |