"""NAPGuard Patch Detector model for outpost deployment. Accepts PIL images directly, applies NFSI preprocessing, runs YOLOv5s detection, and returns patch detection results. Usage (inside outpost): result = model.predict(image=pil_image) # returns {"score": 0.85, "num_detections": 2} Reference: Wu et al., CVPR 2024 """ from __future__ import annotations import sys from typing import List, Optional import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from torchvision import transforms from torchvision.ops import nms from transformers import PreTrainedModel from .configuration_napguard import NAPGuardPatchDetectorConfig def _log(msg): print(f"[NAPGUARD-DEBUG] {msg}", file=sys.stderr, flush=True) # --------------------------------------------------------------------------- # YOLOv5s building blocks # --------------------------------------------------------------------------- def _autopad(k, p=None, d=1): if d > 1: k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] if p is None: p = k // 2 if isinstance(k, int) else [x // 2 for x in k] return p class Conv(nn.Module): default_act = nn.SiLU() def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True): super().__init__() self.conv = nn.Conv2d(c1, c2, k, s, _autopad(k, p, d), groups=g, dilation=d, bias=False) self.bn = nn.BatchNorm2d(c2) self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity() def forward(self, x): return self.act(self.bn(self.conv(x))) class Bottleneck(nn.Module): def __init__(self, c1, c2, shortcut=True, g=1, e=0.5): super().__init__() c_ = int(c2 * e) self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c_, c2, 3, 1, g=g) self.add = shortcut and c1 == c2 def forward(self, x): return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x)) class C3(nn.Module): def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): super().__init__() c_ = int(c2 * e) self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c1, c_, 1, 1) self.cv3 = Conv(2 * c_, c2, 1) self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n))) def forward(self, x): return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1)) class SPPF(nn.Module): def __init__(self, c1, c2, k=5): super().__init__() c_ = c1 // 2 self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c_ * 4, c2, 1, 1) self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) def forward(self, x): x = self.cv1(x) y1 = self.m(x) y2 = self.m(y1) return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1)) class Detect(nn.Module): stride = None def __init__(self, nc=1, anchors=(), ch=()): super().__init__() self.nc = nc self.no = nc + 5 self.nl = len(anchors) self.na = len(anchors[0]) // 2 self.grid = [torch.empty(0) for _ in range(self.nl)] self.anchor_grid = [torch.empty(0) for _ in range(self.nl)] self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2)) self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch) def forward(self, x): z = [] for i in range(self.nl): x[i] = self.m[i](x[i]) bs, _, ny, nx = x[i].shape x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous() if self.grid[i].shape[2:4] != x[i].shape[2:4]: self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i) xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1), 4) xy = (xy * 2 - 0.5 + self.grid[i]) * self.stride[i] wh = (wh * 2) ** 2 * self.anchor_grid[i] z.append(torch.cat((xy, wh, conf), 4).view(bs, self.na * nx * ny, self.no)) return (torch.cat(z, 1),) def _make_grid(self, nx, ny, i): d = self.anchors[i].device t = self.anchors[i].dtype shape = 1, self.na, ny, nx, 2 y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t) yv, xv = torch.meshgrid(y, x, indexing='ij') grid = torch.stack((xv, yv), 2).expand(shape) - 0.5 anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape) return grid, anchor_grid class _Upsample(nn.Module): """Placeholder for upsample layers (no parameters, needed for indexing).""" def __init__(self): super().__init__() self.up = nn.Upsample(None, 2, 'nearest') def forward(self, x): return self.up(x) class _Concat(nn.Module): """Placeholder for concat layers (no parameters, needed for indexing).""" def forward(self, x): return torch.cat(x, 1) # --------------------------------------------------------------------------- # NFSI # --------------------------------------------------------------------------- def _nfsi(imgs, sigma=3.0, threshold_factor=2.0): # FFT requires float32 (cuFFT doesn't support fp16 for non-power-of-2 sizes) orig_dtype = imgs.dtype imgs_f32 = imgs.float() blur = transforms.GaussianBlur(3, sigma) _, _, height, width = imgs_f32.shape R = (height + width) // 8 yy, xx = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij") lpf = (((xx - (width - 1) / 2) ** 2 + (yy - (height - 1) / 2) ** 2) < R ** 2).float().to(imgs_f32.device) im_copy = imgs_f32.clone() mask_bg = torch.ones_like(imgs_f32) f = torch.fft.fftn(im_copy, dim=(2, 3)) f = torch.roll(f, (height // 2, width // 2), dims=(2, 3)) f_l = torch.roll(f * lpf, (-height // 2, -width // 2), dims=(2, 3)) x_l = torch.abs(torch.fft.ifftn(f_l, dim=(2, 3))).clamp(0, 1).mean(dim=1) mu, std = x_l.mean(dim=(1, 2)), x_l.std(dim=(1, 2)) for idx in range(x_l.shape[0]): x_l[idx] = torch.where( torch.abs(x_l[idx] - mu[idx]) > threshold_factor * std[idx], torch.ones_like(x_l[idx]), torch.zeros_like(x_l[idx])) mask = x_l.unsqueeze(1).repeat(1, 3, 1, 1) result = blur(im_copy).clamp_(0, 1) * mask + imgs_f32 * (mask_bg - mask) return result.to(orig_dtype) # --------------------------------------------------------------------------- # HuggingFace wrapper # --------------------------------------------------------------------------- # YOLOv5s layer structure (matching original state_dict keys model.0 - model.24) # Layers 11, 12, 15, 16, 19, 22 are Upsample/Concat (no params) _ANCHORS = [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]] class NAPGuardPatchDetectorModel(PreTrainedModel): """NAPGuard YOLOv5s patch detector with state_dict key prefix `model.N.*`. The nn.ModuleList index matches the original YOLOv5 layer numbering so that state_dict keys align for from_pretrained loading. """ config_class = NAPGuardPatchDetectorConfig supports_gradient_checkpointing = False def __init__(self, config: NAPGuardPatchDetectorConfig) -> None: super().__init__(config) self._input_size = config.input_size self._conf_thres = config.conf_thres self._iou_thres = config.iou_thres self._use_nfsi = config.use_nfsi self._nfsi_sigma = config.nfsi_sigma self._nfsi_threshold_factor = config.nfsi_threshold_factor self._to_tensor = transforms.ToTensor() # Build model as ModuleList to match state_dict key prefix model.N self.model = nn.ModuleList([ Conv(3, 32, 6, 2, 2), # 0: P1/2 Conv(32, 64, 3, 2), # 1: P2/4 C3(64, 64, 1), # 2 Conv(64, 128, 3, 2), # 3: P3/8 C3(128, 128, 2), # 4 Conv(128, 256, 3, 2), # 5: P4/16 C3(256, 256, 3), # 6 Conv(256, 512, 3, 2), # 7: P5/32 C3(512, 512, 1), # 8 SPPF(512, 512, 5), # 9 Conv(512, 256, 1, 1), # 10 _Upsample(), # 11 (no params) _Concat(), # 12 (no params) C3(512, 256, 1, False), # 13 Conv(256, 128, 1, 1), # 14 _Upsample(), # 15 (no params) _Concat(), # 16 (no params) C3(256, 128, 1, False), # 17 Conv(128, 128, 3, 2), # 18 _Concat(), # 19 (no params) C3(256, 256, 1, False), # 20 Conv(256, 256, 3, 2), # 21 _Concat(), # 22 (no params) C3(512, 512, 1, False), # 23 Detect(1, _ANCHORS, (128, 256, 512)), # 24 ]) # Set detect stride self.model[24].stride = torch.tensor([8., 16., 32.]) # Save/restore indices for the PANet forward pass self._save_indices = {4, 6, 9, 10, 13, 14, 17, 20} def forward(self, pixel_values=None, **kwargs): if "image" in kwargs: return self.predict(**kwargs) if pixel_values is None: raise ValueError("Provide pixel_values or image=PIL") return self._yolo_forward(pixel_values) def _yolo_forward(self, x): """YOLOv5s forward with PANet skip connections.""" saved = {} # Backbone: layers 0-9 for i in range(10): x = self.model[i](x) if i in self._save_indices: saved[i] = x # Neck: 10 → upsample → cat(P4) → 13 x = self.model[10](x) saved[10] = x x = self.model[11](x) # upsample x = self.model[12]([x, saved[6]]) # cat with P4 x = self.model[13](x) saved[13] = x # 14 → upsample → cat(P3) → 17 x = self.model[14](x) saved[14] = x x = self.model[15](x) # upsample x = self.model[16]([x, saved[4]]) # cat with P3 x = self.model[17](x) det_small = x # P3 output saved[17] = x # 18 → cat(layer 14 output) → 20 x = self.model[18](x) x = self.model[19]([x, saved[14]]) # cat with layer 14 x = self.model[20](x) det_mid = x # P4 output saved[20] = x # 21 → cat(10 output) → 23 x = self.model[21](x) x = self.model[22]([x, saved[10]]) # cat x = self.model[23](x) det_large = x # P5 output # Detect return self.model[24]([det_small, det_mid, det_large]) @torch.no_grad() def predict(self, image: Image.Image, **kwargs) -> dict: img = image.convert("RGB").resize( (self._input_size, self._input_size), Image.Resampling.BILINEAR ) tensor = self._to_tensor(img).unsqueeze(0).to(device=self.device, dtype=self.dtype) if self._use_nfsi: tensor = _nfsi(tensor, self._nfsi_sigma, self._nfsi_threshold_factor) pred = self._yolo_forward(tensor) prediction = pred[0] if isinstance(pred, tuple) else pred if prediction.ndim == 3: prediction = prediction[0] # Filter by obj conf xc = prediction[..., 4] > self._conf_thres x = prediction[xc] if x.shape[0] == 0: return {"score": 0.0, "num_detections": 0} # Combine obj conf * class conf x[:, 5:] *= x[:, 4:5] conf, cls = x[:, 5:].max(1, keepdim=True) x = torch.cat([x[:, :4], conf, cls], dim=1) x = x[x[:, 4] > self._conf_thres] if x.shape[0] == 0: return {"score": 0.0, "num_detections": 0} # xywh → xyxy boxes = x[:, :4].clone() boxes[:, 0] = x[:, 0] - x[:, 2] / 2 boxes[:, 1] = x[:, 1] - x[:, 3] / 2 boxes[:, 2] = x[:, 0] + x[:, 2] / 2 boxes[:, 3] = x[:, 1] + x[:, 3] / 2 keep = nms(boxes, x[:, 4], self._iou_thres) x = x[keep] return {"score": float(x[:, 4].max().item()), "num_detections": int(x.shape[0])} @torch.no_grad() def score_image(self, image: Image.Image, **kwargs) -> dict: return self.predict(image=image, **kwargs)