""" YOLACT+ with ResNet-18 Backbone ================================= A faithful implementation of YOLACT+ adapted for a lightweight ResNet-18 backbone. Architecture: Backbone : ResNet-18 (torchvision, ImageNet pre-trained) Neck : FPN (Feature Pyramid Network) Head : PredictionHead (class + box + mask coefficient) Proto : ProtoNet (generates prototype masks) Mask : linear combination of prototypes × coefficients NMS : Fast NMS (YOLACT-style) """ import torch import torch.nn as nn import torch.nn.functional as F from torchvision.models import resnet18, ResNet18_Weights from torchvision.ops import nms # ─── Constants ──────────────────────────────────────────────────────────────── NUM_PROTOTYPES = 32 FPN_CHANNELS = 256 PROTO_CHANNELS = 256 # ─── Backbone ───────────────────────────────────────────────────────────────── class ResNet18Backbone(nn.Module): """ ResNet-18 feature extractor. Returns C3, C4, C5 feature maps (strides 8, 16, 32). """ def __init__(self, pretrained: bool = True): super().__init__() weights = ResNet18_Weights.IMAGENET1K_V1 if pretrained else None base = resnet18(weights=weights) self.layer0 = nn.Sequential(base.conv1, base.bn1, base.relu, base.maxpool) self.layer1 = base.layer1 # stride 4, channels 64 self.layer2 = base.layer2 # stride 8, channels 128 → C3 self.layer3 = base.layer3 # stride 16, channels 256 → C4 self.layer4 = base.layer4 # stride 32, channels 512 → C5 self.out_channels = [128, 256, 512] # C3, C4, C5 def forward(self, x): x = self.layer0(x) x = self.layer1(x) c3 = self.layer2(x) c4 = self.layer3(c3) c5 = self.layer4(c4) return c3, c4, c5 # ─── FPN ────────────────────────────────────────────────────────────────────── class FPN(nn.Module): """ 5-level FPN: P3–P7 (P6, P7 generated by strided convolution on P5). """ def __init__(self, in_channels: list, out_channels: int = FPN_CHANNELS): super().__init__() self.lateral_convs = nn.ModuleList( [nn.Conv2d(c, out_channels, 1) for c in in_channels] ) self.output_convs = nn.ModuleList( [nn.Conv2d(out_channels, out_channels, 3, padding=1) for _ in in_channels] ) # Extra levels P6, P7 self.p6_conv = nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1) self.p7_conv = nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1) def forward(self, features): c3, c4, c5 = features lat = [l(f) for l, f in zip(self.lateral_convs, [c3, c4, c5])] # Top-down pathway lat[1] = lat[1] + F.interpolate(lat[2], size=lat[1].shape[-2:], mode="nearest") lat[0] = lat[0] + F.interpolate(lat[1], size=lat[0].shape[-2:], mode="nearest") p3 = self.output_convs[0](lat[0]) p4 = self.output_convs[1](lat[1]) p5 = self.output_convs[2](lat[2]) p6 = self.p6_conv(p5) p7 = self.p7_conv(F.relu(p6)) return [p3, p4, p5, p6, p7] # ─── ProtoNet ───────────────────────────────────────────────────────────────── class ProtoNet(nn.Module): """ Generates K prototype masks from the P3 feature map. Output: [B, K, H/4, W/4] """ def __init__(self, in_channels: int = FPN_CHANNELS, num_protos: int = NUM_PROTOTYPES): super().__init__() self.proto_net = nn.Sequential( nn.Conv2d(in_channels, PROTO_CHANNELS, 3, padding=1), nn.ReLU(), nn.Conv2d(PROTO_CHANNELS, PROTO_CHANNELS, 3, padding=1), nn.ReLU(), nn.Conv2d(PROTO_CHANNELS, PROTO_CHANNELS, 3, padding=1), nn.ReLU(), nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), nn.Conv2d(PROTO_CHANNELS, PROTO_CHANNELS, 3, padding=1), nn.ReLU(), nn.Conv2d(PROTO_CHANNELS, num_protos, 1), ) def forward(self, p3): return self.proto_net(p3) # [B, K, H', W'] # ─── Anchor generator ───────────────────────────────────────────────────────── class AnchorGenerator: """ Pre-computed anchor boxes for each FPN level. Scales: [24, 48, 96, 192, 384] (for 550×550 input) Aspect ratios: [1.0, 0.5, 2.0] """ SCALES = [24, 48, 96, 192, 384] ASPECT_RATIOS = [1.0, 0.5, 2.0] def __init__(self, img_size: int = 550): self.img_size = img_size self.num_anchors_per_cell = len(self.ASPECT_RATIOS) def make_anchors(self, feature_sizes: list) -> torch.Tensor: """Returns [total_anchors, 4] in cx/cy/w/h format (normalised 0-1).""" all_anchors = [] for lvl, (fh, fw) in enumerate(feature_sizes): scale = self.SCALES[lvl] for row in range(fh): for col in range(fw): cx = (col + 0.5) / fw cy = (row + 0.5) / fh for ar in self.ASPECT_RATIOS: w = scale * (ar ** 0.5) / self.img_size h = scale / (ar ** 0.5) / self.img_size all_anchors.append([cx, cy, w, h]) return torch.tensor(all_anchors, dtype=torch.float32) # ─── Prediction Head ────────────────────────────────────────────────────────── class PredictionHead(nn.Module): """ Shared prediction head applied to each FPN level. Outputs: cls_pred : [B, A, num_classes+1] box_pred : [B, A, 4] coef_pred : [B, A, K] """ def __init__( self, in_channels: int = FPN_CHANNELS, num_classes: int = 2, num_anchors: int = 3, num_protos: int = NUM_PROTOTYPES, ): super().__init__() self.num_classes = num_classes self.num_anchors = num_anchors self.shared = nn.Sequential( nn.Conv2d(in_channels, in_channels, 3, padding=1), nn.ReLU(), nn.Conv2d(in_channels, in_channels, 3, padding=1), nn.ReLU(), nn.Conv2d(in_channels, in_channels, 3, padding=1), nn.ReLU(), nn.Conv2d(in_channels, in_channels, 3, padding=1), nn.ReLU(), ) self.cls_layer = nn.Conv2d(in_channels, num_anchors * (num_classes + 1), 1) self.box_layer = nn.Conv2d(in_channels, num_anchors * 4, 1) self.coef_layer = nn.Conv2d(in_channels, num_anchors * num_protos, 1) def forward(self, feat): B, _, H, W = feat.shape x = self.shared(feat) cls = self.cls_layer(x) # [B, A*(C+1), H, W] box = self.box_layer(x) # [B, A*4, H, W] coef = self.coef_layer(x) # [B, A*K, H, W] # Reshape to [B, H*W*A, ...] A, C, K = self.num_anchors, self.num_classes, NUM_PROTOTYPES cls = cls.permute(0, 2, 3, 1).contiguous().view(B, -1, C + 1) box = box.permute(0, 2, 3, 1).contiguous().view(B, -1, 4) coef = coef.permute(0, 2, 3, 1).contiguous().view(B, -1, K) coef = torch.tanh(coef) return cls, box, coef # ─── YOLACT+ ────────────────────────────────────────────────────────────────── class YOLACTPlus(nn.Module): """ YOLACT+ with ResNet-18 backbone. Args: num_classes : number of foreground classes (background added internally) img_size : input image resolution (square, default 550) pretrained : use ImageNet-pretrained ResNet-18 """ def __init__( self, num_classes: int = 1, img_size: int = 550, pretrained: bool = True, ): super().__init__() self.num_classes = num_classes self.img_size = img_size self.backbone = ResNet18Backbone(pretrained=pretrained) self.fpn = FPN(self.backbone.out_channels) self.proto_net = ProtoNet(FPN_CHANNELS, NUM_PROTOTYPES) self.head = PredictionHead( FPN_CHANNELS, num_classes, len(AnchorGenerator.ASPECT_RATIOS), NUM_PROTOTYPES ) self.anchor_gen = AnchorGenerator(img_size) self._anchors = None # cached after first forward pass # ── Forward ─────────────────────────────────────────────────────────────── def forward(self, images: torch.Tensor): """ Args: images : [B, 3, H, W] Returns (training mode): { "cls_pred" : [B, total_anchors, num_classes+1] "box_pred" : [B, total_anchors, 4] "coef_pred" : [B, total_anchors, K] "proto_out" : [B, K, H', W'] "anchors" : [total_anchors, 4] (cx/cy/w/h, normalised) } """ features = self.backbone(images) fpn_feats = self.fpn(features) proto_out = self.proto_net(fpn_feats[0]) # P3 → prototypes # Cache anchors (they depend only on feature map sizes) if self._anchors is None or self._anchors.device != images.device: feat_sizes = [(f.shape[2], f.shape[3]) for f in fpn_feats] self._anchors = self.anchor_gen.make_anchors(feat_sizes).to(images.device) cls_preds, box_preds, coef_preds = [], [], [] for feat in fpn_feats: cls, box, coef = self.head(feat) cls_preds.append(cls) box_preds.append(box) coef_preds.append(coef) return { "cls_pred": torch.cat(cls_preds, dim=1), "box_pred": torch.cat(box_preds, dim=1), "coef_pred": torch.cat(coef_preds, dim=1), "proto_out": proto_out, "anchors": self._anchors, } # ── Post-processing (inference) ─────────────────────────────────────────── @torch.no_grad() def predict( self, images: torch.Tensor, score_thresh: float = 0.3, nms_thresh: float = 0.5, top_k: int = 100, ) -> list: """ Run inference and return decoded predictions. Returns: List (one per image) of dicts: boxes : [N, 4] x1y1x2y2 normalised scores : [N] labels : [N] masks : [N, H, W] float binary masks (upsampled to input size) """ self.eval() out = self.forward(images) cls_pred = out["cls_pred"] # [B, A, C+1] box_pred = out["box_pred"] # [B, A, 4] coef_pred = out["coef_pred"] # [B, A, K] proto = out["proto_out"] # [B, K, H', W'] anchors = out["anchors"] # [A, 4] results = [] B = images.shape[0] for i in range(B): scores_all = torch.softmax(cls_pred[i], dim=-1) # [A, C+1] scores, labels = scores_all[:, 1:].max(dim=-1) # foreground only keep_mask = scores > score_thresh if keep_mask.sum() == 0: results.append({"boxes": torch.zeros(0, 4), "scores": torch.zeros(0), "labels": torch.zeros(0, dtype=torch.long), "masks": torch.zeros(0, self.img_size, self.img_size)}) continue scores = scores[keep_mask] labels = labels[keep_mask] boxes_d = box_pred[i][keep_mask] # deltas coefs = coef_pred[i][keep_mask] # [N, K] anch = anchors[keep_mask] # [N, 4] # Decode box deltas → cx/cy/w/h pred_cx = boxes_d[:, 0] * anch[:, 2] + anch[:, 0] pred_cy = boxes_d[:, 1] * anch[:, 3] + anch[:, 1] pred_w = torch.exp(boxes_d[:, 2]) * anch[:, 2] pred_h = torch.exp(boxes_d[:, 3]) * anch[:, 3] # → x1y1x2y2 x1 = torch.clamp(pred_cx - pred_w / 2, 0, 1) y1 = torch.clamp(pred_cy - pred_h / 2, 0, 1) x2 = torch.clamp(pred_cx + pred_w / 2, 0, 1) y2 = torch.clamp(pred_cy + pred_h / 2, 0, 1) boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=1) # ── Filter out oversized boxes ──────────────────────────── # Remove boxes whose area exceeds 50% of the image area. # These are almost always spurious full-image anchors. box_w = boxes_xyxy[:, 2] - boxes_xyxy[:, 0] box_h = boxes_xyxy[:, 3] - boxes_xyxy[:, 1] box_area = box_w * box_h size_mask = box_area < 0.50 # keep boxes < 50% of image area boxes_xyxy = boxes_xyxy[size_mask] scores = scores[size_mask] labels = labels[size_mask] coefs = coefs[size_mask] if boxes_xyxy.shape[0] == 0: results.append({"boxes": torch.zeros(0, 4), "scores": torch.zeros(0), "labels": torch.zeros(0, dtype=torch.long), "masks": torch.zeros(0, self.img_size, self.img_size)}) continue # NMS (pixel-scale for torchvision nms) scale = float(self.img_size) keep = nms(boxes_xyxy * scale, scores, nms_thresh) keep = keep[:top_k] boxes_xyxy = boxes_xyxy[keep] scores = scores[keep] labels = labels[keep] coefs = coefs[keep] # [N, K] # Decode masks: proto [K, H', W'], coefs [N, K] proto_i = proto[i] # [K, H', W'] K, pH, pW = proto_i.shape proto_flat = proto_i.view(K, -1).T # [H'W', K] mask_flat = torch.sigmoid(proto_flat @ coefs.T) # [H'W', N] masks_raw = mask_flat.T.view(len(keep), pH, pW) # [N, H', W'] # Upsample to input resolution masks_up = F.interpolate( masks_raw.unsqueeze(0), size=(self.img_size, self.img_size), mode="bilinear", align_corners=False ).squeeze(0) masks_bin = (masks_up > 0.5).float() results.append({ "boxes": boxes_xyxy.cpu(), "scores": scores.cpu(), "labels": labels.cpu(), "masks": masks_bin.cpu(), }) return results