Spaces:
Running
Running
| """ | |
| YOLACT+ with ResNet-18 Backbone | |
| ================================= | |
| A faithful implementation of YOLACT+ adapted for a lightweight ResNet-18 backbone. | |
| Architecture: | |
| Backbone : ResNet-18 (torchvision, ImageNet pre-trained) | |
| Neck : FPN (Feature Pyramid Network) | |
| Head : PredictionHead (class + box + mask coefficient) | |
| Proto : ProtoNet (generates prototype masks) | |
| Mask : linear combination of prototypes Γ coefficients | |
| NMS : Fast NMS (YOLACT-style) | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision.models import resnet18, ResNet18_Weights | |
| from torchvision.ops import nms | |
| # βββ Constants ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| NUM_PROTOTYPES = 32 | |
| FPN_CHANNELS = 256 | |
| PROTO_CHANNELS = 256 | |
| # βββ Backbone βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ResNet18Backbone(nn.Module): | |
| """ | |
| ResNet-18 feature extractor. | |
| Returns C3, C4, C5 feature maps (strides 8, 16, 32). | |
| """ | |
| def __init__(self, pretrained: bool = True): | |
| super().__init__() | |
| weights = ResNet18_Weights.IMAGENET1K_V1 if pretrained else None | |
| base = resnet18(weights=weights) | |
| self.layer0 = nn.Sequential(base.conv1, base.bn1, base.relu, base.maxpool) | |
| self.layer1 = base.layer1 # stride 4, channels 64 | |
| self.layer2 = base.layer2 # stride 8, channels 128 β C3 | |
| self.layer3 = base.layer3 # stride 16, channels 256 β C4 | |
| self.layer4 = base.layer4 # stride 32, channels 512 β C5 | |
| self.out_channels = [128, 256, 512] # C3, C4, C5 | |
| def forward(self, x): | |
| x = self.layer0(x) | |
| x = self.layer1(x) | |
| c3 = self.layer2(x) | |
| c4 = self.layer3(c3) | |
| c5 = self.layer4(c4) | |
| return c3, c4, c5 | |
| # βββ FPN ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class FPN(nn.Module): | |
| """ | |
| 5-level FPN: P3βP7 (P6, P7 generated by strided convolution on P5). | |
| """ | |
| def __init__(self, in_channels: list, out_channels: int = FPN_CHANNELS): | |
| super().__init__() | |
| self.lateral_convs = nn.ModuleList( | |
| [nn.Conv2d(c, out_channels, 1) for c in in_channels] | |
| ) | |
| self.output_convs = nn.ModuleList( | |
| [nn.Conv2d(out_channels, out_channels, 3, padding=1) for _ in in_channels] | |
| ) | |
| # Extra levels P6, P7 | |
| self.p6_conv = nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1) | |
| self.p7_conv = nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1) | |
| def forward(self, features): | |
| c3, c4, c5 = features | |
| lat = [l(f) for l, f in zip(self.lateral_convs, [c3, c4, c5])] | |
| # Top-down pathway | |
| lat[1] = lat[1] + F.interpolate(lat[2], size=lat[1].shape[-2:], mode="nearest") | |
| lat[0] = lat[0] + F.interpolate(lat[1], size=lat[0].shape[-2:], mode="nearest") | |
| p3 = self.output_convs[0](lat[0]) | |
| p4 = self.output_convs[1](lat[1]) | |
| p5 = self.output_convs[2](lat[2]) | |
| p6 = self.p6_conv(p5) | |
| p7 = self.p7_conv(F.relu(p6)) | |
| return [p3, p4, p5, p6, p7] | |
| # βββ ProtoNet βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class ProtoNet(nn.Module): | |
| """ | |
| Generates K prototype masks from the P3 feature map. | |
| Output: [B, K, H/4, W/4] | |
| """ | |
| def __init__(self, in_channels: int = FPN_CHANNELS, num_protos: int = NUM_PROTOTYPES): | |
| super().__init__() | |
| self.proto_net = nn.Sequential( | |
| nn.Conv2d(in_channels, PROTO_CHANNELS, 3, padding=1), nn.ReLU(), | |
| nn.Conv2d(PROTO_CHANNELS, PROTO_CHANNELS, 3, padding=1), nn.ReLU(), | |
| nn.Conv2d(PROTO_CHANNELS, PROTO_CHANNELS, 3, padding=1), nn.ReLU(), | |
| nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), | |
| nn.Conv2d(PROTO_CHANNELS, PROTO_CHANNELS, 3, padding=1), nn.ReLU(), | |
| nn.Conv2d(PROTO_CHANNELS, num_protos, 1), | |
| ) | |
| def forward(self, p3): | |
| return self.proto_net(p3) # [B, K, H', W'] | |
| # βββ Anchor generator βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class AnchorGenerator: | |
| """ | |
| Pre-computed anchor boxes for each FPN level. | |
| Scales: [24, 48, 96, 192, 384] (for 550Γ550 input) | |
| Aspect ratios: [1.0, 0.5, 2.0] | |
| """ | |
| SCALES = [24, 48, 96, 192, 384] | |
| ASPECT_RATIOS = [1.0, 0.5, 2.0] | |
| def __init__(self, img_size: int = 550): | |
| self.img_size = img_size | |
| self.num_anchors_per_cell = len(self.ASPECT_RATIOS) | |
| def make_anchors(self, feature_sizes: list) -> torch.Tensor: | |
| """Returns [total_anchors, 4] in cx/cy/w/h format (normalised 0-1).""" | |
| all_anchors = [] | |
| for lvl, (fh, fw) in enumerate(feature_sizes): | |
| scale = self.SCALES[lvl] | |
| for row in range(fh): | |
| for col in range(fw): | |
| cx = (col + 0.5) / fw | |
| cy = (row + 0.5) / fh | |
| for ar in self.ASPECT_RATIOS: | |
| w = scale * (ar ** 0.5) / self.img_size | |
| h = scale / (ar ** 0.5) / self.img_size | |
| all_anchors.append([cx, cy, w, h]) | |
| return torch.tensor(all_anchors, dtype=torch.float32) | |
| # βββ Prediction Head ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class PredictionHead(nn.Module): | |
| """ | |
| Shared prediction head applied to each FPN level. | |
| Outputs: | |
| cls_pred : [B, A, num_classes+1] | |
| box_pred : [B, A, 4] | |
| coef_pred : [B, A, K] | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int = FPN_CHANNELS, | |
| num_classes: int = 2, | |
| num_anchors: int = 3, | |
| num_protos: int = NUM_PROTOTYPES, | |
| ): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.num_anchors = num_anchors | |
| self.shared = nn.Sequential( | |
| nn.Conv2d(in_channels, in_channels, 3, padding=1), nn.ReLU(), | |
| nn.Conv2d(in_channels, in_channels, 3, padding=1), nn.ReLU(), | |
| nn.Conv2d(in_channels, in_channels, 3, padding=1), nn.ReLU(), | |
| nn.Conv2d(in_channels, in_channels, 3, padding=1), nn.ReLU(), | |
| ) | |
| self.cls_layer = nn.Conv2d(in_channels, num_anchors * (num_classes + 1), 1) | |
| self.box_layer = nn.Conv2d(in_channels, num_anchors * 4, 1) | |
| self.coef_layer = nn.Conv2d(in_channels, num_anchors * num_protos, 1) | |
| def forward(self, feat): | |
| B, _, H, W = feat.shape | |
| x = self.shared(feat) | |
| cls = self.cls_layer(x) # [B, A*(C+1), H, W] | |
| box = self.box_layer(x) # [B, A*4, H, W] | |
| coef = self.coef_layer(x) # [B, A*K, H, W] | |
| # Reshape to [B, H*W*A, ...] | |
| A, C, K = self.num_anchors, self.num_classes, NUM_PROTOTYPES | |
| cls = cls.permute(0, 2, 3, 1).contiguous().view(B, -1, C + 1) | |
| box = box.permute(0, 2, 3, 1).contiguous().view(B, -1, 4) | |
| coef = coef.permute(0, 2, 3, 1).contiguous().view(B, -1, K) | |
| coef = torch.tanh(coef) | |
| return cls, box, coef | |
| # βββ YOLACT+ ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class YOLACTPlus(nn.Module): | |
| """ | |
| YOLACT+ with ResNet-18 backbone. | |
| Args: | |
| num_classes : number of foreground classes (background added internally) | |
| img_size : input image resolution (square, default 550) | |
| pretrained : use ImageNet-pretrained ResNet-18 | |
| """ | |
| def __init__( | |
| self, | |
| num_classes: int = 1, | |
| img_size: int = 550, | |
| pretrained: bool = True, | |
| ): | |
| super().__init__() | |
| self.num_classes = num_classes | |
| self.img_size = img_size | |
| self.backbone = ResNet18Backbone(pretrained=pretrained) | |
| self.fpn = FPN(self.backbone.out_channels) | |
| self.proto_net = ProtoNet(FPN_CHANNELS, NUM_PROTOTYPES) | |
| self.head = PredictionHead( | |
| FPN_CHANNELS, num_classes, len(AnchorGenerator.ASPECT_RATIOS), NUM_PROTOTYPES | |
| ) | |
| self.anchor_gen = AnchorGenerator(img_size) | |
| self._anchors = None # cached after first forward pass | |
| # ββ Forward βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def forward(self, images: torch.Tensor): | |
| """ | |
| Args: | |
| images : [B, 3, H, W] | |
| Returns (training mode): | |
| { | |
| "cls_pred" : [B, total_anchors, num_classes+1] | |
| "box_pred" : [B, total_anchors, 4] | |
| "coef_pred" : [B, total_anchors, K] | |
| "proto_out" : [B, K, H', W'] | |
| "anchors" : [total_anchors, 4] (cx/cy/w/h, normalised) | |
| } | |
| """ | |
| features = self.backbone(images) | |
| fpn_feats = self.fpn(features) | |
| proto_out = self.proto_net(fpn_feats[0]) # P3 β prototypes | |
| # Cache anchors (they depend only on feature map sizes) | |
| if self._anchors is None or self._anchors.device != images.device: | |
| feat_sizes = [(f.shape[2], f.shape[3]) for f in fpn_feats] | |
| self._anchors = self.anchor_gen.make_anchors(feat_sizes).to(images.device) | |
| cls_preds, box_preds, coef_preds = [], [], [] | |
| for feat in fpn_feats: | |
| cls, box, coef = self.head(feat) | |
| cls_preds.append(cls) | |
| box_preds.append(box) | |
| coef_preds.append(coef) | |
| return { | |
| "cls_pred": torch.cat(cls_preds, dim=1), | |
| "box_pred": torch.cat(box_preds, dim=1), | |
| "coef_pred": torch.cat(coef_preds, dim=1), | |
| "proto_out": proto_out, | |
| "anchors": self._anchors, | |
| } | |
| # ββ Post-processing (inference) βββββββββββββββββββββββββββββββββββββββββββ | |
| def predict( | |
| self, | |
| images: torch.Tensor, | |
| score_thresh: float = 0.3, | |
| nms_thresh: float = 0.5, | |
| top_k: int = 100, | |
| ) -> list: | |
| """ | |
| Run inference and return decoded predictions. | |
| Returns: | |
| List (one per image) of dicts: | |
| boxes : [N, 4] x1y1x2y2 normalised | |
| scores : [N] | |
| labels : [N] | |
| masks : [N, H, W] float binary masks (upsampled to input size) | |
| """ | |
| self.eval() | |
| out = self.forward(images) | |
| cls_pred = out["cls_pred"] # [B, A, C+1] | |
| box_pred = out["box_pred"] # [B, A, 4] | |
| coef_pred = out["coef_pred"] # [B, A, K] | |
| proto = out["proto_out"] # [B, K, H', W'] | |
| anchors = out["anchors"] # [A, 4] | |
| results = [] | |
| B = images.shape[0] | |
| for i in range(B): | |
| scores_all = torch.softmax(cls_pred[i], dim=-1) # [A, C+1] | |
| scores, labels = scores_all[:, 1:].max(dim=-1) # foreground only | |
| keep_mask = scores > score_thresh | |
| if keep_mask.sum() == 0: | |
| results.append({"boxes": torch.zeros(0, 4), "scores": torch.zeros(0), | |
| "labels": torch.zeros(0, dtype=torch.long), | |
| "masks": torch.zeros(0, self.img_size, self.img_size)}) | |
| continue | |
| scores = scores[keep_mask] | |
| labels = labels[keep_mask] | |
| boxes_d = box_pred[i][keep_mask] # deltas | |
| coefs = coef_pred[i][keep_mask] # [N, K] | |
| anch = anchors[keep_mask] # [N, 4] | |
| # Decode box deltas β cx/cy/w/h | |
| pred_cx = boxes_d[:, 0] * anch[:, 2] + anch[:, 0] | |
| pred_cy = boxes_d[:, 1] * anch[:, 3] + anch[:, 1] | |
| pred_w = torch.exp(boxes_d[:, 2]) * anch[:, 2] | |
| pred_h = torch.exp(boxes_d[:, 3]) * anch[:, 3] | |
| # β x1y1x2y2 | |
| x1 = torch.clamp(pred_cx - pred_w / 2, 0, 1) | |
| y1 = torch.clamp(pred_cy - pred_h / 2, 0, 1) | |
| x2 = torch.clamp(pred_cx + pred_w / 2, 0, 1) | |
| y2 = torch.clamp(pred_cy + pred_h / 2, 0, 1) | |
| boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=1) | |
| # ββ Filter out oversized boxes ββββββββββββββββββββββββββββ | |
| # Remove boxes whose area exceeds 50% of the image area. | |
| # These are almost always spurious full-image anchors. | |
| box_w = boxes_xyxy[:, 2] - boxes_xyxy[:, 0] | |
| box_h = boxes_xyxy[:, 3] - boxes_xyxy[:, 1] | |
| box_area = box_w * box_h | |
| size_mask = box_area < 0.50 # keep boxes < 50% of image area | |
| boxes_xyxy = boxes_xyxy[size_mask] | |
| scores = scores[size_mask] | |
| labels = labels[size_mask] | |
| coefs = coefs[size_mask] | |
| if boxes_xyxy.shape[0] == 0: | |
| results.append({"boxes": torch.zeros(0, 4), "scores": torch.zeros(0), | |
| "labels": torch.zeros(0, dtype=torch.long), | |
| "masks": torch.zeros(0, self.img_size, self.img_size)}) | |
| continue | |
| # NMS (pixel-scale for torchvision nms) | |
| scale = float(self.img_size) | |
| keep = nms(boxes_xyxy * scale, scores, nms_thresh) | |
| keep = keep[:top_k] | |
| boxes_xyxy = boxes_xyxy[keep] | |
| scores = scores[keep] | |
| labels = labels[keep] | |
| coefs = coefs[keep] # [N, K] | |
| # Decode masks: proto [K, H', W'], coefs [N, K] | |
| proto_i = proto[i] # [K, H', W'] | |
| K, pH, pW = proto_i.shape | |
| proto_flat = proto_i.view(K, -1).T # [H'W', K] | |
| mask_flat = torch.sigmoid(proto_flat @ coefs.T) # [H'W', N] | |
| masks_raw = mask_flat.T.view(len(keep), pH, pW) # [N, H', W'] | |
| # Upsample to input resolution | |
| masks_up = F.interpolate( | |
| masks_raw.unsqueeze(0), size=(self.img_size, self.img_size), | |
| mode="bilinear", align_corners=False | |
| ).squeeze(0) | |
| masks_bin = (masks_up > 0.5).float() | |
| results.append({ | |
| "boxes": boxes_xyxy.cpu(), | |
| "scores": scores.cpu(), | |
| "labels": labels.cpu(), | |
| "masks": masks_bin.cpu(), | |
| }) | |
| return results | |