napguard-patch-detector-3 / modeling_napguard.py
rocker417's picture
Upload modeling_napguard.py with huggingface_hub
7766391 verified
"""NAPGuard Patch Detector model for outpost deployment.
Accepts PIL images directly, applies NFSI preprocessing, runs YOLOv5s
detection, and returns patch detection results.
Usage (inside outpost):
result = model.predict(image=pil_image)
# returns {"score": 0.85, "num_detections": 2}
Reference: Wu et al., CVPR 2024
"""
from __future__ import annotations
import sys
from typing import List, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from torchvision.ops import nms
from transformers import PreTrainedModel
from .configuration_napguard import NAPGuardPatchDetectorConfig
def _log(msg):
print(f"[NAPGUARD-DEBUG] {msg}", file=sys.stderr, flush=True)
# ---------------------------------------------------------------------------
# YOLOv5s building blocks
# ---------------------------------------------------------------------------
def _autopad(k, p=None, d=1):
if d > 1:
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k]
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k]
return p
class Conv(nn.Module):
default_act = nn.SiLU()
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, _autopad(k, p, d), groups=g, dilation=d, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self, x):
return self.act(self.bn(self.conv(x)))
class Bottleneck(nn.Module):
def __init__(self, c1, c2, shortcut=True, g=1, e=0.5):
super().__init__()
c_ = int(c2 * e)
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_, c2, 3, 1, g=g)
self.add = shortcut and c1 == c2
def forward(self, x):
return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))
class C3(nn.Module):
def __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5):
super().__init__()
c_ = int(c2 * e)
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c1, c_, 1, 1)
self.cv3 = Conv(2 * c_, c2, 1)
self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))
def forward(self, x):
return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
class SPPF(nn.Module):
def __init__(self, c1, c2, k=5):
super().__init__()
c_ = c1 // 2
self.cv1 = Conv(c1, c_, 1, 1)
self.cv2 = Conv(c_ * 4, c2, 1, 1)
self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)
def forward(self, x):
x = self.cv1(x)
y1 = self.m(x)
y2 = self.m(y1)
return self.cv2(torch.cat((x, y1, y2, self.m(y2)), 1))
class Detect(nn.Module):
stride = None
def __init__(self, nc=1, anchors=(), ch=()):
super().__init__()
self.nc = nc
self.no = nc + 5
self.nl = len(anchors)
self.na = len(anchors[0]) // 2
self.grid = [torch.empty(0) for _ in range(self.nl)]
self.anchor_grid = [torch.empty(0) for _ in range(self.nl)]
self.register_buffer('anchors', torch.tensor(anchors).float().view(self.nl, -1, 2))
self.m = nn.ModuleList(nn.Conv2d(x, self.no * self.na, 1) for x in ch)
def forward(self, x):
z = []
for i in range(self.nl):
x[i] = self.m[i](x[i])
bs, _, ny, nx = x[i].shape
x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
if self.grid[i].shape[2:4] != x[i].shape[2:4]:
self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
xy, wh, conf = x[i].sigmoid().split((2, 2, self.nc + 1), 4)
xy = (xy * 2 - 0.5 + self.grid[i]) * self.stride[i]
wh = (wh * 2) ** 2 * self.anchor_grid[i]
z.append(torch.cat((xy, wh, conf), 4).view(bs, self.na * nx * ny, self.no))
return (torch.cat(z, 1),)
def _make_grid(self, nx, ny, i):
d = self.anchors[i].device
t = self.anchors[i].dtype
shape = 1, self.na, ny, nx, 2
y, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)
yv, xv = torch.meshgrid(y, x, indexing='ij')
grid = torch.stack((xv, yv), 2).expand(shape) - 0.5
anchor_grid = (self.anchors[i] * self.stride[i]).view((1, self.na, 1, 1, 2)).expand(shape)
return grid, anchor_grid
class _Upsample(nn.Module):
"""Placeholder for upsample layers (no parameters, needed for indexing)."""
def __init__(self):
super().__init__()
self.up = nn.Upsample(None, 2, 'nearest')
def forward(self, x):
return self.up(x)
class _Concat(nn.Module):
"""Placeholder for concat layers (no parameters, needed for indexing)."""
def forward(self, x):
return torch.cat(x, 1)
# ---------------------------------------------------------------------------
# NFSI
# ---------------------------------------------------------------------------
def _nfsi(imgs, sigma=3.0, threshold_factor=2.0):
# FFT requires float32 (cuFFT doesn't support fp16 for non-power-of-2 sizes)
orig_dtype = imgs.dtype
imgs_f32 = imgs.float()
blur = transforms.GaussianBlur(3, sigma)
_, _, height, width = imgs_f32.shape
R = (height + width) // 8
yy, xx = torch.meshgrid(torch.arange(height), torch.arange(width), indexing="ij")
lpf = (((xx - (width - 1) / 2) ** 2 + (yy - (height - 1) / 2) ** 2) < R ** 2).float().to(imgs_f32.device)
im_copy = imgs_f32.clone()
mask_bg = torch.ones_like(imgs_f32)
f = torch.fft.fftn(im_copy, dim=(2, 3))
f = torch.roll(f, (height // 2, width // 2), dims=(2, 3))
f_l = torch.roll(f * lpf, (-height // 2, -width // 2), dims=(2, 3))
x_l = torch.abs(torch.fft.ifftn(f_l, dim=(2, 3))).clamp(0, 1).mean(dim=1)
mu, std = x_l.mean(dim=(1, 2)), x_l.std(dim=(1, 2))
for idx in range(x_l.shape[0]):
x_l[idx] = torch.where(
torch.abs(x_l[idx] - mu[idx]) > threshold_factor * std[idx],
torch.ones_like(x_l[idx]), torch.zeros_like(x_l[idx]))
mask = x_l.unsqueeze(1).repeat(1, 3, 1, 1)
result = blur(im_copy).clamp_(0, 1) * mask + imgs_f32 * (mask_bg - mask)
return result.to(orig_dtype)
# ---------------------------------------------------------------------------
# HuggingFace wrapper
# ---------------------------------------------------------------------------
# YOLOv5s layer structure (matching original state_dict keys model.0 - model.24)
# Layers 11, 12, 15, 16, 19, 22 are Upsample/Concat (no params)
_ANCHORS = [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]]
class NAPGuardPatchDetectorModel(PreTrainedModel):
"""NAPGuard YOLOv5s patch detector with state_dict key prefix `model.N.*`.
The nn.ModuleList index matches the original YOLOv5 layer numbering
so that state_dict keys align for from_pretrained loading.
"""
config_class = NAPGuardPatchDetectorConfig
supports_gradient_checkpointing = False
def __init__(self, config: NAPGuardPatchDetectorConfig) -> None:
super().__init__(config)
self._input_size = config.input_size
self._conf_thres = config.conf_thres
self._iou_thres = config.iou_thres
self._use_nfsi = config.use_nfsi
self._nfsi_sigma = config.nfsi_sigma
self._nfsi_threshold_factor = config.nfsi_threshold_factor
self._to_tensor = transforms.ToTensor()
# Build model as ModuleList to match state_dict key prefix model.N
self.model = nn.ModuleList([
Conv(3, 32, 6, 2, 2), # 0: P1/2
Conv(32, 64, 3, 2), # 1: P2/4
C3(64, 64, 1), # 2
Conv(64, 128, 3, 2), # 3: P3/8
C3(128, 128, 2), # 4
Conv(128, 256, 3, 2), # 5: P4/16
C3(256, 256, 3), # 6
Conv(256, 512, 3, 2), # 7: P5/32
C3(512, 512, 1), # 8
SPPF(512, 512, 5), # 9
Conv(512, 256, 1, 1), # 10
_Upsample(), # 11 (no params)
_Concat(), # 12 (no params)
C3(512, 256, 1, False), # 13
Conv(256, 128, 1, 1), # 14
_Upsample(), # 15 (no params)
_Concat(), # 16 (no params)
C3(256, 128, 1, False), # 17
Conv(128, 128, 3, 2), # 18
_Concat(), # 19 (no params)
C3(256, 256, 1, False), # 20
Conv(256, 256, 3, 2), # 21
_Concat(), # 22 (no params)
C3(512, 512, 1, False), # 23
Detect(1, _ANCHORS, (128, 256, 512)), # 24
])
# Set detect stride
self.model[24].stride = torch.tensor([8., 16., 32.])
# Save/restore indices for the PANet forward pass
self._save_indices = {4, 6, 9, 10, 13, 14, 17, 20}
def forward(self, pixel_values=None, **kwargs):
if "image" in kwargs:
return self.predict(**kwargs)
if pixel_values is None:
raise ValueError("Provide pixel_values or image=PIL")
return self._yolo_forward(pixel_values)
def _yolo_forward(self, x):
"""YOLOv5s forward with PANet skip connections."""
saved = {}
# Backbone: layers 0-9
for i in range(10):
x = self.model[i](x)
if i in self._save_indices:
saved[i] = x
# Neck: 10 β†’ upsample β†’ cat(P4) β†’ 13
x = self.model[10](x)
saved[10] = x
x = self.model[11](x) # upsample
x = self.model[12]([x, saved[6]]) # cat with P4
x = self.model[13](x)
saved[13] = x
# 14 β†’ upsample β†’ cat(P3) β†’ 17
x = self.model[14](x)
saved[14] = x
x = self.model[15](x) # upsample
x = self.model[16]([x, saved[4]]) # cat with P3
x = self.model[17](x)
det_small = x # P3 output
saved[17] = x
# 18 β†’ cat(layer 14 output) β†’ 20
x = self.model[18](x)
x = self.model[19]([x, saved[14]]) # cat with layer 14
x = self.model[20](x)
det_mid = x # P4 output
saved[20] = x
# 21 β†’ cat(10 output) β†’ 23
x = self.model[21](x)
x = self.model[22]([x, saved[10]]) # cat
x = self.model[23](x)
det_large = x # P5 output
# Detect
return self.model[24]([det_small, det_mid, det_large])
@torch.no_grad()
def predict(self, image: Image.Image, **kwargs) -> dict:
img = image.convert("RGB").resize(
(self._input_size, self._input_size), Image.Resampling.BILINEAR
)
tensor = self._to_tensor(img).unsqueeze(0).to(device=self.device, dtype=self.dtype)
if self._use_nfsi:
tensor = _nfsi(tensor, self._nfsi_sigma, self._nfsi_threshold_factor)
pred = self._yolo_forward(tensor)
prediction = pred[0] if isinstance(pred, tuple) else pred
if prediction.ndim == 3:
prediction = prediction[0]
# Filter by obj conf
xc = prediction[..., 4] > self._conf_thres
x = prediction[xc]
if x.shape[0] == 0:
return {"score": 0.0, "num_detections": 0}
# Combine obj conf * class conf
x[:, 5:] *= x[:, 4:5]
conf, cls = x[:, 5:].max(1, keepdim=True)
x = torch.cat([x[:, :4], conf, cls], dim=1)
x = x[x[:, 4] > self._conf_thres]
if x.shape[0] == 0:
return {"score": 0.0, "num_detections": 0}
# xywh β†’ xyxy
boxes = x[:, :4].clone()
boxes[:, 0] = x[:, 0] - x[:, 2] / 2
boxes[:, 1] = x[:, 1] - x[:, 3] / 2
boxes[:, 2] = x[:, 0] + x[:, 2] / 2
boxes[:, 3] = x[:, 1] + x[:, 3] / 2
keep = nms(boxes, x[:, 4], self._iou_thres)
x = x[keep]
return {"score": float(x[:, 4].max().item()), "num_detections": int(x.shape[0])}
@torch.no_grad()
def score_image(self, image: Image.Image, **kwargs) -> dict:
return self.predict(image=image, **kwargs)