| """NAPGuard Patch Detector model for outpost deployment. |
| |
| Accepts PIL images directly, applies NFSI preprocessing, runs YOLOv5s |
| detection, and returns patch detection results. |
| |
| Usage (inside outpost): |
| result = model.predict(image=pil_image) |
| # returns {"score": 0.85, "num_detections": 2} |
| |
| Reference: Wu et al., CVPR 2024 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import sys |
| from typing import List, Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from PIL import Image |
| from torchvision import transforms |
| from torchvision.ops import nms |
| from transformers import PreTrainedModel |
|
|
| from .configuration_napguard import NAPGuardPatchDetectorConfig |
|
|
|
|
| def _log(msg): |
| print(f"[NAPGUARD-DEBUG] {msg}", file=sys.stderr, flush=True) |
|
|
|
|
| |
| |
| |
|
|
| def _autopad(k, p=None, d=1): |
| if d > 1: |
| k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] |
| if p is None: |
| p = k // 2 if isinstance(k, int) else [x // 2 for x in k] |
| return p |
|
|
|
|
| class Conv(nn.Module): |
| default_act = nn.SiLU() |
|
|
| def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True): |
| super().__init__() |
| self.conv = nn.Conv2d(c1, c2, k, s, _autopad(k, p, d), groups=g, dilation=d, bias=False) |
| self.bn = nn.BatchNorm2d(c2) |
| self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity() |
|
|
| def forward(self, x): |
| return self.act(self.bn(self.conv(x))) |
|
|
|
|
| class Bottleneck(nn.Module): |
| def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): |
| super().__init__() |
| c_ = int(c2 * e) |
| self.cv1 = Conv(c1, c_, 1, 1) |
| self.cv2 = Conv(c_, c2, 3, 1, g=g) |
| self.add = shortcut and c1 == c2 |
|
|
| def forward(self, x): |
| return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) |
|
|
|
|
| class C3(nn.Module): |
| def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): |
| super().__init__() |
| c_ = int(c2 * e) |
| self.cv1 = Conv(c1, c_, 1, 1) |
| self.cv2 = Conv(c1, c_, 1, 1) |
| self.cv3 = Conv(2 * c_, c2, 1) |
| self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) |
|
|
| def forward(self, x): |
| return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) |
|
|
|
|
| class SPPF(nn.Module): |
| def __init__(self, c1, c2, k=5): |
| super().__init__() |
| c_ = c1 // 2 |
| self.cv1 = Conv(c1, c_, 1, 1) |
| self.cv2 = Conv(c_ * 4, c2, 1, 1) |
| self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) |
|
|
| def forward(self, x): |
| x = self.cv1(x) |
| y1 = self.m(x) |
| y2 = self.m(y1) |
| return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1)) |
|
|
|
|
| class Detect(nn.Module): |
| stride = None |
|
|
| def __init__(self, nc=1, anchors=(), ch=()): |
| super().__init__() |
| self.nc = nc |
| self.no = nc + 5 |
| self.nl = len(anchors) |
| self.na = len(anchors[0]) // 2 |
| self.grid = [torch.empty(0) for _ in range(self.nl)] |
| self.anchor_grid = [torch.empty(0) for _ in range(self.nl)] |
| self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) |
| self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) |
|
|
| def forward(self, x): |
| z = [] |
| for i in range(self.nl): |
| x[i] = self.m[i](x[i]) |
| bs, _, ny, nx = x[i].shape |
| x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() |
| if self.grid[i].shape[2:4] != x[i].shape[2:4]: |
| self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i) |
| xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1), 4) |
| xy = (xy * 2 - 0.5 + self.grid[i]) * self.stride[i] |
| wh = (wh * 2) ** 2 * self.anchor_grid[i] |
| z.append(torch.cat((xy, wh, conf), 4).view(bs, self.na * nx * ny, self.no)) |
| return (torch.cat(z, 1),) |
|
|
| def _make_grid(self, nx, ny, i): |
| d = self.anchors[i].device |
| t = self.anchors[i].dtype |
| shape = 1, self.na, ny, nx, 2 |
| y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t) |
| yv, xv = torch.meshgrid(y, x, indexing='ij') |
| grid = torch.stack((xv, yv), 2).expand(shape) - 0.5 |
| anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape) |
| return grid, anchor_grid |
|
|
|
|
| class _Upsample(nn.Module): |
| """Placeholder for upsample layers (no parameters, needed for indexing).""" |
| def __init__(self): |
| super().__init__() |
| self.up = nn.Upsample(None, 2, 'nearest') |
| def forward(self, x): |
| return self.up(x) |
|
|
|
|
| class _Concat(nn.Module): |
| """Placeholder for concat layers (no parameters, needed for indexing).""" |
| def forward(self, x): |
| return torch.cat(x, 1) |
|
|
|
|
| |
| |
| |
|
|
| def _nfsi(imgs, sigma=3.0, threshold_factor=2.0): |
| |
| orig_dtype = imgs.dtype |
| imgs_f32 = imgs.float() |
|
|
| blur = transforms.GaussianBlur(3, sigma) |
| _, _, height, width = imgs_f32.shape |
| R = (height + width) // 8 |
| yy, xx = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") |
| lpf = (((xx - (width - 1) / 2) ** 2 + (yy - (height - 1) / 2) ** 2) < R ** 2).float().to(imgs_f32.device) |
| im_copy = imgs_f32.clone() |
| mask_bg = torch.ones_like(imgs_f32) |
| f = torch.fft.fftn(im_copy, dim=(2, 3)) |
| f = torch.roll(f, (height // 2, width // 2), dims=(2, 3)) |
| f_l = torch.roll(f * lpf, (-height // 2, -width // 2), dims=(2, 3)) |
| x_l = torch.abs(torch.fft.ifftn(f_l, dim=(2, 3))).clamp(0, 1).mean(dim=1) |
| mu, std = x_l.mean(dim=(1, 2)), x_l.std(dim=(1, 2)) |
| for idx in range(x_l.shape[0]): |
| x_l[idx] = torch.where( |
| torch.abs(x_l[idx] - mu[idx]) > threshold_factor * std[idx], |
| torch.ones_like(x_l[idx]), torch.zeros_like(x_l[idx])) |
| mask = x_l.unsqueeze(1).repeat(1, 3, 1, 1) |
| result = blur(im_copy).clamp_(0, 1) * mask + imgs_f32 * (mask_bg - mask) |
| return result.to(orig_dtype) |
|
|
|
|
| |
| |
| |
|
|
| |
| |
| _ANCHORS = [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]] |
|
|
|
|
| class NAPGuardPatchDetectorModel(PreTrainedModel): |
| """NAPGuard YOLOv5s patch detector with state_dict key prefix `model.N.*`. |
| |
| The nn.ModuleList index matches the original YOLOv5 layer numbering |
| so that state_dict keys align for from_pretrained loading. |
| """ |
|
|
| config_class = NAPGuardPatchDetectorConfig |
| supports_gradient_checkpointing = False |
|
|
| def __init__(self, config: NAPGuardPatchDetectorConfig) -> None: |
| super().__init__(config) |
| self._input_size = config.input_size |
| self._conf_thres = config.conf_thres |
| self._iou_thres = config.iou_thres |
| self._use_nfsi = config.use_nfsi |
| self._nfsi_sigma = config.nfsi_sigma |
| self._nfsi_threshold_factor = config.nfsi_threshold_factor |
| self._to_tensor = transforms.ToTensor() |
|
|
| |
| self.model = nn.ModuleList([ |
| Conv(3, 32, 6, 2, 2), |
| Conv(32, 64, 3, 2), |
| C3(64, 64, 1), |
| Conv(64, 128, 3, 2), |
| C3(128, 128, 2), |
| Conv(128, 256, 3, 2), |
| C3(256, 256, 3), |
| Conv(256, 512, 3, 2), |
| C3(512, 512, 1), |
| SPPF(512, 512, 5), |
| Conv(512, 256, 1, 1), |
| _Upsample(), |
| _Concat(), |
| C3(512, 256, 1, False), |
| Conv(256, 128, 1, 1), |
| _Upsample(), |
| _Concat(), |
| C3(256, 128, 1, False), |
| Conv(128, 128, 3, 2), |
| _Concat(), |
| C3(256, 256, 1, False), |
| Conv(256, 256, 3, 2), |
| _Concat(), |
| C3(512, 512, 1, False), |
| Detect(1, _ANCHORS, (128, 256, 512)), |
| ]) |
|
|
| |
| self.model[24].stride = torch.tensor([8., 16., 32.]) |
|
|
| |
| self._save_indices = {4, 6, 9, 10, 13, 14, 17, 20} |
|
|
| def forward(self, pixel_values=None, **kwargs): |
| if "image" in kwargs: |
| return self.predict(**kwargs) |
| if pixel_values is None: |
| raise ValueError("Provide pixel_values or image=PIL") |
| return self._yolo_forward(pixel_values) |
|
|
| def _yolo_forward(self, x): |
| """YOLOv5s forward with PANet skip connections.""" |
| saved = {} |
|
|
| |
| for i in range(10): |
| x = self.model[i](x) |
| if i in self._save_indices: |
| saved[i] = x |
|
|
| |
| x = self.model[10](x) |
| saved[10] = x |
| x = self.model[11](x) |
| x = self.model[12]([x, saved[6]]) |
| x = self.model[13](x) |
| saved[13] = x |
|
|
| |
| x = self.model[14](x) |
| saved[14] = x |
| x = self.model[15](x) |
| x = self.model[16]([x, saved[4]]) |
| x = self.model[17](x) |
| det_small = x |
| saved[17] = x |
|
|
| |
| x = self.model[18](x) |
| x = self.model[19]([x, saved[14]]) |
| x = self.model[20](x) |
| det_mid = x |
| saved[20] = x |
|
|
| |
| x = self.model[21](x) |
| x = self.model[22]([x, saved[10]]) |
| x = self.model[23](x) |
| det_large = x |
|
|
| |
| return self.model[24]([det_small, det_mid, det_large]) |
|
|
| @torch.no_grad() |
| def predict(self, image: Image.Image, **kwargs) -> dict: |
| img = image.convert("RGB").resize( |
| (self._input_size, self._input_size), Image.Resampling.BILINEAR |
| ) |
| tensor = self._to_tensor(img).unsqueeze(0).to(device=self.device, dtype=self.dtype) |
|
|
| if self._use_nfsi: |
| tensor = _nfsi(tensor, self._nfsi_sigma, self._nfsi_threshold_factor) |
|
|
| pred = self._yolo_forward(tensor) |
| prediction = pred[0] if isinstance(pred, tuple) else pred |
| if prediction.ndim == 3: |
| prediction = prediction[0] |
|
|
| |
| xc = prediction[..., 4] > self._conf_thres |
| x = prediction[xc] |
| if x.shape[0] == 0: |
| return {"score": 0.0, "num_detections": 0} |
|
|
| |
| x[:, 5:] *= x[:, 4:5] |
| conf, cls = x[:, 5:].max(1, keepdim=True) |
| x = torch.cat([x[:, :4], conf, cls], dim=1) |
| x = x[x[:, 4] > self._conf_thres] |
| if x.shape[0] == 0: |
| return {"score": 0.0, "num_detections": 0} |
|
|
| |
| boxes = x[:, :4].clone() |
| boxes[:, 0] = x[:, 0] - x[:, 2] / 2 |
| boxes[:, 1] = x[:, 1] - x[:, 3] / 2 |
| boxes[:, 2] = x[:, 0] + x[:, 2] / 2 |
| boxes[:, 3] = x[:, 1] + x[:, 3] / 2 |
|
|
| keep = nms(boxes, x[:, 4], self._iou_thres) |
| x = x[keep] |
|
|
| return {"score": float(x[:, 4].max().item()), "num_detections": int(x.shape[0])} |
|
|
| @torch.no_grad() |
| def score_image(self, image: Image.Image, **kwargs) -> dict: |
| return self.predict(image=image, **kwargs) |
|
|