BrainAnytime-Demo / models /multimae3d.py
Simmonstt's picture
Import BrainAnytime code from GitHub and configure Gradio Space
041602e verified
"""
MultiMAE3D: Multi-modal Masked Autoencoder for 3D Medical Images
Architecture:
- Per-modality input adapters (Conv3D patch embedding)
- Shared ViT encoder
- Per-modality output adapters (cross-attn decoder)
- Handles arbitrary missing modalities via observed mask
Based on MultiMAE_reference, simplified for our use case:
- Fixed input size 128^3, 4 modalities (T1, T2, Flair, PET)
- Pure reconstruction pretraining (MSE loss)
- No Hydra/Lightning dependencies
"""
import copy
import math
from typing import Union, Tuple, Dict, List, Optional
from collections import OrderedDict
from functools import partial
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.layers import DropPath
from einops import rearrange
from models.multimae3d_utils import (
to_3tuple,
calc_patchified_dim,
patchify,
unpatchify,
shuffle_patches,
unshuffle_patches,
build_3d_sincos_position_embedding,
mask_data,
)
# =============================================================================
# Input Adapter: Conv3D patch embedding (per modality)
# =============================================================================
class PatchedInputAdapter(nn.Module):
"""
Converts a single-channel 3D volume into patch tokens.
Input: [B, N_selected, 1, pd, ph, pw] (selected shuffled patches)
Output: [B, N_selected, embed_dim]
"""
def __init__(
self,
in_channels: int = 1,
patch_size: Union[int, Tuple[int, int, int]] = 16,
embed_dim: int = 768,
):
super().__init__()
self.in_channels = in_channels
self.patch_size = to_3tuple(patch_size)
self.embed_dim = embed_dim
# Conv3D projection: each patch -> embed_dim
self.proj = nn.Conv3d(
in_channels,
embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: [B, N, C, pd, ph, pw] selected patches (already patchified & shuffled)
returns: [B, N, embed_dim]
"""
B, N = x.shape[0], x.shape[1]
# Merge batch and patch dims for Conv3D
x = rearrange(x, "b n c d h w -> (b n) c d h w")
x = self.proj(x) # [(B*N), embed_dim, 1, 1, 1]
x = x.flatten(2) # [(B*N), embed_dim, 1]
x = x.squeeze(-1) # [(B*N), embed_dim]
x = rearrange(x, "(b n) d -> b n d", b=B)
return x
# =============================================================================
# Cross Attention (for decoder)
# =============================================================================
class CrossAttention(nn.Module):
"""Cross attention: query attends to context (encoder output)."""
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = True,
attn_drop: float = 0.0, proj_drop: float = 0.0):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
B, N, C = x.shape
_, M, _ = context.shape
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
kv = self.kv(context).reshape(B, M, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
# =============================================================================
# Transformer blocks with attention mask support
# =============================================================================
class Mlp(nn.Module):
"""Simple MLP with GELU activation."""
def __init__(self, in_features, hidden_features=None, out_features=None,
act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class MaskedAttention(nn.Module):
"""Multi-head self-attention with optional additive attention mask."""
def __init__(self, dim, num_heads=12, qkv_bias=True,
attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x, attn_mask=None):
"""
x: [B, N, C]
attn_mask: [B, 1, 1, N] additive mask, -inf for tokens to ignore (column masking)
"""
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # each [B, num_heads, N, head_dim]
attn = (q @ k.transpose(-2, -1)) * self.scale # [B, num_heads, N, N]
if attn_mask is not None:
attn = attn + attn_mask
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class MaskedBlock(nn.Module):
"""Pre-LN Transformer block with optional attention mask support.
Used for both encoder (with mask) and decoder (without mask).
"""
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True,
drop_path=0., act_layer=nn.GELU,
norm_layer=partial(nn.LayerNorm, eps=1e-6)):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = MaskedAttention(dim, num_heads=num_heads, qkv_bias=qkv_bias)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden, act_layer=act_layer)
def forward(self, x, attn_mask=None):
x = x + self.drop_path(self.attn(self.norm1(x), attn_mask=attn_mask))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
# =============================================================================
# Cross-Modal Predictor (for cross-level mutual prediction)
# =============================================================================
class CrossModalPredictor(nn.Module):
"""3-layer MLP predictor for cross-modal feature prediction.
Maps features from one modality space to another.
Structure: Linear(D, 2D) → GELU → Linear(2D, 2D) → GELU → Linear(2D, D)
"""
def __init__(self, dim: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * 2),
nn.GELU(),
nn.Linear(dim * 2, dim * 2),
nn.GELU(),
nn.Linear(dim * 2, dim),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
# =============================================================================
# Output Adapter: Decoder (per modality)
# =============================================================================
class SpatialOutputAdapter(nn.Module):
"""
Per-modality decoder.
Takes encoder tokens, adds mask tokens, applies cross-attention + self-attention,
then projects back to patch pixel space.
Architecture:
1. Project encoder tokens from encoder_dim -> decoder_dim
2. Create mask tokens for masked positions
3. Add positional embedding to query (mask + selected tokens)
4. Cross-attention: query attends to encoder context
5. Self-attention transformer blocks
6. Linear projection to patch pixel dimension
"""
def __init__(
self,
out_channels: int = 1,
img_size: Union[int, Tuple[int, int, int]] = 128,
patch_size: Union[int, Tuple[int, int, int]] = 16,
encoder_embed_dim: int = 768,
embed_dim: int = 384,
num_heads: int = 12,
depth: int = 2,
mlp_ratio: float = 4.0,
qkv_bias: bool = True,
):
super().__init__()
self.out_channels = out_channels
self.img_size = to_3tuple(img_size)
self.patch_size = to_3tuple(patch_size)
self.embed_dim = embed_dim
self.num_heads = num_heads
self.depth = depth
self.patchified_dim = calc_patchified_dim(self.img_size, self.patch_size)
self.num_patches = self.patchified_dim[0] * self.patchified_dim[1] * self.patchified_dim[2]
# Project encoder tokens to decoder dimension
self.proj_context = nn.Linear(encoder_embed_dim, embed_dim)
# Learnable mask token
self.mask_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
nn.init.normal_(self.mask_token, std=0.02)
# Decoder positional embedding (sincos, frozen)
self.pos_embed = build_3d_sincos_position_embedding(
grid_size=self.patchified_dim,
embed_dim=embed_dim,
)
# Cross-attention + MLP (MultiMAE style)
self.xattn = CrossAttention(
dim=embed_dim, num_heads=num_heads, qkv_bias=qkv_bias,
)
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.context_norm = norm_layer(embed_dim)
self.query_norm = norm_layer(embed_dim)
self.out_norm = norm_layer(embed_dim)
mlp_hidden = int(embed_dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(embed_dim, mlp_hidden),
nn.GELU(),
nn.Linear(mlp_hidden, embed_dim),
)
# Self-attention transformer blocks (decoder: no attention mask needed)
self.blocks = nn.Sequential(*[
MaskedBlock(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
act_layer=nn.GELU,
norm_layer=norm_layer,
)
for _ in range(depth)
]) if depth > 0 else nn.Identity()
# Output projection: decoder_dim -> patch_pixels
dim_patch = self.patch_size[0] * self.patch_size[1] * self.patch_size[2] * out_channels
self.out_proj = nn.Linear(embed_dim, dim_patch)
def forward(
self,
encoder_tokens: torch.Tensor,
task_range: Tuple[int, int],
perm_idx: torch.Tensor,
num_patches: int,
) -> torch.Tensor:
"""
Args:
encoder_tokens: [B, total_visible_tokens, encoder_dim] (last layer output)
task_range: (start, end) indices of this modality's tokens in the concat
perm_idx: [B, num_patches] permutation indices for this modality
num_patches: total number of patches for this modality
Returns:
output: [B, num_patches, out_channels, pd, ph, pw] (all patches, unshuffled order)
"""
B = encoder_tokens.shape[0]
# 1. Project encoder tokens to decoder dim
context = self.proj_context(encoder_tokens)
# 2. Extract this modality's selected tokens from the context
num_selected = task_range[1] - task_range[0]
selected_tokens = context[:, task_range[0]:task_range[1]]
# 3. Create mask tokens for masked positions
num_masked = num_patches - num_selected
mask_tokens = self.mask_token.repeat(B, num_masked, 1)
# 4. Concatenate: [selected, masked] in shuffled order
query = torch.cat([selected_tokens, mask_tokens], dim=1) # [B, num_patches, dim]
# 5. Add positional embedding (following the permutation order)
pos_emb = self.pos_embed.expand(B, -1, -1) # [B, num_patches, dim]
pos_emb_shuffled = pos_emb[torch.arange(B, device=pos_emb.device)[:, None], perm_idx]
query = query + pos_emb_shuffled
# 6. Cross-attention + MLP
x = self.xattn(self.query_norm(query), self.context_norm(context))
x = x + self.mlp(self.out_norm(x))
# 7. Self-attention blocks
if self.depth > 0:
x = self.blocks(x)
# 8. Project to patch pixel space
x = self.out_proj(x) # [B, num_patches, patch_pixels]
# 9. Reshape to patch format
x = rearrange(
x,
"b n (c pd ph pw) -> b n c pd ph pw",
c=self.out_channels,
pd=self.patch_size[0],
ph=self.patch_size[1],
pw=self.patch_size[2],
)
# 10. Unshuffle back to spatial order
x = unshuffle_patches(x, perm_idx)
return x
# =============================================================================
# MultiMAE3D: Main Model
# =============================================================================
class MultiMAE3D(nn.Module):
"""
Multi-modal Masked Autoencoder for 3D Medical Images.
Handles 4 modalities (T1, T2, Flair, PET) with arbitrary missing modalities.
Forward pass:
1. Split stacked input into per-modality volumes
2. Patchify and mask each modality (missing → 100% masked)
3. Tokenize visible patches via per-modality input adapters
4. Add positional embeddings + CLS token
5. Concatenate all visible tokens → shared ViT encoder
6. Per-modality decoder → reconstruct masked patches
7. Compute MSE loss only on present modalities' masked patches
"""
MODALITY_NAMES = ["T1", "T2", "Flair", "PET"]
def __init__(
self,
img_size: Union[int, Tuple[int, int, int]] = 128,
patch_size: Union[int, Tuple[int, int, int]] = 16,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
mlp_ratio: float = 4.0,
decoder_embed_dim: int = 384,
decoder_depth: int = 2,
decoder_num_heads: int = 12,
mask_ratio: float = 0.75,
use_dirichlet: bool = True,
dirichlet_alpha: float = 1.0,
num_global_tokens: int = 1,
qkv_bias: bool = True,
drop_path_rate: float = 0.0,
enable_cross_modal: bool = False,
):
super().__init__()
self.img_size = to_3tuple(img_size)
self.patch_size = to_3tuple(patch_size)
self.embed_dim = embed_dim
self.depth = depth
self.mask_ratio = mask_ratio
self.use_dirichlet = use_dirichlet
self.dirichlet_alpha = dirichlet_alpha
self.num_global_tokens = num_global_tokens
self.enable_cross_modal = enable_cross_modal
self.patchified_dim = calc_patchified_dim(self.img_size, self.patch_size)
self.num_patches = self.patchified_dim[0] * self.patchified_dim[1] * self.patchified_dim[2]
# ----- Input adapters (per modality) -----
self.input_adapters = nn.ModuleDict({
name: PatchedInputAdapter(
in_channels=1,
patch_size=patch_size,
embed_dim=embed_dim,
)
for name in self.MODALITY_NAMES
})
# ----- Encoder positional embedding (sincos, frozen) -----
self.pos_embed = build_3d_sincos_position_embedding(
grid_size=self.patchified_dim,
embed_dim=embed_dim,
)
# ----- CLS token -----
if num_global_tokens > 0:
self.global_tokens = nn.Parameter(torch.zeros(num_global_tokens, embed_dim))
nn.init.normal_(self.global_tokens, std=0.02)
# ----- Shared Transformer encoder (ModuleList for attn_mask support) -----
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]
norm_layer = partial(nn.LayerNorm, eps=1e-6)
self.encoder = nn.ModuleList([
MaskedBlock(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
drop_path=dpr[i],
act_layer=nn.GELU,
norm_layer=norm_layer,
)
for i in range(depth)
])
# ----- Output adapters / decoders (per modality) -----
self.output_adapters = nn.ModuleDict({
name: SpatialOutputAdapter(
out_channels=1,
img_size=img_size,
patch_size=patch_size,
encoder_embed_dim=embed_dim,
embed_dim=decoder_embed_dim,
num_heads=decoder_num_heads,
depth=decoder_depth,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
)
for name in self.MODALITY_NAMES
})
# Initialize weights
self._initialize_weights()
# ----- Cross-modal mutual prediction components -----
if self.enable_cross_modal:
# Teacher encoder (EMA copy of student) — no gradients
self.teacher_input_adapters = copy.deepcopy(self.input_adapters)
for p in self.teacher_input_adapters.parameters():
p.requires_grad = False
self.teacher_encoder = copy.deepcopy(self.encoder)
for p in self.teacher_encoder.parameters():
p.requires_grad = False
# Teacher global tokens stored as buffer (auto-moves with .to(device))
if self.num_global_tokens > 0:
self.register_buffer(
"teacher_global_tokens",
self.global_tokens.data.clone(),
)
# Cross-modal predictors (student-only, learnable)
self.predictor_mri_to_pet = CrossModalPredictor(embed_dim)
self.predictor_pet_to_mri = CrossModalPredictor(embed_dim)
# Initialize predictor weights
self.predictor_mri_to_pet.apply(self._init_weights)
self.predictor_pet_to_mri.apply(self._init_weights)
def _initialize_weights(self):
self.apply(self._init_weights)
# Special init for Conv3D projection (following MAE)
for name, m in self.named_modules():
if isinstance(m, nn.Linear):
if "qkv" in name:
val = math.sqrt(6.0 / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
nn.init.uniform_(m.weight, -val, val)
elif "kv" in name:
val = math.sqrt(6.0 / float(m.weight.shape[0] // 2 + m.weight.shape[1]))
nn.init.uniform_(m.weight, -val, val)
if isinstance(m, nn.Conv3d):
if ".proj" in name:
w = m.weight.data
nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
@staticmethod
def _init_weights(m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def _split_modalities(self, images: torch.Tensor) -> Dict[str, torch.Tensor]:
"""Split stacked [B, 4, D, H, W] into per-modality dict {name: [B, 1, D, H, W]}."""
return {
name: images[:, i:i+1]
for i, name in enumerate(self.MODALITY_NAMES)
}
# -----------------------------------------------------------------
# Cross-modal mutual prediction helpers
# -----------------------------------------------------------------
def _encode_with(
self,
selected_patches: Dict[str, torch.Tensor],
perm_indices: Dict[str, torch.Tensor],
observed: torch.Tensor,
input_adapters: nn.ModuleDict,
global_tokens,
encoder_blocks: nn.ModuleList,
):
"""
Shared encoding logic used by both student and teacher.
Returns:
encoder_output: [B, total_tokens, D] or None
task_ranges: OrderedDict {modality_name: (start, end)}
"""
B = observed.shape[0]
device = observed.device
tokens = {}
for name in self.MODALITY_NAMES:
sel = selected_patches[name]
if sel.shape[1] == 0:
continue
tok = input_adapters[name](sel)
perm = perm_indices[name]
pos_emb = self.pos_embed.expand(B, -1, -1)
pos_emb_selected = pos_emb[
torch.arange(B, device=device)[:, None], perm[:, :sel.shape[1]]
]
tok = tok + pos_emb_selected
tokens[name] = tok
token_list = []
task_ranges = OrderedDict()
offset = self.num_global_tokens
for name in self.MODALITY_NAMES:
if name in tokens:
n_tok = tokens[name].shape[1]
task_ranges[name] = (offset, offset + n_tok)
token_list.append(tokens[name])
offset += n_tok
else:
task_ranges[name] = (offset, offset)
if len(token_list) == 0:
return None, task_ranges
input_tokens = torch.cat(token_list, dim=1)
if self.num_global_tokens > 0 and global_tokens is not None:
if global_tokens.dim() == 2:
cls = global_tokens.unsqueeze(0).expand(B, -1, -1)
else:
cls = global_tokens.expand(B, -1, -1)
input_tokens = torch.cat([cls, input_tokens], dim=1)
# Column masking for missing modalities
total_tokens = input_tokens.shape[1]
attn_mask = torch.zeros(B, 1, 1, total_tokens, device=device)
for i, name in enumerate(self.MODALITY_NAMES):
start, end = task_ranges[name]
if start == end:
continue
missing = (observed[:, i] < 0.5)
if missing.any():
attn_mask[missing, :, :, start:end] = float("-inf")
if (attn_mask == 0).all():
attn_mask = None
encoder_output = input_tokens
for block in encoder_blocks:
encoder_output = block(encoder_output, attn_mask=attn_mask)
return encoder_output, task_ranges
def _compute_cross_modal_loss(
self,
selected_patches: Dict[str, torch.Tensor],
perm_indices: Dict[str, torch.Tensor],
observed: torch.Tensor,
student_encoder_output: torch.Tensor,
task_ranges: OrderedDict,
) -> torch.Tensor:
"""
Cross-level mutual prediction loss (simplified global-average-pooling version).
Two groups:
- MRI group: all T1 + T2 + Flair tokens → z_MRI (D-dim vector)
- PET group: all PET tokens → z_PET (D-dim vector)
Predictions (student → teacher):
- predictor_mri_to_pet(z_MRI_student) → predict z_PET_teacher
- predictor_pet_to_mri(z_PET_student) → predict z_MRI_teacher
Loss: negative cosine similarity, averaged over paired samples only.
"""
B = observed.shape[0]
device = observed.device
# Paired = has at least one MRI modality AND PET
has_mri = (observed[:, :3].sum(dim=1) > 0.5) # [B]
has_pet = (observed[:, 3] > 0.5) # [B]
is_paired = has_mri & has_pet # [B]
if not is_paired.any():
return torch.tensor(0.0, device=device, requires_grad=True)
# --- Teacher forward (no gradients) ---
with torch.no_grad():
teacher_gt = (
self.teacher_global_tokens
if self.num_global_tokens > 0 else None
)
teacher_output, _ = self._encode_with(
selected_patches, perm_indices, observed,
self.teacher_input_adapters, teacher_gt,
self.teacher_encoder,
)
if teacher_output is None:
return torch.tensor(0.0, device=device, requires_grad=True)
# --- Build group masks [B, L] ---
total_tokens = student_encoder_output.shape[1]
mri_mask = torch.zeros(B, total_tokens, device=device)
pet_mask = torch.zeros(B, total_tokens, device=device)
# MRI group: T1 (idx 0), T2 (idx 1), Flair (idx 2)
for idx, name in enumerate(["T1", "T2", "Flair"]):
start, end = task_ranges[name]
if start < end:
mri_mask[:, start:end] = observed[:, idx:idx+1].expand(-1, end - start)
# PET group: idx 3
start, end = task_ranges["PET"]
if start < end:
pet_mask[:, start:end] = observed[:, 3:4].expand(-1, end - start)
# --- Global average pooling per group ---
mri_count = mri_mask.sum(dim=1, keepdim=True).clamp(min=1)
pet_count = pet_mask.sum(dim=1, keepdim=True).clamp(min=1)
z_mri_s = (student_encoder_output * mri_mask.unsqueeze(-1)).sum(dim=1) / mri_count # [B, D]
z_pet_s = (student_encoder_output * pet_mask.unsqueeze(-1)).sum(dim=1) / pet_count # [B, D]
z_mri_t = (teacher_output * mri_mask.unsqueeze(-1)).sum(dim=1) / mri_count # [B, D]
z_pet_t = (teacher_output * pet_mask.unsqueeze(-1)).sum(dim=1) / pet_count # [B, D]
# --- L2 normalize onto unit hypersphere ---
z_mri_s = F.normalize(z_mri_s, dim=-1)
z_pet_s = F.normalize(z_pet_s, dim=-1)
z_mri_t = F.normalize(z_mri_t, dim=-1)
z_pet_t = F.normalize(z_pet_t, dim=-1)
# --- Cross-modal predictions + normalize ---
pred_pet = F.normalize(self.predictor_mri_to_pet(z_mri_s), dim=-1) # [B, D]
pred_mri = F.normalize(self.predictor_pet_to_mri(z_pet_s), dim=-1) # [B, D]
# --- Negative cosine similarity: L = 2 - 2·cos(pred, target) ---
loss_m2p = 2 - 2 * (pred_pet * z_pet_t.detach()).sum(dim=-1) # [B]
loss_p2m = 2 - 2 * (pred_mri * z_mri_t.detach()).sum(dim=-1) # [B]
# Average only over paired samples
paired_f = is_paired.float()
n_paired = paired_f.sum().clamp(min=1)
loss_m2p = (loss_m2p * paired_f).sum() / n_paired
loss_p2m = (loss_p2m * paired_f).sum() / n_paired
return 0.5 * (loss_m2p + loss_p2m)
@torch.no_grad()
def update_teacher(self, momentum: float):
"""EMA update: θ_teacher ← m·θ_teacher + (1-m)·θ_student."""
if not self.enable_cross_modal:
return
for p_s, p_t in zip(
self.input_adapters.parameters(),
self.teacher_input_adapters.parameters(),
):
p_t.data.mul_(momentum).add_(p_s.data, alpha=1 - momentum)
if self.num_global_tokens > 0:
self.teacher_global_tokens.mul_(momentum).add_(
self.global_tokens.data, alpha=1 - momentum
)
for p_s, p_t in zip(
self.encoder.parameters(),
self.teacher_encoder.parameters(),
):
p_t.data.mul_(momentum).add_(p_s.data, alpha=1 - momentum)
@torch.no_grad()
def init_teacher_from_student(self):
"""Copy current student weights to teacher (call after loading checkpoint)."""
if not self.enable_cross_modal:
return
for p_s, p_t in zip(
self.input_adapters.parameters(),
self.teacher_input_adapters.parameters(),
):
p_t.data.copy_(p_s.data)
if self.num_global_tokens > 0:
self.teacher_global_tokens.copy_(self.global_tokens.data)
for p_s, p_t in zip(
self.encoder.parameters(),
self.teacher_encoder.parameters(),
):
p_t.data.copy_(p_s.data)
def forward(
self,
images: torch.Tensor,
observed: torch.Tensor,
return_loss: bool = True,
patch_mask_probs: torch.Tensor = None,
) -> Dict[str, torch.Tensor]:
"""
Args:
images: [B, 4, D, H, W] stacked multi-modal 3D volumes
observed: [B, 4] float tensor, 1.0=present, 0.0=missing
return_loss: if True, compute and return reconstruction loss
patch_mask_probs: optional [N_patches] per-patch masking probability
from anatomy-aware masking (higher = more likely to be masked)
Returns:
dict with:
'loss': scalar MSE loss (if return_loss=True)
'per_modality_loss': {name: loss} for each present modality
'mask_ratios': {name: float} actual mask ratios used
"""
B = images.shape[0]
device = images.device
# 1. Split into per-modality dict
batch = self._split_modalities(images)
# 2. Mask data (patchify + shuffle + split)
# When patch_mask_probs is provided, uses anatomy-aware weighted sampling
selected_patches, masked_patches, perm_indices, mask_ratios = mask_data(
batch=batch,
modality_names=self.MODALITY_NAMES,
observed=observed,
mask_ratio=self.mask_ratio,
patch_size=self.patch_size,
use_dirichlet=self.use_dirichlet if self.training else False,
dirichlet_alpha=self.dirichlet_alpha,
patch_mask_probs=patch_mask_probs if self.training else None,
)
# 3-6. Student encoding (tokenize → concat → attn mask → encoder)
encoder_output, task_ranges = self._encode_with(
selected_patches, perm_indices, observed,
self.input_adapters, self.global_tokens, self.encoder,
)
if encoder_output is None:
return {
"loss": torch.tensor(0.0, device=device),
"cross_modal_loss": torch.tensor(0.0, device=device),
"per_modality_loss": {},
"mask_ratios": mask_ratios,
}
# 7. Per-modality decoder
reconstructed = {}
for name in self.MODALITY_NAMES:
reconstructed[name] = self.output_adapters[name](
encoder_tokens=encoder_output,
task_range=task_ranges[name],
perm_idx=perm_indices[name],
num_patches=self.num_patches,
)
# reconstructed[name]: [B, num_patches, 1, pd, ph, pw] in spatial order
# 8. Compute reconstruction loss (MSE, only on present modalities' masked patches)
if return_loss:
total_loss = torch.tensor(0.0, device=device)
per_mod_loss = {}
num_present = 0
for i, name in enumerate(self.MODALITY_NAMES):
# Only compute loss on present modalities
mod_observed = observed[:, i] # [B]
if mod_observed.sum() < 0.5:
continue
# Ground truth: all patches in spatial order
gt_patches = patchify(batch[name], self.patch_size) # [B, num_patches, 1, pd, ph, pw]
pred_patches = reconstructed[name] # [B, num_patches, 1, pd, ph, pw]
# Create per-patch mask: 1 = masked (should reconstruct), 0 = visible
perm = perm_indices[name]
num_selected = selected_patches[name].shape[1]
# In shuffled order: first num_selected are visible, rest masked
# Convert to spatial order mask (vectorized, no Python loop)
mask = torch.ones(B, self.num_patches, device=device)
if num_selected > 0:
selected_perm = perm[:, :num_selected] # [B, num_selected]
mask.scatter_(1, selected_perm, 0.0)
# Per-sample observed mask: zero out loss for missing samples
sample_mask = mod_observed.float() # [B]
# Patch normalization (per-patch zero-mean unit-variance, like original MAE)
gt_mean = gt_patches.mean(dim=(2, 3, 4, 5), keepdim=True)
gt_var = gt_patches.var(dim=(2, 3, 4, 5), keepdim=True)
gt_patches_norm = (gt_patches - gt_mean) / (gt_var + 1e-6).sqrt()
# Compute MSE on masked patches only (against normalized targets)
per_patch_mse = ((pred_patches - gt_patches_norm) ** 2).mean(dim=(2, 3, 4, 5)) # [B, num_patches]
masked_mse = (per_patch_mse * mask).sum(dim=1) / mask.sum(dim=1).clamp(min=1) # [B]
mod_loss = (masked_mse * sample_mask).sum() / sample_mask.sum().clamp(min=1)
per_mod_loss[name] = mod_loss
total_loss = total_loss + mod_loss
num_present += 1
if num_present > 0:
total_loss = total_loss / num_present
# 9. Cross-modal mutual prediction loss
cross_modal_loss = torch.tensor(0.0, device=device)
if self.enable_cross_modal:
cross_modal_loss = self._compute_cross_modal_loss(
selected_patches, perm_indices, observed,
encoder_output, task_ranges,
)
return {
"loss": total_loss,
"cross_modal_loss": cross_modal_loss,
"per_modality_loss": per_mod_loss,
"mask_ratios": mask_ratios,
}
return {
"reconstructed": reconstructed,
"cross_modal_loss": torch.tensor(0.0, device=device),
"mask_ratios": mask_ratios,
}
def encode(
self,
images: torch.Tensor,
observed: torch.Tensor,
) -> torch.Tensor:
"""
Encode without masking (for downstream use).
Returns encoder output tokens [B, num_global + 4*num_patches, embed_dim].
"""
B = images.shape[0]
device = images.device
batch = self._split_modalities(images)
tokens_list = []
offset = self.num_global_tokens
for i, name in enumerate(self.MODALITY_NAMES):
img = batch[name] # [B, 1, D, H, W]
patches = patchify(img, self.patch_size) # [B, num_patches, 1, pd, ph, pw]
# Tokenize all patches (no masking)
tok = self.input_adapters[name](patches) # [B, num_patches, embed_dim]
# Add positional embedding
pos_emb = self.pos_embed.expand(B, -1, -1)
tok = tok + pos_emb
# Zero out tokens for missing modalities
mod_mask = observed[:, i:i+1].unsqueeze(-1) # [B, 1, 1]
tok = tok * mod_mask
tokens_list.append(tok)
offset += self.num_patches
input_tokens = torch.cat(tokens_list, dim=1)
# Add CLS token
if self.num_global_tokens > 0:
cls = self.global_tokens.unsqueeze(0).expand(B, -1, -1)
input_tokens = torch.cat([cls, input_tokens], dim=1)
# Build attention mask: prevent attending to tokens from missing modalities
total_tokens = input_tokens.shape[1]
attn_mask = torch.zeros(B, 1, 1, total_tokens, device=device)
mod_offset = self.num_global_tokens
for i, name in enumerate(self.MODALITY_NAMES):
start = mod_offset
end = mod_offset + self.num_patches
missing = (observed[:, i] < 0.5) # [B]
if missing.any():
attn_mask[missing, :, :, start:end] = float("-inf")
mod_offset = end
if (attn_mask == 0).all():
attn_mask = None
# Encode with attention mask
encoder_output = input_tokens
for block in self.encoder:
encoder_output = block(encoder_output, attn_mask=attn_mask)
return encoder_output
def create_multimae3d(
img_size: int = 128,
patch_size: int = 16,
embed_dim: int = 768,
depth: int = 12,
num_heads: int = 12,
decoder_embed_dim: int = 384,
decoder_depth: int = 2,
decoder_num_heads: int = 12,
mask_ratio: float = 0.75,
use_dirichlet: bool = True,
enable_cross_modal: bool = False,
**kwargs,
) -> MultiMAE3D:
"""Factory function to create MultiMAE3D with default ViT-B config."""
return MultiMAE3D(
img_size=img_size,
patch_size=patch_size,
embed_dim=embed_dim,
depth=depth,
num_heads=num_heads,
decoder_embed_dim=decoder_embed_dim,
decoder_depth=decoder_depth,
decoder_num_heads=decoder_num_heads,
mask_ratio=mask_ratio,
use_dirichlet=use_dirichlet,
enable_cross_modal=enable_cross_modal,
**kwargs,
)