| """SAC Patch Segmenter model for outpost deployment. |
| |
| Accepts PIL images directly, runs U-Net segmentation, and returns |
| patch detection score + mask fraction. |
| |
| Usage (inside outpost): |
| result = model.predict(image=pil_image) |
| # returns {"score": 0.85, "mask_fraction": 0.12} |
| |
| Reference: Liu et al., CVPR 2022, "Segment and Complete" |
| """ |
|
|
| from __future__ import annotations |
|
|
| from typing import Optional |
|
|
| import sys |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from PIL import Image |
| from torchvision import transforms |
| from transformers import PreTrainedModel |
|
|
| from .configuration_sac import SACPatchSegmenterConfig |
|
|
|
|
| def _log(msg): |
| print(f"[SAC-DEBUG] {msg}", file=sys.stderr, flush=True) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class _DoubleConv(nn.Module): |
| def __init__(self, in_ch, out_ch, mid_ch=None): |
| super().__init__() |
| mid = mid_ch or out_ch |
| self.double_conv = nn.Sequential( |
| nn.Conv2d(in_ch, mid, 3, padding=1), nn.BatchNorm2d(mid), nn.ReLU(inplace=True), |
| nn.Conv2d(mid, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), |
| ) |
|
|
| def forward(self, x): |
| return self.double_conv(x) |
|
|
|
|
| class _Down(nn.Module): |
| def __init__(self, in_ch, out_ch): |
| super().__init__() |
| self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), _DoubleConv(in_ch, out_ch)) |
|
|
| def forward(self, x): |
| return self.maxpool_conv(x) |
|
|
|
|
| class _Up(nn.Module): |
| def __init__(self, in_ch, out_ch): |
| super().__init__() |
| self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) |
| self.conv = _DoubleConv(in_ch, out_ch, in_ch // 2) |
|
|
| def forward(self, x1, x2): |
| x1 = self.up(x1) |
| dy, dx = x2.size(2) - x1.size(2), x2.size(3) - x1.size(3) |
| x1 = F.pad(x1, [dx // 2, dx - dx // 2, dy // 2, dy - dy // 2]) |
| return self.conv(torch.cat([x2, x1], dim=1)) |
|
|
|
|
| class _OutConv(nn.Module): |
| def __init__(self, in_ch, out_ch): |
| super().__init__() |
| self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=1) |
|
|
| def forward(self, x): |
| return self.conv(x) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class SACPatchSegmenterModel(PreTrainedModel): |
| """SAC U-Net patch segmenter with integrated preprocessing. |
| |
| Accepts PIL images, resizes to 416x416, runs U-Net segmentation, |
| and returns patch detection results. |
| """ |
|
|
| config_class = SACPatchSegmenterConfig |
| supports_gradient_checkpointing = False |
|
|
| def __init__(self, config: SACPatchSegmenterConfig) -> None: |
| super().__init__(config) |
| bf = config.base_filter |
| self._input_size = config.input_size |
|
|
| |
| self.inc = _DoubleConv(3, bf) |
| self.down1 = _Down(bf, bf * 2) |
| self.down2 = _Down(bf * 2, bf * 4) |
| self.down3 = _Down(bf * 4, bf * 8) |
| self.down4 = _Down(bf * 8, bf * 16 // 2) |
| self.up1 = _Up(bf * 16, bf * 8 // 2) |
| self.up2 = _Up(bf * 8, bf * 4 // 2) |
| self.up3 = _Up(bf * 4, bf * 2 // 2) |
| self.up4 = _Up(bf * 2, bf) |
| self.outc = _OutConv(bf, 1) |
|
|
| self._to_tensor = transforms.ToTensor() |
|
|
| def forward(self, pixel_values: Optional[torch.Tensor] = None, **kwargs): |
| """Standard forward pass. Also supports predict(image=pil).""" |
| if "image" in kwargs: |
| return self.predict(**kwargs) |
| if pixel_values is None: |
| raise ValueError("Provide pixel_values tensor or image=PIL") |
| return self._unet_forward(pixel_values) |
|
|
| def _unet_forward(self, x: torch.Tensor) -> torch.Tensor: |
| x1 = self.inc(x) |
| x2 = self.down1(x1) |
| x3 = self.down2(x2) |
| x4 = self.down3(x3) |
| x5 = self.down4(x4) |
| x = self.up1(x5, x4) |
| x = self.up2(x, x3) |
| x = self.up3(x, x2) |
| x = self.up4(x, x1) |
| return self.outc(x) |
|
|
| @torch.no_grad() |
| def predict(self, image: Image.Image, **kwargs) -> dict: |
| """Accept a PIL image and return patch detection results.""" |
| img = image.convert("RGB").resize( |
| (self._input_size, self._input_size), Image.Resampling.BILINEAR |
| ) |
| tensor = self._to_tensor(img).unsqueeze(0).to(device=self.device, dtype=self.dtype) |
|
|
| logits = self._unet_forward(tensor) |
| prob = torch.sigmoid(logits) |
|
|
| mask = (prob[0, 0] > 0.5).float() |
| mask_fraction = float(mask.sum().item()) / mask.numel() |
|
|
| if mask_fraction > 0.001: |
| score = min(1.0, mask_fraction * 10.0) |
| else: |
| score = float(prob.max().item()) * 0.5 |
|
|
| return {"score": score, "mask_fraction": mask_fraction} |
|
|
| @torch.no_grad() |
| def score_image(self, image: Image.Image, **kwargs) -> dict: |
| """Alias for predict — matches outpost calling convention.""" |
| return self.predict(image=image, **kwargs) |
|
|