Spaces:
Runtime error
Runtime error
File size: 7,905 Bytes
deeabb9 c5794e7 deeabb9 c5794e7 deeabb9 bcfd69e deeabb9 bcfd69e deeabb9 a250f2c deeabb9 c5794e7 bcfd69e deeabb9 c5794e7 bcfd69e deeabb9 bcfd69e deeabb9 bcfd69e c5794e7 deeabb9 bcfd69e af8f4ba c5794e7 bcfd69e af8f4ba deeabb9 c5794e7 a250f2c deeabb9 bcfd69e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 |
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Optional
import re
import numpy as np
import torch
from PIL import Image
from .config import (
ROAD_PROMPT,
ROOF_PROMPT,
SEGMENTATION_MASK_THRESH,
SEGMENTATION_MAX_SIDE,
SEGMENTATION_MODEL_ID,
SEGMENTATION_SCORE_THRESH,
WATER_PROMPT,
TREE_PROMPT,
)
class SemanticSegmenter:
"""Promptable segmenter backed by SAM3."""
def __init__(self, model_id: str):
import transformers # type: ignore
from transformers.utils import logging as hf_logging # type: ignore
hf_logging.set_verbosity_error()
try:
hf_logging.disable_progress_bar()
except Exception:
pass
processor_cls = getattr(transformers, "Sam3Processor", None) or getattr(
transformers, "AutoProcessor", None
) or getattr(transformers, "AutoImageProcessor", None)
model_cls = getattr(transformers, "Sam3Model", None) or getattr(
transformers, "AutoModelForMaskGeneration", None
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
processor = processor_cls.from_pretrained(model_id)
model = model_cls.from_pretrained(model_id)
try:
model = model.to(device)
except RuntimeError as exc:
# Fall back to CPU if the GPU move fails (e.g., OOM or missing device)
device = torch.device("cpu")
model = model.to(device)
print(f"[WARN] SAM3 fell back to CPU after .to(device) error: {exc}")
model.eval()
self.processor = processor
self.model = model
self.device = device
if torch.cuda.is_available() and self.device.type != "cuda":
print("[WARN] CUDA is available but SAM3 is running on CPU; mask generation will be slow.")
else:
print(f"[INFO] SAM3 loaded on {self.device}")
def segment(
self,
img: Image.Image,
max_side: int,
prompts: Dict[str, str],
score_threshold: float,
mask_threshold: float,
) -> dict[str, np.ndarray]:
if not prompts:
return {}
orig_size = img.size # (W, H)
img_proc = img
if max(img.size) > max_side:
scale = max_side / max(img.size)
new_size = (max(1, int(round(img.size[0] * scale))), max(1, int(round(img.size[1] * scale))))
img_proc = img.resize(new_size, resample=Image.BILINEAR)
def _split_prompts(text: str) -> list[str]:
parts = [p.strip() for p in re.split(r"[;,\n]", text) if p.strip()]
return parts if parts else ([text.strip()] if text.strip() else [])
masks: dict[str, np.ndarray] = {}
for key, prompt in prompts.items():
prompt_texts = _split_prompts(prompt or "")
if not prompt_texts:
continue
mask_union = None
for text in prompt_texts:
try:
inputs = self.processor(images=img_proc, text=text, return_tensors="pt").to(self.device)
except TypeError as exc:
raise ImportError(
"Loaded processor does not accept text prompts; install a transformers build with SAM3 text prompting support (e.g., pip install --upgrade transformers or a nightly that includes Sam3Processor)."
) from exc
with torch.inference_mode():
outputs = self.model(**inputs)
results = self.processor.post_process_instance_segmentation(
outputs,
threshold=score_threshold,
mask_threshold=mask_threshold,
target_sizes=[(orig_size[1], orig_size[0])],
)[0]
inst_masks = results.get("masks")
if inst_masks is None or len(inst_masks) == 0:
continue
if torch.is_floating_point(inst_masks):
inst_masks = inst_masks > 0.5
mask_tensor = torch.any(inst_masks, dim=0)
mask_union = mask_tensor if mask_union is None else (mask_union | mask_tensor)
if mask_union is None:
continue
mask_np = mask_union.detach().cpu().numpy().astype(bool)
if mask_np.any():
masks[key] = mask_np
return masks
@dataclass
class SegmenterRequest:
image: Image.Image
source_path: Optional[str] = None
want_water: bool = False
want_road: bool = False
want_roof: bool = False
want_tree: bool = False
max_side: int = SEGMENTATION_MAX_SIDE
water_prompt: str = WATER_PROMPT
road_prompt: str = ROAD_PROMPT
roof_prompt: str = ROOF_PROMPT
tree_prompt: str = TREE_PROMPT
score_threshold: float = SEGMENTATION_SCORE_THRESH
mask_threshold: float = SEGMENTATION_MASK_THRESH
class SegmenterService:
"""Caches segmenters and mask outputs across UI interactions."""
def __init__(self, model_id: str = SEGMENTATION_MODEL_ID):
self.model_id = model_id
self._segmenters: Dict[str, SemanticSegmenter] = {}
# Eagerly load the default model once to avoid repeated cold-starts.
try:
self._segmenters[model_id] = SemanticSegmenter(model_id)
except Exception as exc:
print(f"[WARN] Failed to preload segmentation model {model_id}: {exc}")
def _get_segmenter(self, model_id: str) -> SemanticSegmenter:
if model_id not in self._segmenters:
self._segmenters[model_id] = SemanticSegmenter(model_id)
return self._segmenters[model_id]
def get_masks(self, request: SegmenterRequest, model_id: str | None = None) -> dict[str, np.ndarray]:
if not (request.want_water or request.want_road or request.want_tree or request.want_roof):
return {}
segmenter = self._get_segmenter(model_id or self.model_id)
prompts: dict[str, str] = {}
if request.want_water and request.water_prompt:
prompts["water"] = request.water_prompt
if request.want_road and request.road_prompt:
prompts["road"] = request.road_prompt
if request.want_roof and request.roof_prompt:
prompts["roof"] = request.roof_prompt
if request.want_tree and request.tree_prompt:
prompts["tree"] = request.tree_prompt
try:
masks = segmenter.segment(
request.image,
request.max_side,
prompts=prompts,
score_threshold=float(request.score_threshold),
mask_threshold=float(request.mask_threshold),
)
except RuntimeError as exc:
print(f"[WARN] Segmentation failed; skipping masks: {exc}")
masks = {}
result: dict[str, np.ndarray] = {}
if request.want_water and masks.get("water") is not None:
result["water"] = masks["water"]
if request.want_road and masks.get("road") is not None:
result["road"] = masks["road"]
if request.want_roof and masks.get("roof") is not None:
result["roof"] = masks["roof"]
if request.want_tree and masks.get("tree") is not None:
result["tree"] = masks["tree"]
return result
__all__ = ["SegmenterService", "SegmenterRequest", "SemanticSegmenter"]
# Shared singleton to avoid reloads across analyzer instances
_GLOBAL_SEGMENTER: SegmenterService | None = None
def get_global_segmenter(default_model_id: str = SEGMENTATION_MODEL_ID) -> SegmenterService:
global _GLOBAL_SEGMENTER
if _GLOBAL_SEGMENTER is None or _GLOBAL_SEGMENTER.model_id != default_model_id:
_GLOBAL_SEGMENTER = SegmenterService(default_model_id)
return _GLOBAL_SEGMENTER
|