sac-patch-segmenter-2 / modeling_sac.py
rocker417's picture
Upload modeling_sac.py with huggingface_hub
35e14e9 verified
"""SAC Patch Segmenter model for outpost deployment.
Accepts PIL images directly, runs U-Net segmentation, and returns
patch detection score + mask fraction.
Usage (inside outpost):
result = model.predict(image=pil_image)
# returns {"score": 0.85, "mask_fraction": 0.12}
Reference: Liu et al., CVPR 2022, "Segment and Complete"
"""
from __future__ import annotations
from typing import Optional
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from transformers import PreTrainedModel
from .configuration_sac import SACPatchSegmenterConfig
def _log(msg):
print(f"[SAC-DEBUG] {msg}", file=sys.stderr, flush=True)
# ---------------------------------------------------------------------------
# U-Net architecture (matches joellliu/SegmentAndComplete coco_at.pth)
# ---------------------------------------------------------------------------
class _DoubleConv(nn.Module):
def __init__(self, in_ch, out_ch, mid_ch=None):
super().__init__()
mid = mid_ch or out_ch
self.double_conv = nn.Sequential(
nn.Conv2d(in_ch, mid, 3, padding=1), nn.BatchNorm2d(mid), nn.ReLU(inplace=True),
nn.Conv2d(mid, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True),
)
def forward(self, x):
return self.double_conv(x)
class _Down(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.maxpool_conv = nn.Sequential(nn.MaxPool2d(2), _DoubleConv(in_ch, out_ch))
def forward(self, x):
return self.maxpool_conv(x)
class _Up(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
self.conv = _DoubleConv(in_ch, out_ch, in_ch // 2)
def forward(self, x1, x2):
x1 = self.up(x1)
dy, dx = x2.size(2) - x1.size(2), x2.size(3) - x1.size(3)
x1 = F.pad(x1, [dx // 2, dx - dx // 2, dy // 2, dy - dy // 2])
return self.conv(torch.cat([x2, x1], dim=1))
class _OutConv(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv = nn.Conv2d(in_ch, out_ch, kernel_size=1)
def forward(self, x):
return self.conv(x)
# ---------------------------------------------------------------------------
# HuggingFace PreTrainedModel wrapper
# ---------------------------------------------------------------------------
class SACPatchSegmenterModel(PreTrainedModel):
"""SAC U-Net patch segmenter with integrated preprocessing.
Accepts PIL images, resizes to 416x416, runs U-Net segmentation,
and returns patch detection results.
"""
config_class = SACPatchSegmenterConfig
supports_gradient_checkpointing = False
def __init__(self, config: SACPatchSegmenterConfig) -> None:
super().__init__(config)
bf = config.base_filter
self._input_size = config.input_size
# U-Net layers
self.inc = _DoubleConv(3, bf)
self.down1 = _Down(bf, bf * 2)
self.down2 = _Down(bf * 2, bf * 4)
self.down3 = _Down(bf * 4, bf * 8)
self.down4 = _Down(bf * 8, bf * 16 // 2)
self.up1 = _Up(bf * 16, bf * 8 // 2)
self.up2 = _Up(bf * 8, bf * 4 // 2)
self.up3 = _Up(bf * 4, bf * 2 // 2)
self.up4 = _Up(bf * 2, bf)
self.outc = _OutConv(bf, 1)
self._to_tensor = transforms.ToTensor()
def forward(self, pixel_values: Optional[torch.Tensor] = None, **kwargs):
"""Standard forward pass. Also supports predict(image=pil)."""
if "image" in kwargs:
return self.predict(**kwargs)
if pixel_values is None:
raise ValueError("Provide pixel_values tensor or image=PIL")
return self._unet_forward(pixel_values)
def _unet_forward(self, x: torch.Tensor) -> torch.Tensor:
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
return self.outc(x)
@torch.no_grad()
def predict(self, image: Image.Image, **kwargs) -> dict:
"""Accept a PIL image and return patch detection results."""
img = image.convert("RGB").resize(
(self._input_size, self._input_size), Image.Resampling.BILINEAR
)
tensor = self._to_tensor(img).unsqueeze(0).to(device=self.device, dtype=self.dtype)
logits = self._unet_forward(tensor)
prob = torch.sigmoid(logits)
mask = (prob[0, 0] > 0.5).float()
mask_fraction = float(mask.sum().item()) / mask.numel()
if mask_fraction > 0.001:
score = min(1.0, mask_fraction * 10.0)
else:
score = float(prob.max().item()) * 0.5
return {"score": score, "mask_fraction": mask_fraction}
@torch.no_grad()
def score_image(self, image: Image.Image, **kwargs) -> dict:
"""Alias for predict — matches outpost calling convention."""
return self.predict(image=image, **kwargs)