"""SAC Patch Segmenter model for outpost deployment. Accepts PIL images directly, runs U-Net segmentation, and returns patch detection score + mask fraction. Usage (inside outpost): result = model.predict(image=pil_image) # returns {"score": 0.85, "mask_fraction": 0.12} Reference: Liu et al., CVPR 2022, "Segment and Complete" """ from __future__ import annotations from typing import Optional import sys import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from torchvision import transforms from transformers import PreTrainedModel from .configuration_sac import SACPatchSegmenterConfig def _log(msg): print(f"[SAC-DEBUG] {msg}", file=sys.stderr, flush=True) # --------------------------------------------------------------------------- # U-Net architecture (matches joellliu/SegmentAndComplete coco_at.pth) # --------------------------------------------------------------------------- class _DoubleConv(nn.Module): def __init__(self, in_ch, out_ch, mid_ch=None): super().__init__() mid = mid_ch or out_ch self.double_conv = nn.Sequential( nn.Conv2d(in_ch, mid, 3, padding=1), nn.BatchNorm2d(mid), nn.ReLU(inplace=True), nn.Conv2d(mid, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), ) def forward(self, x): return self.double_conv(x) class _Down(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), _DoubleConv(in_ch, out_ch)) def forward(self, x): return self.maxpool_conv(x) class _Up(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) self.conv = _DoubleConv(in_ch, out_ch, in_ch // 2) def forward(self, x1, x2): x1 = self.up(x1) dy, dx = x2.size(2) - x1.size(2), x2.size(3) - x1.size(3) x1 = F.pad(x1, [dx // 2, dx - dx // 2, dy // 2, dy - dy // 2]) return self.conv(torch.cat([x2, x1], dim=1)) class _OutConv(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=1) def forward(self, x): return self.conv(x) # --------------------------------------------------------------------------- # HuggingFace PreTrainedModel wrapper # --------------------------------------------------------------------------- class SACPatchSegmenterModel(PreTrainedModel): """SAC U-Net patch segmenter with integrated preprocessing. Accepts PIL images, resizes to 416x416, runs U-Net segmentation, and returns patch detection results. """ config_class = SACPatchSegmenterConfig supports_gradient_checkpointing = False def __init__(self, config: SACPatchSegmenterConfig) -> None: super().__init__(config) bf = config.base_filter self._input_size = config.input_size # U-Net layers self.inc = _DoubleConv(3, bf) self.down1 = _Down(bf, bf * 2) self.down2 = _Down(bf * 2, bf * 4) self.down3 = _Down(bf * 4, bf * 8) self.down4 = _Down(bf * 8, bf * 16 // 2) self.up1 = _Up(bf * 16, bf * 8 // 2) self.up2 = _Up(bf * 8, bf * 4 // 2) self.up3 = _Up(bf * 4, bf * 2 // 2) self.up4 = _Up(bf * 2, bf) self.outc = _OutConv(bf, 1) self._to_tensor = transforms.ToTensor() def forward(self, pixel_values: Optional[torch.Tensor] = None, **kwargs): """Standard forward pass. Also supports predict(image=pil).""" if "image" in kwargs: return self.predict(**kwargs) if pixel_values is None: raise ValueError("Provide pixel_values tensor or image=PIL") return self._unet_forward(pixel_values) def _unet_forward(self, x: torch.Tensor) -> torch.Tensor: x1 = self.inc(x) x2 = self.down1(x1) x3 = self.down2(x2) x4 = self.down3(x3) x5 = self.down4(x4) x = self.up1(x5, x4) x = self.up2(x, x3) x = self.up3(x, x2) x = self.up4(x, x1) return self.outc(x) @torch.no_grad() def predict(self, image: Image.Image, **kwargs) -> dict: """Accept a PIL image and return patch detection results.""" img = image.convert("RGB").resize( (self._input_size, self._input_size), Image.Resampling.BILINEAR ) tensor = self._to_tensor(img).unsqueeze(0).to(device=self.device, dtype=self.dtype) logits = self._unet_forward(tensor) prob = torch.sigmoid(logits) mask = (prob[0, 0] > 0.5).float() mask_fraction = float(mask.sum().item()) / mask.numel() if mask_fraction > 0.001: score = min(1.0, mask_fraction * 10.0) else: score = float(prob.max().item()) * 0.5 return {"score": score, "mask_fraction": mask_fraction} @torch.no_grad() def score_image(self, image: Image.Image, **kwargs) -> dict: """Alias for predict — matches outpost calling convention.""" return self.predict(image=image, **kwargs)