import torch import torch.nn as nn import torch.nn.functional as F import math from models.mae import MaskedAutoEncoder from models.densenet import DenseNet class AttentionPool(nn.Module): def __init__(self, dim=768, embed_dim=2048, num_heads=8): super().__init__() self.query = nn.Parameter(torch.randn(1, 1, dim)) self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True) self.proj = nn.Linear(dim, embed_dim) def forward(self, x): # x: (B, 576, 768) B = x.size(0) q = self.query.expand(B, -1, -1) # (B, 1, 768) attn_out, _ = self.attn(q, x, x) # (B, 1, 768) return self.proj(attn_out.squeeze(1)) # (B, 2048) class CrossAttentionBlock(nn.Module): """ Cross-attention: Query tokens attend to Key/Value tokens from another modality. """ def __init__(self, dim_q, dim_kv, num_heads=8, dropout=0.1, proj_dim=None): super().__init__() self.proj_dim = proj_dim or dim_q self.num_heads = num_heads self.head_dim = self.proj_dim // num_heads self.scale = self.head_dim ** -0.5 self.q_proj = nn.Linear(dim_q, self.proj_dim) self.k_proj = nn.Linear(dim_kv, self.proj_dim) self.v_proj = nn.Linear(dim_kv, self.proj_dim) self.out_proj = nn.Linear(self.proj_dim, dim_q) self.dropout = nn.Dropout(dropout) self.norm_q = nn.LayerNorm(dim_q) self.norm_kv = nn.LayerNorm(dim_kv) def forward(self, query, key_value): B, N_q, _ = query.shape N_kv = key_value.shape[1] q = self.norm_q(query) kv = self.norm_kv(key_value) Q = self.q_proj(q).view(B, N_q, self.num_heads, self.head_dim).transpose(1, 2) K = self.k_proj(kv).view(B, N_kv, self.num_heads, self.head_dim).transpose(1, 2) V = self.v_proj(kv).view(B, N_kv, self.num_heads, self.head_dim).transpose(1, 2) attn = (Q @ K.transpose(-2, -1)) * self.scale attn = F.softmax(attn, dim=-1) attn = self.dropout(attn) out = (attn @ V).transpose(1, 2).reshape(B, N_q, self.proj_dim) out = self.out_proj(out) return query + self.dropout(out) class BidirectionalCrossAttention(nn.Module): """ Bidirectional: MAE attends to DenseNet AND DenseNet attends to MAE. """ def __init__(self, mae_dim=768, dense_dim=2048, num_heads=8, dropout=0.1, proj_dim=512): super().__init__() # MAE queries DenseNet self.mae_cross = CrossAttentionBlock(mae_dim, dense_dim, num_heads, dropout, proj_dim) # DenseNet queries MAE self.dense_cross = CrossAttentionBlock(dense_dim, mae_dim, num_heads, dropout, proj_dim) # FFN blocks self.mae_ffn = nn.Sequential( nn.LayerNorm(mae_dim), nn.Linear(mae_dim, mae_dim * 4), nn.GELU(), nn.Dropout(dropout), nn.Linear(mae_dim * 4, mae_dim), nn.Dropout(dropout) ) self.dense_ffn = nn.Sequential( nn.LayerNorm(dense_dim), nn.Linear(dense_dim, dense_dim * 2), nn.GELU(), nn.Dropout(dropout), nn.Linear(dense_dim * 2, dense_dim), nn.Dropout(dropout) ) def forward(self, mae_tokens, dense_tokens): # Cross attention mae_out = self.mae_cross(mae_tokens, dense_tokens) dense_out = self.dense_cross(dense_tokens, mae_tokens) # FFN with residual mae_out = mae_out + self.mae_ffn(mae_out) dense_out = dense_out + self.dense_ffn(dense_out) return mae_out, dense_out class LearnedLogitEnsemble(nn.Module): def __init__(self, num_heads=7, num_classes=14, temperature_init=1.0, use_gate=False): super().__init__() self.num_classes = num_classes self.num_heads = num_heads # 1. Per-head temperature (very important!) self.log_temps = nn.Parameter(torch.ones(num_heads) * math.log(temperature_init)) # 2. Learned head weights via tiny gating network (best version) # Input = concatenated logits (or probs) → predicts soft weights gate_input_dim = num_classes * num_heads # concatenating raw logits works best self.use_gate = use_gate if use_gate: self.gate = nn.Sequential( nn.Linear(gate_input_dim, 256), nn.GELU(), nn.LayerNorm(256), nn.Dropout(0.1), nn.Linear(256, num_heads), ) else: # Simpler: just learn fixed weights + L2 regularization later self.raw_weights = nn.Parameter(torch.ones(num_heads)) def forward(self, logits_list): """ logits_list: list/tuple of 7 tensors, each (B, 14) """ B = logits_list[0].size(0) device = logits_list[0].device # Step 1: Temperature scaling per head scaled_logits = [] for i, logits in enumerate(logits_list): T = torch.exp(self.log_temps[i]) # >0 guaranteed scaled_logits.append(logits / (T + 1e-8)) # Stack → (B, num_heads, num_classes) stacked = torch.stack(scaled_logits, dim=1) # (B, 7, 14) if self.use_gate: # Step 2: Dynamic gating (sample-wise & class-wise aware) gate_in = stacked.flatten(1) # (B, 7*14) raw_gate = self.gate(gate_in) # (B, 7) weights = torch.softmax(raw_gate, dim=-1).unsqueeze(-1) # (B,7,1) else: # Step 2: Fixed learned weights (still strong!) weights = torch.softmax(self.raw_weights, dim=0) # (7,) weights = weights.view(1, self.num_heads, 1).to(device) # (1,7,1) # Step 3: Weighted average in logit space fused_logits = (stacked * weights).sum(dim=1) # (B, 14) return fused_logits class XRAYClassifier(nn.Module): def __init__(self, num_classes=14, c=1, mask_ratio=0, dropout=0.25, img_size=384, encoder_dim=768, mlp_dim=3072, decoder_dim=512, encoder_depth=12, encoder_head=8, decoder_depth=8, decoder_head=8, patch_size=8): super().__init__() # ---- MAE branch (frozen) ---- self.mae = MaskedAutoEncoder( c=c, mask_ratio=0, dropout=dropout, img_size=img_size, encoder_dim=encoder_dim, mlp_dim=mlp_dim, decoder_dim=decoder_dim, encoder_depth=encoder_depth, encoder_head=encoder_head, decoder_depth=decoder_depth, decoder_head=decoder_head, patch_size=patch_size ) for p in self.mae.parameters(): p.requires_grad = False self.token_ln = nn.LayerNorm(encoder_dim) self.attn_selfpool_mae=AttentionPool(encoder_dim,1024) # ---- DenseNet branch (pretrained by you) ---- # If your DenseNet supports 1 channel, set c=1 and remove the input duplication at forward. self.dense = DenseNet(c=2, k=64, num_classes=num_classes) self.dn_feat_dim = 2048 # ---- Cross-Attention Fusion (NEW) ---- self.cross_attn_layers = nn.ModuleList([ BidirectionalCrossAttention( mae_dim=encoder_dim, # 768 dense_dim=self.dn_feat_dim, # 2048 num_heads=8, dropout=0.1, proj_dim=512 ) for _ in range(12) ]) self.attn_pool_mae=AttentionPool(encoder_dim,1024) self.classifier_mae=nn.Sequential( nn.Linear(1024, 512), nn.GELU(), nn.Dropout(0.1), nn.Linear(512, num_classes), ) self.attn_pool_dense=AttentionPool(self.dn_feat_dim,1024) self.classifier_attn=nn.Sequential( nn.Linear(2048, 1024), nn.GELU(), nn.Dropout(0.2), nn.Linear(1024, 512), nn.GELU(), nn.Dropout(0.1), nn.Linear(512, num_classes), ) #FPN self.lateral5 = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0) # feat4: 2048 ✅ self.lateral4 = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0) # feat3: 2048 (CHANGED) self.lateral3 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) # feat2: 1024 ✅ self.lateral2 = nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0) # feat1: 512 (CHANGED) self.output5 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) self.output4 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) self.output3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) self.output2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1) self.upsample = nn.Upsample(scale_factor=2, mode='nearest') self._classify_out5 = nn.Linear(256, num_classes) self._classify_out4 = nn.Linear(256, num_classes) self._classify_out3 = nn.Linear(256, num_classes) self._classify_out2 = nn.Linear(256, num_classes) self.learned_logit_ensemble = LearnedLogitEnsemble(num_classes=num_classes) def forward(self, x): mae_tokens, _, _, _ = self.mae.encoder(x) mae_tokens = self.token_ln(mae_tokens) #self.generate_kmeans_mask(self.kmeans,mae_tokens,5) doublex=torch.cat([x,x],dim=1) # [B, 2, 384, 384] # ---- DenseNet path - Extract multi-scale features ---- xdense = self.dense.initialconv(doublex) # [B, 128, 192, 192] # Layer 1 + ECA (BEFORE transition) feat1 = self.dense.layer1(xdense) feat1 = self.dense.dropout1(feat1) feat1 = self.dense.eca1(feat1) # [B, 512, 192, 192] ← Keep this! xdense1 = self.dense.trans1(feat1) # [B, 256, 96, 96] # Layer 2 + ECA (BEFORE transition) feat2 = self.dense.layer2(xdense1) feat2 = self.dense.dropout2(feat2) feat2 = self.dense.eca2(feat2) # [B, 1024, 96, 96] ← Keep this! xdense2 = self.dense.trans2(feat2) # [B, 512, 48, 48] # Layer 3 + ECA (BEFORE transition) feat3 = self.dense.layer3(xdense2) feat3 = self.dense.dropout3(feat3) feat3 = self.dense.eca3(feat3) # [B, 2048, 48, 48] ← Keep this! xdense3 = self.dense.trans3(feat3) # [B, 1024, 24, 24] # Layer 4 (no transition) feat4 = self.dense.layer4(xdense3) feat4 = self.dense.dropout4(feat4) feat4 = self.dense.eca4(feat4) # [B, 2048, 24, 24] xdense4 = feat4 # Global pooling for DenseNet classifier xdense_pooled = self.dense.global_average_pool(xdense4) xdense_pooled = xdense_pooled.view(xdense_pooled.size(0), -1) xdense_pooled = self.dense.dropout(xdense_pooled) classifier_xdense = self.dense.classifier(xdense_pooled) # Dense tokens for cross-attention dense_tokens = xdense4.flatten(2).transpose(1, 2) # [B, 576, 2048] # ---- FPN with CORRECT multi-scale features ---- c4 = self.lateral5(feat4) # [B, 2048, 24, 24] → [B, 256, 24, 24] c3 = self.lateral4(feat3) # [B, 2048, 48, 48] → [B, 256, 48, 48] c2 = self.lateral3(feat2) # [B, 1024, 96, 96] → [B, 256, 96, 96] c1 = self.lateral2(feat1) # [B, 512, 192, 192] → [B, 256, 192, 192] # Top-down pathway p4 = c4 # 24×24 p4 = self.output5(p4) p3 = self.upsample(p4) + c3 # 48×48 + 48×48 ✅ p3 = self.output4(p3) p2 = self.upsample(p3) + c2 # 96×96 + 96×96 ✅ p2 = self.output3(p2) p1 = self.upsample(p2) + c1 # 192×192 + 192×192 ✅ p1 = self.output2(p1) # Classification heads out4 = self._classify_out5(p4.mean([2, 3])) out3 = self._classify_out4(p3.mean([2, 3])) out2 = self._classify_out3(p2.mean([2, 3])) out1 = self._classify_out2(p1.mean([2, 3])) # ---- MAE path ---- mae_tokens_pooled = self.attn_selfpool_mae(mae_tokens) classifier_mae = self.classifier_mae(mae_tokens_pooled) # ---- Cross attention ---- for cross_layer in self.cross_attn_layers: mae_cross, dense_cross = cross_layer(mae_tokens, dense_tokens) mae_cross = self.attn_pool_mae(mae_cross) dense_cross = self.attn_pool_dense(dense_cross) out = torch.cat([mae_cross, dense_cross], dim=1) classifier_attn = self.classifier_attn(out) # ---- Ensemble ---- merged_classifier = self.learned_logit_ensemble([ classifier_mae, classifier_xdense, classifier_attn, out4, out3, out2, out1 # 7 heads ]) return merged_classifier