FracAtlas-YOLACT / model.py
MuhammadAdil63's picture
deploy YOLACT+ fracture detection demo
fcec417
Raw
History Blame Contribute Delete
15.5 kB
"""
YOLACT+ with ResNet-18 Backbone
=================================
A faithful implementation of YOLACT+ adapted for a lightweight ResNet-18 backbone.
Architecture:
Backbone : ResNet-18 (torchvision, ImageNet pre-trained)
Neck : FPN (Feature Pyramid Network)
Head : PredictionHead (class + box + mask coefficient)
Proto : ProtoNet (generates prototype masks)
Mask : linear combination of prototypes Γ— coefficients
NMS : Fast NMS (YOLACT-style)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet18, ResNet18_Weights
from torchvision.ops import nms
# ─── Constants ────────────────────────────────────────────────────────────────
NUM_PROTOTYPES = 32
FPN_CHANNELS = 256
PROTO_CHANNELS = 256
# ─── Backbone ─────────────────────────────────────────────────────────────────
class ResNet18Backbone(nn.Module):
"""
ResNet-18 feature extractor.
Returns C3, C4, C5 feature maps (strides 8, 16, 32).
"""
def __init__(self, pretrained: bool = True):
super().__init__()
weights = ResNet18_Weights.IMAGENET1K_V1 if pretrained else None
base = resnet18(weights=weights)
self.layer0 = nn.Sequential(base.conv1, base.bn1, base.relu, base.maxpool)
self.layer1 = base.layer1 # stride 4, channels 64
self.layer2 = base.layer2 # stride 8, channels 128 β†’ C3
self.layer3 = base.layer3 # stride 16, channels 256 β†’ C4
self.layer4 = base.layer4 # stride 32, channels 512 β†’ C5
self.out_channels = [128, 256, 512] # C3, C4, C5
def forward(self, x):
x = self.layer0(x)
x = self.layer1(x)
c3 = self.layer2(x)
c4 = self.layer3(c3)
c5 = self.layer4(c4)
return c3, c4, c5
# ─── FPN ──────────────────────────────────────────────────────────────────────
class FPN(nn.Module):
"""
5-level FPN: P3–P7 (P6, P7 generated by strided convolution on P5).
"""
def __init__(self, in_channels: list, out_channels: int = FPN_CHANNELS):
super().__init__()
self.lateral_convs = nn.ModuleList(
[nn.Conv2d(c, out_channels, 1) for c in in_channels]
)
self.output_convs = nn.ModuleList(
[nn.Conv2d(out_channels, out_channels, 3, padding=1) for _ in in_channels]
)
# Extra levels P6, P7
self.p6_conv = nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1)
self.p7_conv = nn.Conv2d(out_channels, out_channels, 3, stride=2, padding=1)
def forward(self, features):
c3, c4, c5 = features
lat = [l(f) for l, f in zip(self.lateral_convs, [c3, c4, c5])]
# Top-down pathway
lat[1] = lat[1] + F.interpolate(lat[2], size=lat[1].shape[-2:], mode="nearest")
lat[0] = lat[0] + F.interpolate(lat[1], size=lat[0].shape[-2:], mode="nearest")
p3 = self.output_convs[0](lat[0])
p4 = self.output_convs[1](lat[1])
p5 = self.output_convs[2](lat[2])
p6 = self.p6_conv(p5)
p7 = self.p7_conv(F.relu(p6))
return [p3, p4, p5, p6, p7]
# ─── ProtoNet ─────────────────────────────────────────────────────────────────
class ProtoNet(nn.Module):
"""
Generates K prototype masks from the P3 feature map.
Output: [B, K, H/4, W/4]
"""
def __init__(self, in_channels: int = FPN_CHANNELS, num_protos: int = NUM_PROTOTYPES):
super().__init__()
self.proto_net = nn.Sequential(
nn.Conv2d(in_channels, PROTO_CHANNELS, 3, padding=1), nn.ReLU(),
nn.Conv2d(PROTO_CHANNELS, PROTO_CHANNELS, 3, padding=1), nn.ReLU(),
nn.Conv2d(PROTO_CHANNELS, PROTO_CHANNELS, 3, padding=1), nn.ReLU(),
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
nn.Conv2d(PROTO_CHANNELS, PROTO_CHANNELS, 3, padding=1), nn.ReLU(),
nn.Conv2d(PROTO_CHANNELS, num_protos, 1),
)
def forward(self, p3):
return self.proto_net(p3) # [B, K, H', W']
# ─── Anchor generator ─────────────────────────────────────────────────────────
class AnchorGenerator:
"""
Pre-computed anchor boxes for each FPN level.
Scales: [24, 48, 96, 192, 384] (for 550Γ—550 input)
Aspect ratios: [1.0, 0.5, 2.0]
"""
SCALES = [24, 48, 96, 192, 384]
ASPECT_RATIOS = [1.0, 0.5, 2.0]
def __init__(self, img_size: int = 550):
self.img_size = img_size
self.num_anchors_per_cell = len(self.ASPECT_RATIOS)
def make_anchors(self, feature_sizes: list) -> torch.Tensor:
"""Returns [total_anchors, 4] in cx/cy/w/h format (normalised 0-1)."""
all_anchors = []
for lvl, (fh, fw) in enumerate(feature_sizes):
scale = self.SCALES[lvl]
for row in range(fh):
for col in range(fw):
cx = (col + 0.5) / fw
cy = (row + 0.5) / fh
for ar in self.ASPECT_RATIOS:
w = scale * (ar ** 0.5) / self.img_size
h = scale / (ar ** 0.5) / self.img_size
all_anchors.append([cx, cy, w, h])
return torch.tensor(all_anchors, dtype=torch.float32)
# ─── Prediction Head ──────────────────────────────────────────────────────────
class PredictionHead(nn.Module):
"""
Shared prediction head applied to each FPN level.
Outputs:
cls_pred : [B, A, num_classes+1]
box_pred : [B, A, 4]
coef_pred : [B, A, K]
"""
def __init__(
self,
in_channels: int = FPN_CHANNELS,
num_classes: int = 2,
num_anchors: int = 3,
num_protos: int = NUM_PROTOTYPES,
):
super().__init__()
self.num_classes = num_classes
self.num_anchors = num_anchors
self.shared = nn.Sequential(
nn.Conv2d(in_channels, in_channels, 3, padding=1), nn.ReLU(),
nn.Conv2d(in_channels, in_channels, 3, padding=1), nn.ReLU(),
nn.Conv2d(in_channels, in_channels, 3, padding=1), nn.ReLU(),
nn.Conv2d(in_channels, in_channels, 3, padding=1), nn.ReLU(),
)
self.cls_layer = nn.Conv2d(in_channels, num_anchors * (num_classes + 1), 1)
self.box_layer = nn.Conv2d(in_channels, num_anchors * 4, 1)
self.coef_layer = nn.Conv2d(in_channels, num_anchors * num_protos, 1)
def forward(self, feat):
B, _, H, W = feat.shape
x = self.shared(feat)
cls = self.cls_layer(x) # [B, A*(C+1), H, W]
box = self.box_layer(x) # [B, A*4, H, W]
coef = self.coef_layer(x) # [B, A*K, H, W]
# Reshape to [B, H*W*A, ...]
A, C, K = self.num_anchors, self.num_classes, NUM_PROTOTYPES
cls = cls.permute(0, 2, 3, 1).contiguous().view(B, -1, C + 1)
box = box.permute(0, 2, 3, 1).contiguous().view(B, -1, 4)
coef = coef.permute(0, 2, 3, 1).contiguous().view(B, -1, K)
coef = torch.tanh(coef)
return cls, box, coef
# ─── YOLACT+ ──────────────────────────────────────────────────────────────────
class YOLACTPlus(nn.Module):
"""
YOLACT+ with ResNet-18 backbone.
Args:
num_classes : number of foreground classes (background added internally)
img_size : input image resolution (square, default 550)
pretrained : use ImageNet-pretrained ResNet-18
"""
def __init__(
self,
num_classes: int = 1,
img_size: int = 550,
pretrained: bool = True,
):
super().__init__()
self.num_classes = num_classes
self.img_size = img_size
self.backbone = ResNet18Backbone(pretrained=pretrained)
self.fpn = FPN(self.backbone.out_channels)
self.proto_net = ProtoNet(FPN_CHANNELS, NUM_PROTOTYPES)
self.head = PredictionHead(
FPN_CHANNELS, num_classes, len(AnchorGenerator.ASPECT_RATIOS), NUM_PROTOTYPES
)
self.anchor_gen = AnchorGenerator(img_size)
self._anchors = None # cached after first forward pass
# ── Forward ───────────────────────────────────────────────────────────────
def forward(self, images: torch.Tensor):
"""
Args:
images : [B, 3, H, W]
Returns (training mode):
{
"cls_pred" : [B, total_anchors, num_classes+1]
"box_pred" : [B, total_anchors, 4]
"coef_pred" : [B, total_anchors, K]
"proto_out" : [B, K, H', W']
"anchors" : [total_anchors, 4] (cx/cy/w/h, normalised)
}
"""
features = self.backbone(images)
fpn_feats = self.fpn(features)
proto_out = self.proto_net(fpn_feats[0]) # P3 β†’ prototypes
# Cache anchors (they depend only on feature map sizes)
if self._anchors is None or self._anchors.device != images.device:
feat_sizes = [(f.shape[2], f.shape[3]) for f in fpn_feats]
self._anchors = self.anchor_gen.make_anchors(feat_sizes).to(images.device)
cls_preds, box_preds, coef_preds = [], [], []
for feat in fpn_feats:
cls, box, coef = self.head(feat)
cls_preds.append(cls)
box_preds.append(box)
coef_preds.append(coef)
return {
"cls_pred": torch.cat(cls_preds, dim=1),
"box_pred": torch.cat(box_preds, dim=1),
"coef_pred": torch.cat(coef_preds, dim=1),
"proto_out": proto_out,
"anchors": self._anchors,
}
# ── Post-processing (inference) ───────────────────────────────────────────
@torch.no_grad()
def predict(
self,
images: torch.Tensor,
score_thresh: float = 0.3,
nms_thresh: float = 0.5,
top_k: int = 100,
) -> list:
"""
Run inference and return decoded predictions.
Returns:
List (one per image) of dicts:
boxes : [N, 4] x1y1x2y2 normalised
scores : [N]
labels : [N]
masks : [N, H, W] float binary masks (upsampled to input size)
"""
self.eval()
out = self.forward(images)
cls_pred = out["cls_pred"] # [B, A, C+1]
box_pred = out["box_pred"] # [B, A, 4]
coef_pred = out["coef_pred"] # [B, A, K]
proto = out["proto_out"] # [B, K, H', W']
anchors = out["anchors"] # [A, 4]
results = []
B = images.shape[0]
for i in range(B):
scores_all = torch.softmax(cls_pred[i], dim=-1) # [A, C+1]
scores, labels = scores_all[:, 1:].max(dim=-1) # foreground only
keep_mask = scores > score_thresh
if keep_mask.sum() == 0:
results.append({"boxes": torch.zeros(0, 4), "scores": torch.zeros(0),
"labels": torch.zeros(0, dtype=torch.long),
"masks": torch.zeros(0, self.img_size, self.img_size)})
continue
scores = scores[keep_mask]
labels = labels[keep_mask]
boxes_d = box_pred[i][keep_mask] # deltas
coefs = coef_pred[i][keep_mask] # [N, K]
anch = anchors[keep_mask] # [N, 4]
# Decode box deltas β†’ cx/cy/w/h
pred_cx = boxes_d[:, 0] * anch[:, 2] + anch[:, 0]
pred_cy = boxes_d[:, 1] * anch[:, 3] + anch[:, 1]
pred_w = torch.exp(boxes_d[:, 2]) * anch[:, 2]
pred_h = torch.exp(boxes_d[:, 3]) * anch[:, 3]
# β†’ x1y1x2y2
x1 = torch.clamp(pred_cx - pred_w / 2, 0, 1)
y1 = torch.clamp(pred_cy - pred_h / 2, 0, 1)
x2 = torch.clamp(pred_cx + pred_w / 2, 0, 1)
y2 = torch.clamp(pred_cy + pred_h / 2, 0, 1)
boxes_xyxy = torch.stack([x1, y1, x2, y2], dim=1)
# ── Filter out oversized boxes ────────────────────────────
# Remove boxes whose area exceeds 50% of the image area.
# These are almost always spurious full-image anchors.
box_w = boxes_xyxy[:, 2] - boxes_xyxy[:, 0]
box_h = boxes_xyxy[:, 3] - boxes_xyxy[:, 1]
box_area = box_w * box_h
size_mask = box_area < 0.50 # keep boxes < 50% of image area
boxes_xyxy = boxes_xyxy[size_mask]
scores = scores[size_mask]
labels = labels[size_mask]
coefs = coefs[size_mask]
if boxes_xyxy.shape[0] == 0:
results.append({"boxes": torch.zeros(0, 4), "scores": torch.zeros(0),
"labels": torch.zeros(0, dtype=torch.long),
"masks": torch.zeros(0, self.img_size, self.img_size)})
continue
# NMS (pixel-scale for torchvision nms)
scale = float(self.img_size)
keep = nms(boxes_xyxy * scale, scores, nms_thresh)
keep = keep[:top_k]
boxes_xyxy = boxes_xyxy[keep]
scores = scores[keep]
labels = labels[keep]
coefs = coefs[keep] # [N, K]
# Decode masks: proto [K, H', W'], coefs [N, K]
proto_i = proto[i] # [K, H', W']
K, pH, pW = proto_i.shape
proto_flat = proto_i.view(K, -1).T # [H'W', K]
mask_flat = torch.sigmoid(proto_flat @ coefs.T) # [H'W', N]
masks_raw = mask_flat.T.view(len(keep), pH, pW) # [N, H', W']
# Upsample to input resolution
masks_up = F.interpolate(
masks_raw.unsqueeze(0), size=(self.img_size, self.img_size),
mode="bilinear", align_corners=False
).squeeze(0)
masks_bin = (masks_up > 0.5).float()
results.append({
"boxes": boxes_xyxy.cpu(),
"scores": scores.cpu(),
"labels": labels.cpu(),
"masks": masks_bin.cpu(),
})
return results