|
|
"""
|
|
|
Hybrid TransMIL + Query2Label Architecture
|
|
|
|
|
|
Combines:
|
|
|
- TransMIL's instance-level feature aggregation (with Nystrom attention)
|
|
|
- Query2Label's learnable label queries with cross-attention decoder
|
|
|
- End-to-end training with ResNet-50 backbone
|
|
|
|
|
|
Key Innovation: Extract sequence features from TransMIL BEFORE CLS aggregation,
|
|
|
allowing Q2L label queries to cross-attend across all ultrasound images per case.
|
|
|
"""
|
|
|
|
|
|
import os
|
|
|
import sys
|
|
|
|
|
|
|
|
|
_models_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
if _models_dir not in sys.path:
|
|
|
sys.path.insert(0, _models_dir)
|
|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import torchvision
|
|
|
import numpy as np
|
|
|
from torch.utils.checkpoint import checkpoint_sequential
|
|
|
|
|
|
|
|
|
from nystrom_attention import NystromAttention
|
|
|
|
|
|
|
|
|
try:
|
|
|
from models.transformer import TransformerDecoder, TransformerDecoderLayer
|
|
|
except ImportError:
|
|
|
try:
|
|
|
from transformer import TransformerDecoder, TransformerDecoderLayer
|
|
|
except ImportError:
|
|
|
print("Warning: Could not import Q2L Transformer components.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransLayer(nn.Module):
|
|
|
"""Transformer layer with Nystrom attention (from TransMIL)"""
|
|
|
|
|
|
def __init__(self, norm_layer=nn.LayerNorm, dim=512):
|
|
|
super().__init__()
|
|
|
self.norm = norm_layer(dim)
|
|
|
self.attn = NystromAttention(
|
|
|
dim=dim,
|
|
|
dim_head=dim // 8,
|
|
|
heads=8,
|
|
|
num_landmarks=dim // 2,
|
|
|
pinv_iterations=6,
|
|
|
residual=True,
|
|
|
dropout=0.1
|
|
|
)
|
|
|
|
|
|
def forward(self, x):
|
|
|
x = x + self.attn(self.norm(x))
|
|
|
return x
|
|
|
|
|
|
|
|
|
class TransMILFeatureExtractor(nn.Module):
|
|
|
"""
|
|
|
Modified TransMIL that outputs sequence features instead of aggregated CLS token.
|
|
|
|
|
|
Based on TransMIL.py but extracts features BEFORE CLS aggregation (line 83 output).
|
|
|
Uses learned 1D position encoding instead of PPEG for simplicity.
|
|
|
|
|
|
Args:
|
|
|
input_dim: Dimension of input features (2048 for ResNet-50)
|
|
|
hidden_dim: Dimension of hidden features (512 default)
|
|
|
use_ppeg: Whether to use PPEG (2D positional encoding) or learned 1D encoding
|
|
|
max_seq_len: Maximum sequence length for position encoding
|
|
|
"""
|
|
|
|
|
|
def __init__(self, input_dim=2048, hidden_dim=512, use_ppeg=False, max_seq_len=1024):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.fc1 = nn.Sequential(nn.Linear(input_dim, hidden_dim), nn.ReLU())
|
|
|
|
|
|
|
|
|
self.cls_token = nn.Parameter(torch.randn(1, 1, hidden_dim))
|
|
|
|
|
|
|
|
|
self.layer1 = TransLayer(dim=hidden_dim)
|
|
|
self.layer2 = TransLayer(dim=hidden_dim)
|
|
|
|
|
|
|
|
|
self.norm = nn.LayerNorm(hidden_dim)
|
|
|
|
|
|
|
|
|
self.use_ppeg = use_ppeg
|
|
|
if not use_ppeg:
|
|
|
|
|
|
self.pos_embedding = nn.Parameter(torch.randn(1, max_seq_len, hidden_dim))
|
|
|
else:
|
|
|
|
|
|
self.pos_layer = PPEG(dim=hidden_dim)
|
|
|
|
|
|
def forward(self, features, mask=None):
|
|
|
"""
|
|
|
Args:
|
|
|
features: [B, N, input_dim] - Instance features (e.g., from ResNet-50)
|
|
|
mask: [B, N] - Padding mask (True = valid instance, False = padded)
|
|
|
|
|
|
Returns:
|
|
|
seq_features: [B, 1+N, hidden_dim] - Sequence features (CLS + instances)
|
|
|
attn_mask: [B, 1+N] - Attention mask for decoder
|
|
|
"""
|
|
|
B, N, _ = features.shape
|
|
|
|
|
|
|
|
|
h = self.fc1(features)
|
|
|
|
|
|
|
|
|
if self.use_ppeg:
|
|
|
|
|
|
H = h.shape[1]
|
|
|
_H, _W = int(np.ceil(np.sqrt(H))), int(np.ceil(np.sqrt(H)))
|
|
|
add_length = _H * _W - H
|
|
|
if add_length > 0:
|
|
|
h = torch.cat([h, h[:, :add_length, :]], dim=1)
|
|
|
|
|
|
|
|
|
if mask is not None:
|
|
|
pad_mask = torch.zeros(B, add_length, dtype=torch.bool, device=mask.device)
|
|
|
mask = torch.cat([mask, pad_mask], dim=1)
|
|
|
|
|
|
|
|
|
cls_tokens = self.cls_token.expand(B, -1, -1)
|
|
|
h = torch.cat([cls_tokens, h], dim=1)
|
|
|
|
|
|
|
|
|
if mask is not None:
|
|
|
cls_mask = torch.ones(B, 1, dtype=torch.bool, device=mask.device)
|
|
|
attn_mask = torch.cat([cls_mask, mask], dim=1)
|
|
|
else:
|
|
|
attn_mask = torch.ones(B, h.shape[1], dtype=torch.bool, device=h.device)
|
|
|
|
|
|
|
|
|
h = self.layer1(h)
|
|
|
|
|
|
|
|
|
if self.use_ppeg:
|
|
|
|
|
|
h = self.pos_layer(h, _H, _W)
|
|
|
else:
|
|
|
|
|
|
seq_len = h.shape[1]
|
|
|
h = h + self.pos_embedding[:, :seq_len, :]
|
|
|
|
|
|
|
|
|
h = self.layer2(h)
|
|
|
|
|
|
|
|
|
h = self.norm(h)
|
|
|
|
|
|
|
|
|
return h, attn_mask
|
|
|
|
|
|
|
|
|
class PPEG(nn.Module):
|
|
|
"""
|
|
|
Position-aware Patch Embedding Generator (from TransMIL)
|
|
|
Uses 2D depthwise convolutions to inject spatial positional information.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, dim=512):
|
|
|
super().__init__()
|
|
|
self.proj = nn.Conv2d(dim, dim, 7, 1, 7 // 2, groups=dim)
|
|
|
self.proj1 = nn.Conv2d(dim, dim, 5, 1, 5 // 2, groups=dim)
|
|
|
self.proj2 = nn.Conv2d(dim, dim, 3, 1, 3 // 2, groups=dim)
|
|
|
|
|
|
def forward(self, x, H, W):
|
|
|
"""
|
|
|
Args:
|
|
|
x: [B, 1+N, C] - Token sequence (CLS + instances)
|
|
|
H, W: Grid dimensions (H * W >= N)
|
|
|
"""
|
|
|
B, _, C = x.shape
|
|
|
|
|
|
|
|
|
cls_token, feat_token = x[:, 0], x[:, 1:]
|
|
|
|
|
|
|
|
|
cnn_feat = feat_token.transpose(1, 2).view(B, C, H, W)
|
|
|
|
|
|
|
|
|
x = self.proj(cnn_feat) + cnn_feat + self.proj1(cnn_feat) + self.proj2(cnn_feat)
|
|
|
|
|
|
|
|
|
x = x.flatten(2).transpose(1, 2)
|
|
|
|
|
|
|
|
|
x = torch.cat((cls_token.unsqueeze(1), x), dim=1)
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GroupWiseLinear(nn.Module):
|
|
|
"""
|
|
|
Group-wise linear layer for per-class classification (from Q2L).
|
|
|
Applies a separate linear transformation for each class.
|
|
|
"""
|
|
|
|
|
|
def __init__(self, num_class, hidden_dim, bias=True):
|
|
|
super().__init__()
|
|
|
self.num_class = num_class
|
|
|
self.hidden_dim = hidden_dim
|
|
|
self.bias = bias
|
|
|
|
|
|
self.W = nn.Parameter(torch.Tensor(1, num_class, hidden_dim))
|
|
|
if bias:
|
|
|
self.b = nn.Parameter(torch.Tensor(1, num_class))
|
|
|
self.reset_parameters()
|
|
|
|
|
|
def reset_parameters(self):
|
|
|
import math
|
|
|
stdv = 1. / math.sqrt(self.W.size(2))
|
|
|
for i in range(self.num_class):
|
|
|
self.W[0][i].data.uniform_(-stdv, stdv)
|
|
|
if self.bias:
|
|
|
for i in range(self.num_class):
|
|
|
self.b[0][i].data.uniform_(-stdv, stdv)
|
|
|
|
|
|
def forward(self, x):
|
|
|
"""
|
|
|
Args:
|
|
|
x: [B, num_class, hidden_dim]
|
|
|
Returns:
|
|
|
logits: [B, num_class]
|
|
|
"""
|
|
|
|
|
|
x = (self.W * x).sum(-1)
|
|
|
if self.bias:
|
|
|
x = x + self.b
|
|
|
return x
|
|
|
|
|
|
|
|
|
class HybridQuery2Label(nn.Module):
|
|
|
"""
|
|
|
Query2Label decoder adapted for sequence inputs (not spatial features).
|
|
|
|
|
|
Uses learnable label queries to cross-attend to instance sequence from TransMIL.
|
|
|
Based on query2label.py but modified to accept [B, 1+N, hidden_dim] sequences
|
|
|
instead of [B, C, H, W] spatial features.
|
|
|
|
|
|
Args:
|
|
|
num_class: Number of label classes
|
|
|
hidden_dim: Dimension of features (512)
|
|
|
nheads: Number of attention heads
|
|
|
num_decoder_layers: Number of transformer decoder layers
|
|
|
dim_feedforward: Dimension of feedforward network
|
|
|
dropout: Dropout rate
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
num_class,
|
|
|
hidden_dim=512,
|
|
|
nheads=8,
|
|
|
num_decoder_layers=2,
|
|
|
dim_feedforward=2048,
|
|
|
dropout=0.1,
|
|
|
normalize_before=False
|
|
|
):
|
|
|
super().__init__()
|
|
|
self.num_class = num_class
|
|
|
self.hidden_dim = hidden_dim
|
|
|
|
|
|
|
|
|
self.query_embed = nn.Embedding(num_class, hidden_dim)
|
|
|
|
|
|
|
|
|
decoder_layer = TransformerDecoderLayer(
|
|
|
d_model=hidden_dim,
|
|
|
nhead=nheads,
|
|
|
dim_feedforward=dim_feedforward,
|
|
|
dropout=dropout,
|
|
|
normalize_before=normalize_before
|
|
|
)
|
|
|
decoder_norm = nn.LayerNorm(hidden_dim)
|
|
|
self.decoder = TransformerDecoder(
|
|
|
decoder_layer,
|
|
|
num_decoder_layers,
|
|
|
decoder_norm,
|
|
|
return_intermediate=False
|
|
|
)
|
|
|
|
|
|
|
|
|
self.fc = GroupWiseLinear(num_class, hidden_dim, bias=True)
|
|
|
|
|
|
def forward(self, sequence_features, memory_key_padding_mask=None):
|
|
|
"""
|
|
|
Args:
|
|
|
sequence_features: [B, 1+N, hidden_dim] - Sequence from TransMIL
|
|
|
memory_key_padding_mask: [B, 1+N] - Padding mask (True = ignore, False = valid)
|
|
|
NOTE: PyTorch convention is inverted!
|
|
|
|
|
|
Returns:
|
|
|
logits: [B, num_class] - Multi-label classification logits
|
|
|
"""
|
|
|
B = sequence_features.shape[0]
|
|
|
|
|
|
|
|
|
memory = sequence_features.permute(1, 0, 2)
|
|
|
|
|
|
|
|
|
query_embed = self.query_embed.weight
|
|
|
query_embed = query_embed.unsqueeze(1).repeat(1, B, 1)
|
|
|
|
|
|
|
|
|
tgt = torch.zeros_like(query_embed)
|
|
|
|
|
|
|
|
|
|
|
|
hs = self.decoder(
|
|
|
tgt=tgt,
|
|
|
memory=memory,
|
|
|
memory_key_padding_mask=memory_key_padding_mask,
|
|
|
pos=None,
|
|
|
query_pos=query_embed
|
|
|
)
|
|
|
|
|
|
|
|
|
if hs.dim() == 4:
|
|
|
hs = hs[-1]
|
|
|
|
|
|
|
|
|
hs = hs.permute(1, 0, 2)
|
|
|
|
|
|
|
|
|
logits = self.fc(hs)
|
|
|
|
|
|
return logits
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ResNet50Backbone(nn.Module):
|
|
|
"""
|
|
|
ResNet-50 feature extractor with Global Average Pooling.
|
|
|
|
|
|
Extracts 2048-dimensional features from images for TransMIL input.
|
|
|
Supports gradient checkpointing for memory efficiency.
|
|
|
|
|
|
Args:
|
|
|
pretrained: Use ImageNet pre-trained weights
|
|
|
use_checkpointing: Enable gradient checkpointing (saves memory)
|
|
|
"""
|
|
|
|
|
|
def __init__(self, pretrained=True, use_checkpointing=False):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
resnet = torchvision.models.resnet50(pretrained=pretrained)
|
|
|
|
|
|
|
|
|
|
|
|
self.features = nn.Sequential(*list(resnet.children())[:-2])
|
|
|
|
|
|
|
|
|
self.gap = nn.AdaptiveAvgPool2d(1)
|
|
|
|
|
|
self.use_checkpointing = use_checkpointing
|
|
|
|
|
|
def forward(self, images):
|
|
|
"""
|
|
|
Args:
|
|
|
images: [B*N, 3, H, W] - Batch of images (flattened across cases)
|
|
|
|
|
|
Returns:
|
|
|
features: [B*N, 2048] - Instance features
|
|
|
"""
|
|
|
if self.training and self.use_checkpointing:
|
|
|
|
|
|
|
|
|
x = checkpoint_sequential(self.features, segments=4, input=images)
|
|
|
else:
|
|
|
x = self.features(images)
|
|
|
|
|
|
x = self.gap(x)
|
|
|
x = x.flatten(1)
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TransMIL_Query2Label_E2E(nn.Module):
|
|
|
"""
|
|
|
Complete end-to-end model: Images → ResNet-50 → TransMIL → Q2L → Logits
|
|
|
|
|
|
Pipeline:
|
|
|
1. ResNet-50 extracts features from each ultrasound image
|
|
|
2. TransMIL aggregates variable-length instance sequences with attention
|
|
|
3. Query2Label decoder performs multi-label classification via cross-attention
|
|
|
|
|
|
Args:
|
|
|
num_class: Number of label classes (default 30)
|
|
|
hidden_dim: Hidden dimension for TransMIL and Q2L (default 512)
|
|
|
nheads: Number of attention heads in Q2L decoder
|
|
|
num_decoder_layers: Number of Q2L decoder layers
|
|
|
pretrained_resnet: Use ImageNet pre-trained ResNet-50
|
|
|
use_checkpointing: Enable gradient checkpointing for ResNet-50
|
|
|
use_ppeg: Use PPEG position encoding (vs learned 1D)
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
num_class=30,
|
|
|
hidden_dim=512,
|
|
|
nheads=8,
|
|
|
num_decoder_layers=2,
|
|
|
pretrained_resnet=True,
|
|
|
use_checkpointing=False,
|
|
|
use_ppeg=False
|
|
|
):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.backbone = ResNet50Backbone(
|
|
|
pretrained=pretrained_resnet,
|
|
|
use_checkpointing=use_checkpointing
|
|
|
)
|
|
|
|
|
|
|
|
|
self.feature_extractor = TransMILFeatureExtractor(
|
|
|
input_dim=2048,
|
|
|
hidden_dim=hidden_dim,
|
|
|
use_ppeg=use_ppeg
|
|
|
)
|
|
|
|
|
|
|
|
|
self.q2l_decoder = HybridQuery2Label(
|
|
|
num_class=num_class,
|
|
|
hidden_dim=hidden_dim,
|
|
|
nheads=nheads,
|
|
|
num_decoder_layers=num_decoder_layers
|
|
|
)
|
|
|
|
|
|
def forward(self, images, num_instances_per_case):
|
|
|
"""
|
|
|
Args:
|
|
|
images: [B*N_total, 3, H, W] - All images flattened across batch
|
|
|
num_instances_per_case: [B] or list - Number of images per case
|
|
|
|
|
|
Returns:
|
|
|
logits: [B, num_class] - Multi-label classification logits
|
|
|
"""
|
|
|
|
|
|
if isinstance(num_instances_per_case, list):
|
|
|
num_instances_per_case = torch.tensor(num_instances_per_case, device=images.device)
|
|
|
|
|
|
B = len(num_instances_per_case)
|
|
|
|
|
|
|
|
|
all_features = self.backbone(images)
|
|
|
|
|
|
|
|
|
max_N = int(num_instances_per_case.max().item())
|
|
|
features_padded = torch.zeros(B, max_N, 2048, device=images.device)
|
|
|
masks = torch.zeros(B, max_N, dtype=torch.bool, device=images.device)
|
|
|
|
|
|
idx = 0
|
|
|
for i, n in enumerate(num_instances_per_case):
|
|
|
n = int(n.item()) if torch.is_tensor(n) else int(n)
|
|
|
features_padded[i, :n] = all_features[idx:idx+n]
|
|
|
masks[i, :n] = True
|
|
|
idx += n
|
|
|
|
|
|
|
|
|
seq_features, attn_mask = self.feature_extractor(features_padded, masks)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
decoder_mask = ~attn_mask
|
|
|
logits = self.q2l_decoder(seq_features, memory_key_padding_mask=decoder_mask)
|
|
|
|
|
|
|
|
|
return logits
|
|
|
|
|
|
def freeze_backbone(self):
|
|
|
"""Freeze ResNet-50 backbone for training only TransMIL+Q2L"""
|
|
|
for param in self.backbone.parameters():
|
|
|
param.requires_grad = False
|
|
|
|
|
|
def unfreeze_backbone(self):
|
|
|
"""Unfreeze ResNet-50 for end-to-end fine-tuning"""
|
|
|
for param in self.backbone.parameters():
|
|
|
param.requires_grad = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
print("Testing TransMIL_Query2Label_E2E model...")
|
|
|
|
|
|
|
|
|
num_class = 30
|
|
|
batch_size = 2
|
|
|
num_instances = [8, 12]
|
|
|
img_size = 224
|
|
|
|
|
|
|
|
|
model = TransMIL_Query2Label_E2E(
|
|
|
num_class=num_class,
|
|
|
hidden_dim=512,
|
|
|
nheads=8,
|
|
|
num_decoder_layers=2,
|
|
|
pretrained_resnet=False,
|
|
|
use_checkpointing=False,
|
|
|
use_ppeg=False
|
|
|
)
|
|
|
|
|
|
|
|
|
total_images = sum(num_instances)
|
|
|
images = torch.randn(total_images, 3, img_size, img_size)
|
|
|
|
|
|
print(f"\nInput shapes:")
|
|
|
print(f" Images: {images.shape}")
|
|
|
print(f" Num instances per case: {num_instances}")
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
with torch.no_grad():
|
|
|
logits = model(images, num_instances)
|
|
|
|
|
|
print(f"\nOutput shape:")
|
|
|
print(f" Logits: {logits.shape}")
|
|
|
print(f" Expected: [{batch_size}, {num_class}]")
|
|
|
|
|
|
assert logits.shape == (batch_size, num_class), "Output shape mismatch!"
|
|
|
print("\n✓ Model test passed!")
|
|
|
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
print("Testing individual components...")
|
|
|
print("="*60)
|
|
|
|
|
|
|
|
|
print("\n1. TransMILFeatureExtractor")
|
|
|
feature_extractor = TransMILFeatureExtractor(input_dim=2048, hidden_dim=512)
|
|
|
features = torch.randn(2, 10, 2048)
|
|
|
mask = torch.ones(2, 10, dtype=torch.bool)
|
|
|
seq_features, attn_mask = feature_extractor(features, mask)
|
|
|
print(f" Input: {features.shape}, Output: {seq_features.shape}")
|
|
|
assert seq_features.shape == (2, 11, 512)
|
|
|
print(" ✓ Passed")
|
|
|
|
|
|
|
|
|
print("\n2. HybridQuery2Label")
|
|
|
decoder = HybridQuery2Label(num_class=30, hidden_dim=512)
|
|
|
seq_features = torch.randn(2, 11, 512)
|
|
|
logits = decoder(seq_features)
|
|
|
print(f" Input: {seq_features.shape}, Output: {logits.shape}")
|
|
|
assert logits.shape == (2, 30)
|
|
|
print(" ✓ Passed")
|
|
|
|
|
|
|
|
|
print("\n3. ResNet50Backbone")
|
|
|
backbone = ResNet50Backbone(pretrained=False)
|
|
|
images = torch.randn(20, 3, 224, 224)
|
|
|
features = backbone(images)
|
|
|
print(f" Input: {images.shape}, Output: {features.shape}")
|
|
|
assert features.shape == (20, 2048)
|
|
|
print(" ✓ Passed")
|
|
|
|
|
|
print("\n" + "="*60)
|
|
|
print("All tests passed! ✓")
|
|
|
print("="*60)
|
|
|
|