|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.nn.functional as F
|
|
|
import math
|
|
|
|
|
|
from models.mae import MaskedAutoEncoder
|
|
|
from models.densenet import DenseNet
|
|
|
|
|
|
class AttentionPool(nn.Module):
|
|
|
def __init__(self, dim=768, embed_dim=2048, num_heads=8):
|
|
|
super().__init__()
|
|
|
self.query = nn.Parameter(torch.randn(1, 1, dim))
|
|
|
self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
|
|
|
self.proj = nn.Linear(dim, embed_dim)
|
|
|
|
|
|
def forward(self, x):
|
|
|
B = x.size(0)
|
|
|
q = self.query.expand(B, -1, -1)
|
|
|
attn_out, _ = self.attn(q, x, x)
|
|
|
return self.proj(attn_out.squeeze(1))
|
|
|
|
|
|
class CrossAttentionBlock(nn.Module):
|
|
|
"""
|
|
|
Cross-attention: Query tokens attend to Key/Value tokens from another modality.
|
|
|
"""
|
|
|
def __init__(self, dim_q, dim_kv, num_heads=8, dropout=0.1, proj_dim=None):
|
|
|
super().__init__()
|
|
|
self.proj_dim = proj_dim or dim_q
|
|
|
self.num_heads = num_heads
|
|
|
self.head_dim = self.proj_dim // num_heads
|
|
|
self.scale = self.head_dim ** -0.5
|
|
|
|
|
|
self.q_proj = nn.Linear(dim_q, self.proj_dim)
|
|
|
self.k_proj = nn.Linear(dim_kv, self.proj_dim)
|
|
|
self.v_proj = nn.Linear(dim_kv, self.proj_dim)
|
|
|
self.out_proj = nn.Linear(self.proj_dim, dim_q)
|
|
|
|
|
|
self.dropout = nn.Dropout(dropout)
|
|
|
self.norm_q = nn.LayerNorm(dim_q)
|
|
|
self.norm_kv = nn.LayerNorm(dim_kv)
|
|
|
|
|
|
def forward(self, query, key_value):
|
|
|
B, N_q, _ = query.shape
|
|
|
N_kv = key_value.shape[1]
|
|
|
|
|
|
q = self.norm_q(query)
|
|
|
kv = self.norm_kv(key_value)
|
|
|
|
|
|
Q = self.q_proj(q).view(B, N_q, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
K = self.k_proj(kv).view(B, N_kv, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
V = self.v_proj(kv).view(B, N_kv, self.num_heads, self.head_dim).transpose(1, 2)
|
|
|
|
|
|
attn = (Q @ K.transpose(-2, -1)) * self.scale
|
|
|
attn = F.softmax(attn, dim=-1)
|
|
|
attn = self.dropout(attn)
|
|
|
|
|
|
out = (attn @ V).transpose(1, 2).reshape(B, N_q, self.proj_dim)
|
|
|
out = self.out_proj(out)
|
|
|
|
|
|
return query + self.dropout(out)
|
|
|
|
|
|
|
|
|
class BidirectionalCrossAttention(nn.Module):
|
|
|
"""
|
|
|
Bidirectional: MAE attends to DenseNet AND DenseNet attends to MAE.
|
|
|
"""
|
|
|
def __init__(self, mae_dim=768, dense_dim=2048, num_heads=8, dropout=0.1, proj_dim=512):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.mae_cross = CrossAttentionBlock(mae_dim, dense_dim, num_heads, dropout, proj_dim)
|
|
|
|
|
|
self.dense_cross = CrossAttentionBlock(dense_dim, mae_dim, num_heads, dropout, proj_dim)
|
|
|
|
|
|
|
|
|
self.mae_ffn = nn.Sequential(
|
|
|
nn.LayerNorm(mae_dim),
|
|
|
nn.Linear(mae_dim, mae_dim * 4),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(dropout),
|
|
|
nn.Linear(mae_dim * 4, mae_dim),
|
|
|
nn.Dropout(dropout)
|
|
|
)
|
|
|
self.dense_ffn = nn.Sequential(
|
|
|
nn.LayerNorm(dense_dim),
|
|
|
nn.Linear(dense_dim, dense_dim * 2),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(dropout),
|
|
|
nn.Linear(dense_dim * 2, dense_dim),
|
|
|
nn.Dropout(dropout)
|
|
|
)
|
|
|
|
|
|
def forward(self, mae_tokens, dense_tokens):
|
|
|
|
|
|
mae_out = self.mae_cross(mae_tokens, dense_tokens)
|
|
|
dense_out = self.dense_cross(dense_tokens, mae_tokens)
|
|
|
|
|
|
|
|
|
mae_out = mae_out + self.mae_ffn(mae_out)
|
|
|
dense_out = dense_out + self.dense_ffn(dense_out)
|
|
|
|
|
|
return mae_out, dense_out
|
|
|
class LearnedLogitEnsemble(nn.Module):
|
|
|
def __init__(self, num_heads=7, num_classes=14, temperature_init=1.0, use_gate=False):
|
|
|
super().__init__()
|
|
|
self.num_classes = num_classes
|
|
|
self.num_heads = num_heads
|
|
|
|
|
|
|
|
|
self.log_temps = nn.Parameter(torch.ones(num_heads) * math.log(temperature_init))
|
|
|
|
|
|
|
|
|
|
|
|
gate_input_dim = num_classes * num_heads
|
|
|
self.use_gate = use_gate
|
|
|
|
|
|
if use_gate:
|
|
|
self.gate = nn.Sequential(
|
|
|
nn.Linear(gate_input_dim, 256),
|
|
|
nn.GELU(),
|
|
|
nn.LayerNorm(256),
|
|
|
nn.Dropout(0.1),
|
|
|
nn.Linear(256, num_heads),
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
self.raw_weights = nn.Parameter(torch.ones(num_heads))
|
|
|
|
|
|
def forward(self, logits_list):
|
|
|
"""
|
|
|
logits_list: list/tuple of 7 tensors, each (B, 14)
|
|
|
"""
|
|
|
B = logits_list[0].size(0)
|
|
|
device = logits_list[0].device
|
|
|
|
|
|
|
|
|
scaled_logits = []
|
|
|
for i, logits in enumerate(logits_list):
|
|
|
T = torch.exp(self.log_temps[i])
|
|
|
scaled_logits.append(logits / (T + 1e-8))
|
|
|
|
|
|
|
|
|
stacked = torch.stack(scaled_logits, dim=1)
|
|
|
|
|
|
if self.use_gate:
|
|
|
|
|
|
gate_in = stacked.flatten(1)
|
|
|
raw_gate = self.gate(gate_in)
|
|
|
weights = torch.softmax(raw_gate, dim=-1).unsqueeze(-1)
|
|
|
else:
|
|
|
|
|
|
weights = torch.softmax(self.raw_weights, dim=0)
|
|
|
weights = weights.view(1, self.num_heads, 1).to(device)
|
|
|
|
|
|
|
|
|
fused_logits = (stacked * weights).sum(dim=1)
|
|
|
|
|
|
return fused_logits
|
|
|
class XRAYClassifier(nn.Module):
|
|
|
def __init__(self, num_classes=14, c=1, mask_ratio=0, dropout=0.25, img_size=384,
|
|
|
encoder_dim=768, mlp_dim=3072, decoder_dim=512, encoder_depth=12,
|
|
|
encoder_head=8, decoder_depth=8, decoder_head=8, patch_size=8):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.mae = MaskedAutoEncoder(
|
|
|
c=c, mask_ratio=0, dropout=dropout, img_size=img_size,
|
|
|
encoder_dim=encoder_dim, mlp_dim=mlp_dim, decoder_dim=decoder_dim,
|
|
|
encoder_depth=encoder_depth, encoder_head=encoder_head,
|
|
|
decoder_depth=decoder_depth, decoder_head=decoder_head, patch_size=patch_size
|
|
|
)
|
|
|
for p in self.mae.parameters():
|
|
|
p.requires_grad = False
|
|
|
|
|
|
self.token_ln = nn.LayerNorm(encoder_dim)
|
|
|
self.attn_selfpool_mae=AttentionPool(encoder_dim,1024)
|
|
|
|
|
|
|
|
|
|
|
|
self.dense = DenseNet(c=2, k=64, num_classes=num_classes)
|
|
|
|
|
|
self.dn_feat_dim = 2048
|
|
|
|
|
|
|
|
|
self.cross_attn_layers = nn.ModuleList([
|
|
|
BidirectionalCrossAttention(
|
|
|
mae_dim=encoder_dim,
|
|
|
dense_dim=self.dn_feat_dim,
|
|
|
num_heads=8,
|
|
|
dropout=0.1,
|
|
|
proj_dim=512
|
|
|
)
|
|
|
for _ in range(12)
|
|
|
])
|
|
|
|
|
|
self.attn_pool_mae=AttentionPool(encoder_dim,1024)
|
|
|
|
|
|
self.classifier_mae=nn.Sequential(
|
|
|
nn.Linear(1024, 512),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(0.1),
|
|
|
nn.Linear(512, num_classes),
|
|
|
)
|
|
|
|
|
|
self.attn_pool_dense=AttentionPool(self.dn_feat_dim,1024)
|
|
|
|
|
|
self.classifier_attn=nn.Sequential(
|
|
|
nn.Linear(2048, 1024),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(0.2),
|
|
|
nn.Linear(1024, 512),
|
|
|
nn.GELU(),
|
|
|
nn.Dropout(0.1),
|
|
|
nn.Linear(512, num_classes),
|
|
|
)
|
|
|
|
|
|
self.lateral5 = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0)
|
|
|
self.lateral4 = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0)
|
|
|
self.lateral3 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
|
|
|
self.lateral2 = nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0)
|
|
|
self.output5 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
|
|
self.output4 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
|
|
self.output3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
|
|
self.output2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
|
|
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
|
|
|
|
|
self._classify_out5 = nn.Linear(256, num_classes)
|
|
|
self._classify_out4 = nn.Linear(256, num_classes)
|
|
|
self._classify_out3 = nn.Linear(256, num_classes)
|
|
|
self._classify_out2 = nn.Linear(256, num_classes)
|
|
|
|
|
|
self.learned_logit_ensemble = LearnedLogitEnsemble(num_classes=num_classes)
|
|
|
|
|
|
def forward(self, x):
|
|
|
mae_tokens, _, _, _ = self.mae.encoder(x)
|
|
|
mae_tokens = self.token_ln(mae_tokens)
|
|
|
|
|
|
doublex=torch.cat([x,x],dim=1)
|
|
|
|
|
|
xdense = self.dense.initialconv(doublex)
|
|
|
|
|
|
|
|
|
feat1 = self.dense.layer1(xdense)
|
|
|
feat1 = self.dense.dropout1(feat1)
|
|
|
feat1 = self.dense.eca1(feat1)
|
|
|
xdense1 = self.dense.trans1(feat1)
|
|
|
|
|
|
|
|
|
feat2 = self.dense.layer2(xdense1)
|
|
|
feat2 = self.dense.dropout2(feat2)
|
|
|
feat2 = self.dense.eca2(feat2)
|
|
|
xdense2 = self.dense.trans2(feat2)
|
|
|
|
|
|
|
|
|
feat3 = self.dense.layer3(xdense2)
|
|
|
feat3 = self.dense.dropout3(feat3)
|
|
|
feat3 = self.dense.eca3(feat3)
|
|
|
xdense3 = self.dense.trans3(feat3)
|
|
|
|
|
|
|
|
|
feat4 = self.dense.layer4(xdense3)
|
|
|
feat4 = self.dense.dropout4(feat4)
|
|
|
feat4 = self.dense.eca4(feat4)
|
|
|
xdense4 = feat4
|
|
|
|
|
|
|
|
|
xdense_pooled = self.dense.global_average_pool(xdense4)
|
|
|
xdense_pooled = xdense_pooled.view(xdense_pooled.size(0), -1)
|
|
|
xdense_pooled = self.dense.dropout(xdense_pooled)
|
|
|
classifier_xdense = self.dense.classifier(xdense_pooled)
|
|
|
|
|
|
|
|
|
dense_tokens = xdense4.flatten(2).transpose(1, 2)
|
|
|
|
|
|
|
|
|
c4 = self.lateral5(feat4)
|
|
|
c3 = self.lateral4(feat3)
|
|
|
c2 = self.lateral3(feat2)
|
|
|
c1 = self.lateral2(feat1)
|
|
|
|
|
|
|
|
|
p4 = c4
|
|
|
p4 = self.output5(p4)
|
|
|
|
|
|
p3 = self.upsample(p4) + c3
|
|
|
p3 = self.output4(p3)
|
|
|
|
|
|
p2 = self.upsample(p3) + c2
|
|
|
p2 = self.output3(p2)
|
|
|
|
|
|
p1 = self.upsample(p2) + c1
|
|
|
p1 = self.output2(p1)
|
|
|
|
|
|
|
|
|
out4 = self._classify_out5(p4.mean([2, 3]))
|
|
|
out3 = self._classify_out4(p3.mean([2, 3]))
|
|
|
out2 = self._classify_out3(p2.mean([2, 3]))
|
|
|
out1 = self._classify_out2(p1.mean([2, 3]))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mae_tokens_pooled = self.attn_selfpool_mae(mae_tokens)
|
|
|
classifier_mae = self.classifier_mae(mae_tokens_pooled)
|
|
|
|
|
|
|
|
|
for cross_layer in self.cross_attn_layers:
|
|
|
mae_cross, dense_cross = cross_layer(mae_tokens, dense_tokens)
|
|
|
|
|
|
mae_cross = self.attn_pool_mae(mae_cross)
|
|
|
dense_cross = self.attn_pool_dense(dense_cross)
|
|
|
out = torch.cat([mae_cross, dense_cross], dim=1)
|
|
|
classifier_attn = self.classifier_attn(out)
|
|
|
|
|
|
|
|
|
merged_classifier = self.learned_logit_ensemble([
|
|
|
classifier_mae,
|
|
|
classifier_xdense,
|
|
|
classifier_attn,
|
|
|
out4, out3, out2, out1
|
|
|
])
|
|
|
|
|
|
return merged_classifier |