OncoVision-X / src /models /dca_net.py
adityasync's picture
Clean OncoVision-X deployment with LFS
8960670
#!/usr/bin/env python3
"""
DCA-Net: Dual-Context Attention Network for Lung Nodule Classification
Architecture per roadmap1.md Phase 2:
- Stream 1 (Nodule): 2.5D CNN with EfficientNet-B0 + cross-slice attention
- Stream 2 (Context): Lightweight 3D CNN with spatial attention
- Fusion: Multi-head attention fusion module
- Prediction head with dropout
- Uncertainty quantification via MC Dropout
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import timm
# ─────────────────────────────────────────────────────────────
# Cross-Slice Attention Module
# ─────────────────────────────────────────────────────────────
class CrossSliceAttention(nn.Module):
"""Learn spatial dependencies across adjacent slices (Β±k neighbors)."""
def __init__(self, feature_dim, num_neighbors=2):
super().__init__()
self.num_neighbors = num_neighbors
self.query = nn.Linear(feature_dim, feature_dim)
self.key = nn.Linear(feature_dim, feature_dim)
self.value = nn.Linear(feature_dim, feature_dim)
self.scale = feature_dim ** 0.5
self.norm = nn.LayerNorm(feature_dim)
def forward(self, slice_features):
"""
Args:
slice_features: (B, num_slices, D)
Returns:
attended: (B, num_slices, D)
"""
B, S, D = slice_features.shape
Q = self.query(slice_features) # (B, S, D)
K = self.key(slice_features)
V = self.value(slice_features)
# Build a local attention mask so each slice only attends to Β±k neighbors
mask = torch.zeros(S, S, device=slice_features.device, dtype=torch.bool)
for i in range(S):
lo = max(0, i - self.num_neighbors)
hi = min(S, i + self.num_neighbors + 1)
mask[i, lo:hi] = True
# Scaled dot-product attention
attn = torch.matmul(Q, K.transpose(-1, -2)) / self.scale # (B, S, S)
attn = attn.masked_fill(~mask.unsqueeze(0), float('-inf'))
attn = F.softmax(attn, dim=-1)
out = torch.matmul(attn, V) # (B, S, D)
out = self.norm(out + slice_features) # residual + LayerNorm
return out, attn
# ─────────────────────────────────────────────────────────────
# Stream 1: Nodule Feature Extractor (2.5D CNN)
# ─────────────────────────────────────────────────────────────
class NoduleStream(nn.Module):
"""
Process 3D nodule patch as stack of 2D slices through EfficientNet-B0,
then apply cross-slice attention + temporal 1D convolution.
Input: (B, 1, 64, 64, 64)
Output: (B, 512)
"""
def __init__(self, backbone_name="efficientnet_b0", feature_dim=512,
num_neighbors=2, ablation=None):
super().__init__()
self.ablation = ablation
# 2D backbone (pretrained on ImageNet)
self.backbone = timm.create_model(
backbone_name, pretrained=False, in_chans=1, num_classes=0
)
backbone_out = self.backbone.num_features # e.g. 1280 for efficientnet_b0
# Project backbone features to feature_dim
self.proj = nn.Linear(backbone_out, feature_dim)
# Cross-slice attention
self.cross_attn = CrossSliceAttention(feature_dim, num_neighbors)
# Temporal 1D convolution across slices
self.temporal_conv = nn.Sequential(
nn.Conv1d(feature_dim, feature_dim, kernel_size=3, padding=1),
nn.BatchNorm1d(feature_dim),
nn.ReLU(inplace=True),
nn.Conv1d(feature_dim, feature_dim, kernel_size=3, padding=1),
nn.BatchNorm1d(feature_dim),
nn.ReLU(inplace=True),
)
# Adaptive pooling to single vector
self.pool = nn.AdaptiveAvgPool1d(1)
def forward(self, x):
"""
Args:
x: (B, 1, D, H, W) – e.g. (B, 1, 64, 64, 64)
Returns:
features: (B, 512)
attn_weights: (B, num_slices, num_slices)
"""
# --- DATAPARALLEL FIX ---
# When using nn.DataParallel, x is split across GPUs.
# B here is the *local* batch size on this specific GPU.
B, C, D, H, W = x.shape
# Reshape: treat depth as batch dim β†’ (local_B*D, 1, H, W)
slices = x.squeeze(1).contiguous() # (local_B, D, H, W)
slices = slices.view(B * D, 1, H, W) # (local_B*D, 1, H, W)
# Forward through 2D backbone
slice_feats = self.backbone(slices) # (local_B*D, backbone_out)
slice_feats = self.proj(slice_feats) # (local_B*D, feature_dim)
# Reshape back using the dynamically inferred local_B
slice_feats = slice_feats.view(B, D, -1) # (local_B, D, feature_dim)
# Cross-slice attention
if self.ablation == 'no_attention':
attended = slice_feats
attn_weights = None
else:
attended, attn_weights = self.cross_attn(slice_feats) # (local_B, D, feature_dim)
# Temporal 1D conv: (local_B, feature_dim, D)
temporal = attended.permute(0, 2, 1).contiguous()
temporal = self.temporal_conv(temporal)
# Pool across slices β†’ (local_B, feature_dim)
features = self.pool(temporal).squeeze(-1)
return features, attn_weights
# ─────────────────────────────────────────────────────────────
# Stream 2: Anatomical Context Extractor (Lightweight 3D CNN)
# ─────────────────────────────────────────────────────────────
class SpatialAttention3D(nn.Module):
"""Channel-wise spatial attention for 3D features."""
def __init__(self, in_channels):
super().__init__()
self.conv = nn.Sequential(
nn.Conv3d(in_channels, 1, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
attn = self.conv(x)
return x * attn
class ContextStream(nn.Module):
"""
Lightweight 3D CNN to capture surrounding anatomy.
Input: (B, 1, 48, 48, 48)
Output: (B, 256)
"""
def __init__(self, feature_dim=256):
super().__init__()
# Scale internal channels based on output dim
c1, c2, c3 = 64, 128, 256
if feature_dim > 256:
c1, c2, c3 = 64, 128, 512
self.block1 = nn.Sequential(
nn.Conv3d(1, c1, kernel_size=3, stride=2, padding=1),
nn.BatchNorm3d(c1),
nn.ReLU(inplace=True),
) # 48β†’24
self.block2 = nn.Sequential(
nn.Conv3d(c1, c2, kernel_size=3, stride=2, padding=1),
nn.BatchNorm3d(c2),
nn.ReLU(inplace=True),
) # 24β†’12
self.block3 = nn.Sequential(
nn.Conv3d(c2, c3, kernel_size=3, stride=2, padding=1),
nn.BatchNorm3d(c3),
nn.ReLU(inplace=True),
) # 12β†’6
self.spatial_attn = SpatialAttention3D(c3)
self.gap = nn.AdaptiveAvgPool3d(1) # β†’ (B, c3, 1, 1, 1)
self.fc = nn.Linear(c3, feature_dim)
def forward(self, x):
"""
Args:
x: (B, 1, 48, 48, 48)
Returns:
features: (B, feature_dim)
"""
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.spatial_attn(x)
x = self.gap(x).flatten(1) # (B, c3)
x = self.fc(x)
return x
# ─────────────────────────────────────────────────────────────
# Fusion Module (Multi-Head Attention)
# ─────────────────────────────────────────────────────────────
class FusionModule(nn.Module):
"""
Fuse nodule and context features via multi-head attention.
Input: [B, 512] + [B, 256] β†’ concatenated [B, 768]
Output: [B, 256]
"""
def __init__(self, nodule_dim=512, context_dim=256, fused_dim=256,
num_heads=4, dropout=0.5):
super().__init__()
total_dim = nodule_dim + context_dim # 768
# Project to a dimension divisible by num_heads
self.proj_in = nn.Linear(total_dim, fused_dim * 2)
self.attn = nn.MultiheadAttention(
embed_dim=fused_dim * 2, num_heads=num_heads,
dropout=dropout, batch_first=True
)
self.ffn = nn.Sequential(
nn.Linear(fused_dim * 2, fused_dim),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
)
self.norm = nn.LayerNorm(fused_dim)
def forward(self, nodule_feats, context_feats):
"""
Args:
nodule_feats: (B, 512)
context_feats: (B, 256) or None (if no_context ablation)
Returns:
fused: (B, 256)
"""
if context_feats is not None:
combined = torch.cat([nodule_feats, context_feats], dim=-1) # (B, 768)
else:
combined = nodule_feats # (B, 512)
proj = self.proj_in(combined) # (B, fused_dim*2)
# Self-attention expects (B, seq_len, embed_dim) β€” treat as seq_len=1
proj = proj.unsqueeze(1) # (B, 1, fused_dim*2)
attn_out, _ = self.attn(proj, proj, proj) # (B, 1, fused_dim*2)
attn_out = attn_out.squeeze(1) # (B, fused_dim*2)
fused = self.ffn(attn_out) # (B, fused_dim)
fused = self.norm(fused)
return fused
# ─────────────────────────────────────────────────────────────
# Prediction Head
# ─────────────────────────────────────────────────────────────
class PredictionHead(nn.Module):
"""
Dense 256 β†’ 128 β†’ 1 with BatchNorm, ReLU, Dropout.
"""
def __init__(self, in_dim=256, hidden_dim=128, dropout=0.3):
super().__init__()
self.head = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
nn.BatchNorm1d(hidden_dim),
nn.ReLU(inplace=True),
nn.Dropout(dropout),
nn.Linear(hidden_dim, 1),
)
def forward(self, x):
return self.head(x)
# ─────────────────────────────────────────────────────────────
# Full DCA-Net Model
# ─────────────────────────────────────────────────────────────
class DCANet(nn.Module):
"""
Dual-Context Attention Network for lung nodule classification.
Inputs:
nodule_patch: (B, 1, 64, 64, 64)
context_patch: (B, 1, 48, 48, 48)
Outputs (training):
logits: (B, 1)
Outputs (uncertainty mode):
mean_prob, confidence: (B,), (B,)
"""
def __init__(self, config=None):
super().__init__()
# Parse config or use defaults
if config is None:
config = {}
model_cfg = config.get('model', {})
# Read ablation from model section first (where YAML stores it),
# then fall back to top-level key for backward compatibility.
self.ablation = model_cfg.get('ablation', None) or config.get('ablation', None)
backbone = model_cfg.get('backbone', 'efficientnet_b0')
nodule_dim = model_cfg.get('nodule_feature_dim', 512)
context_dim = model_cfg.get('context_feature_dim', 256)
fusion_dim = model_cfg.get('fusion_dim', 256)
num_heads = model_cfg.get('num_attention_heads', 4)
dropout = model_cfg.get('dropout', 0.5)
pred_dropout = model_cfg.get('prediction_dropout', 0.3)
num_neighbors = model_cfg.get('slice_neighbors', 2)
self.mc_passes = model_cfg.get('mc_dropout_passes', 5)
# Streams
self.nodule_stream = NoduleStream(
backbone_name=backbone, feature_dim=nodule_dim,
num_neighbors=num_neighbors, ablation=self.ablation
)
# Context stream (skip entirely for no_context ablation)
if self.ablation != 'no_context':
self.context_stream = ContextStream(feature_dim=context_dim)
else:
self.context_stream = None
# Fusion β€” adjust input dimension when context is disabled
effective_context_dim = context_dim if self.ablation != 'no_context' else 0
self.fusion = FusionModule(
nodule_dim=nodule_dim, context_dim=effective_context_dim,
fused_dim=fusion_dim, num_heads=num_heads, dropout=dropout
)
# Prediction
self.prediction_head = PredictionHead(
in_dim=fusion_dim, hidden_dim=128, dropout=pred_dropout
)
def forward(self, nodule_patch, context_patch):
"""Standard forward pass (training mode).
Args:
nodule_patch: (B, 1, 64, 64, 64)
context_patch: (B, 1, 48, 48, 48)
Returns:
logits: (B, 1)
"""
nodule_feats, attn_weights = self.nodule_stream(nodule_patch)
if self.context_stream is not None:
context_feats = self.context_stream(context_patch)
else:
context_feats = None
fused = self.fusion(nodule_feats, context_feats)
logits = self.prediction_head(fused)
return logits
@torch.no_grad()
def predict_with_uncertainty(self, nodule_patch, context_patch):
"""Monte Carlo Dropout uncertainty estimation.
Runs multiple forward passes with dropout enabled,
computes mean prediction and confidence.
Args:
nodule_patch: (B, 1, 64, 64, 64)
context_patch: (B, 1, 48, 48, 48)
Returns:
mean_prob: (B,) mean probability
confidence: (B,) confidence score (1 - normalized variance)
"""
# Set to eval mode but keep dropout active
self.eval()
if self.ablation != 'no_uncertainty':
for module in self.modules():
if isinstance(module, torch.nn.Dropout):
module.train()
preds = []
for _ in range(self.mc_passes):
logits = self.forward(nodule_patch, context_patch)
prob = torch.sigmoid(logits.squeeze(-1))
preds.append(prob)
preds = torch.stack(preds, dim=0) # (mc_passes, B)
mean_prob = preds.mean(dim=0) # (B,)
variance = preds.var(dim=0) # (B,)
# Confidence: 1 - normalized variance (variance is max 0.25 for Bernoulli)
confidence = 1.0 - (variance / 0.25).clamp(0, 1)
self.eval()
return mean_prob, confidence
def get_slice_importance(self, nodule_patch, context_patch):
"""Get per-slice importance scores from cross-slice attention.
Returns:
importance: (B, num_slices) attention-based importance
"""
self.eval()
with torch.no_grad():
_, attn_weights = self.nodule_stream(nodule_patch)
# Average attention received by each slice
importance = attn_weights.mean(dim=1) # (B, num_slices)
return importance