adelelsayed1991's picture
Upload folder using huggingface_hub
abd02e7 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from models.mae import MaskedAutoEncoder
from models.densenet import DenseNet
class AttentionPool(nn.Module):
def __init__(self, dim=768, embed_dim=2048, num_heads=8):
super().__init__()
self.query = nn.Parameter(torch.randn(1, 1, dim))
self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=num_heads, batch_first=True)
self.proj = nn.Linear(dim, embed_dim)
def forward(self, x): # x: (B, 576, 768)
B = x.size(0)
q = self.query.expand(B, -1, -1) # (B, 1, 768)
attn_out, _ = self.attn(q, x, x) # (B, 1, 768)
return self.proj(attn_out.squeeze(1)) # (B, 2048)
class CrossAttentionBlock(nn.Module):
"""
Cross-attention: Query tokens attend to Key/Value tokens from another modality.
"""
def __init__(self, dim_q, dim_kv, num_heads=8, dropout=0.1, proj_dim=None):
super().__init__()
self.proj_dim = proj_dim or dim_q
self.num_heads = num_heads
self.head_dim = self.proj_dim // num_heads
self.scale = self.head_dim ** -0.5
self.q_proj = nn.Linear(dim_q, self.proj_dim)
self.k_proj = nn.Linear(dim_kv, self.proj_dim)
self.v_proj = nn.Linear(dim_kv, self.proj_dim)
self.out_proj = nn.Linear(self.proj_dim, dim_q)
self.dropout = nn.Dropout(dropout)
self.norm_q = nn.LayerNorm(dim_q)
self.norm_kv = nn.LayerNorm(dim_kv)
def forward(self, query, key_value):
B, N_q, _ = query.shape
N_kv = key_value.shape[1]
q = self.norm_q(query)
kv = self.norm_kv(key_value)
Q = self.q_proj(q).view(B, N_q, self.num_heads, self.head_dim).transpose(1, 2)
K = self.k_proj(kv).view(B, N_kv, self.num_heads, self.head_dim).transpose(1, 2)
V = self.v_proj(kv).view(B, N_kv, self.num_heads, self.head_dim).transpose(1, 2)
attn = (Q @ K.transpose(-2, -1)) * self.scale
attn = F.softmax(attn, dim=-1)
attn = self.dropout(attn)
out = (attn @ V).transpose(1, 2).reshape(B, N_q, self.proj_dim)
out = self.out_proj(out)
return query + self.dropout(out)
class BidirectionalCrossAttention(nn.Module):
"""
Bidirectional: MAE attends to DenseNet AND DenseNet attends to MAE.
"""
def __init__(self, mae_dim=768, dense_dim=2048, num_heads=8, dropout=0.1, proj_dim=512):
super().__init__()
# MAE queries DenseNet
self.mae_cross = CrossAttentionBlock(mae_dim, dense_dim, num_heads, dropout, proj_dim)
# DenseNet queries MAE
self.dense_cross = CrossAttentionBlock(dense_dim, mae_dim, num_heads, dropout, proj_dim)
# FFN blocks
self.mae_ffn = nn.Sequential(
nn.LayerNorm(mae_dim),
nn.Linear(mae_dim, mae_dim * 4),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(mae_dim * 4, mae_dim),
nn.Dropout(dropout)
)
self.dense_ffn = nn.Sequential(
nn.LayerNorm(dense_dim),
nn.Linear(dense_dim, dense_dim * 2),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(dense_dim * 2, dense_dim),
nn.Dropout(dropout)
)
def forward(self, mae_tokens, dense_tokens):
# Cross attention
mae_out = self.mae_cross(mae_tokens, dense_tokens)
dense_out = self.dense_cross(dense_tokens, mae_tokens)
# FFN with residual
mae_out = mae_out + self.mae_ffn(mae_out)
dense_out = dense_out + self.dense_ffn(dense_out)
return mae_out, dense_out
class LearnedLogitEnsemble(nn.Module):
def __init__(self, num_heads=7, num_classes=14, temperature_init=1.0, use_gate=False):
super().__init__()
self.num_classes = num_classes
self.num_heads = num_heads
# 1. Per-head temperature (very important!)
self.log_temps = nn.Parameter(torch.ones(num_heads) * math.log(temperature_init))
# 2. Learned head weights via tiny gating network (best version)
# Input = concatenated logits (or probs) → predicts soft weights
gate_input_dim = num_classes * num_heads # concatenating raw logits works best
self.use_gate = use_gate
if use_gate:
self.gate = nn.Sequential(
nn.Linear(gate_input_dim, 256),
nn.GELU(),
nn.LayerNorm(256),
nn.Dropout(0.1),
nn.Linear(256, num_heads),
)
else:
# Simpler: just learn fixed weights + L2 regularization later
self.raw_weights = nn.Parameter(torch.ones(num_heads))
def forward(self, logits_list):
"""
logits_list: list/tuple of 7 tensors, each (B, 14)
"""
B = logits_list[0].size(0)
device = logits_list[0].device
# Step 1: Temperature scaling per head
scaled_logits = []
for i, logits in enumerate(logits_list):
T = torch.exp(self.log_temps[i]) # >0 guaranteed
scaled_logits.append(logits / (T + 1e-8))
# Stack → (B, num_heads, num_classes)
stacked = torch.stack(scaled_logits, dim=1) # (B, 7, 14)
if self.use_gate:
# Step 2: Dynamic gating (sample-wise & class-wise aware)
gate_in = stacked.flatten(1) # (B, 7*14)
raw_gate = self.gate(gate_in) # (B, 7)
weights = torch.softmax(raw_gate, dim=-1).unsqueeze(-1) # (B,7,1)
else:
# Step 2: Fixed learned weights (still strong!)
weights = torch.softmax(self.raw_weights, dim=0) # (7,)
weights = weights.view(1, self.num_heads, 1).to(device) # (1,7,1)
# Step 3: Weighted average in logit space
fused_logits = (stacked * weights).sum(dim=1) # (B, 14)
return fused_logits
class XRAYClassifier(nn.Module):
def __init__(self, num_classes=14, c=1, mask_ratio=0, dropout=0.25, img_size=384,
encoder_dim=768, mlp_dim=3072, decoder_dim=512, encoder_depth=12,
encoder_head=8, decoder_depth=8, decoder_head=8, patch_size=8):
super().__init__()
# ---- MAE branch (frozen) ----
self.mae = MaskedAutoEncoder(
c=c, mask_ratio=0, dropout=dropout, img_size=img_size,
encoder_dim=encoder_dim, mlp_dim=mlp_dim, decoder_dim=decoder_dim,
encoder_depth=encoder_depth, encoder_head=encoder_head,
decoder_depth=decoder_depth, decoder_head=decoder_head, patch_size=patch_size
)
for p in self.mae.parameters():
p.requires_grad = False
self.token_ln = nn.LayerNorm(encoder_dim)
self.attn_selfpool_mae=AttentionPool(encoder_dim,1024)
# ---- DenseNet branch (pretrained by you) ----
# If your DenseNet supports 1 channel, set c=1 and remove the input duplication at forward.
self.dense = DenseNet(c=2, k=64, num_classes=num_classes)
self.dn_feat_dim = 2048
# ---- Cross-Attention Fusion (NEW) ----
self.cross_attn_layers = nn.ModuleList([
BidirectionalCrossAttention(
mae_dim=encoder_dim, # 768
dense_dim=self.dn_feat_dim, # 2048
num_heads=8,
dropout=0.1,
proj_dim=512
)
for _ in range(12)
])
self.attn_pool_mae=AttentionPool(encoder_dim,1024)
self.classifier_mae=nn.Sequential(
nn.Linear(1024, 512),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(512, num_classes),
)
self.attn_pool_dense=AttentionPool(self.dn_feat_dim,1024)
self.classifier_attn=nn.Sequential(
nn.Linear(2048, 1024),
nn.GELU(),
nn.Dropout(0.2),
nn.Linear(1024, 512),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(512, num_classes),
)
#FPN
self.lateral5 = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0) # feat4: 2048 ✅
self.lateral4 = nn.Conv2d(2048, 256, kernel_size=1, stride=1, padding=0) # feat3: 2048 (CHANGED)
self.lateral3 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0) # feat2: 1024 ✅
self.lateral2 = nn.Conv2d(512, 256, kernel_size=1, stride=1, padding=0) # feat1: 512 (CHANGED)
self.output5 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.output4 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.output3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.output2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
self._classify_out5 = nn.Linear(256, num_classes)
self._classify_out4 = nn.Linear(256, num_classes)
self._classify_out3 = nn.Linear(256, num_classes)
self._classify_out2 = nn.Linear(256, num_classes)
self.learned_logit_ensemble = LearnedLogitEnsemble(num_classes=num_classes)
def forward(self, x):
mae_tokens, _, _, _ = self.mae.encoder(x)
mae_tokens = self.token_ln(mae_tokens)
#self.generate_kmeans_mask(self.kmeans,mae_tokens,5)
doublex=torch.cat([x,x],dim=1) # [B, 2, 384, 384]
# ---- DenseNet path - Extract multi-scale features ----
xdense = self.dense.initialconv(doublex) # [B, 128, 192, 192]
# Layer 1 + ECA (BEFORE transition)
feat1 = self.dense.layer1(xdense)
feat1 = self.dense.dropout1(feat1)
feat1 = self.dense.eca1(feat1) # [B, 512, 192, 192] ← Keep this!
xdense1 = self.dense.trans1(feat1) # [B, 256, 96, 96]
# Layer 2 + ECA (BEFORE transition)
feat2 = self.dense.layer2(xdense1)
feat2 = self.dense.dropout2(feat2)
feat2 = self.dense.eca2(feat2) # [B, 1024, 96, 96] ← Keep this!
xdense2 = self.dense.trans2(feat2) # [B, 512, 48, 48]
# Layer 3 + ECA (BEFORE transition)
feat3 = self.dense.layer3(xdense2)
feat3 = self.dense.dropout3(feat3)
feat3 = self.dense.eca3(feat3) # [B, 2048, 48, 48] ← Keep this!
xdense3 = self.dense.trans3(feat3) # [B, 1024, 24, 24]
# Layer 4 (no transition)
feat4 = self.dense.layer4(xdense3)
feat4 = self.dense.dropout4(feat4)
feat4 = self.dense.eca4(feat4) # [B, 2048, 24, 24]
xdense4 = feat4
# Global pooling for DenseNet classifier
xdense_pooled = self.dense.global_average_pool(xdense4)
xdense_pooled = xdense_pooled.view(xdense_pooled.size(0), -1)
xdense_pooled = self.dense.dropout(xdense_pooled)
classifier_xdense = self.dense.classifier(xdense_pooled)
# Dense tokens for cross-attention
dense_tokens = xdense4.flatten(2).transpose(1, 2) # [B, 576, 2048]
# ---- FPN with CORRECT multi-scale features ----
c4 = self.lateral5(feat4) # [B, 2048, 24, 24] → [B, 256, 24, 24]
c3 = self.lateral4(feat3) # [B, 2048, 48, 48] → [B, 256, 48, 48]
c2 = self.lateral3(feat2) # [B, 1024, 96, 96] → [B, 256, 96, 96]
c1 = self.lateral2(feat1) # [B, 512, 192, 192] → [B, 256, 192, 192]
# Top-down pathway
p4 = c4 # 24×24
p4 = self.output5(p4)
p3 = self.upsample(p4) + c3 # 48×48 + 48×48 ✅
p3 = self.output4(p3)
p2 = self.upsample(p3) + c2 # 96×96 + 96×96 ✅
p2 = self.output3(p2)
p1 = self.upsample(p2) + c1 # 192×192 + 192×192 ✅
p1 = self.output2(p1)
# Classification heads
out4 = self._classify_out5(p4.mean([2, 3]))
out3 = self._classify_out4(p3.mean([2, 3]))
out2 = self._classify_out3(p2.mean([2, 3]))
out1 = self._classify_out2(p1.mean([2, 3]))
# ---- MAE path ----
mae_tokens_pooled = self.attn_selfpool_mae(mae_tokens)
classifier_mae = self.classifier_mae(mae_tokens_pooled)
# ---- Cross attention ----
for cross_layer in self.cross_attn_layers:
mae_cross, dense_cross = cross_layer(mae_tokens, dense_tokens)
mae_cross = self.attn_pool_mae(mae_cross)
dense_cross = self.attn_pool_dense(dense_cross)
out = torch.cat([mae_cross, dense_cross], dim=1)
classifier_attn = self.classifier_attn(out)
# ---- Ensemble ----
merged_classifier = self.learned_logit_ensemble([
classifier_mae,
classifier_xdense,
classifier_attn,
out4, out3, out2, out1 # 7 heads
])
return merged_classifier