Spaces:
Runtime error
Runtime error
File size: 2,924 Bytes
fd601de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 |
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[Segmentation Mask] --> [Encoder (U-Net style or ViT Patch Embedding)] ---> Q\n",
"\n",
"[Text Condition: \"CT\", \"T1-MR\"] --> [Condition Embedding] --> K, V\n",
"\n",
"Q, K, V ---> [Cross-Attention Block] ---> Fused Feature Map --> UNet Backbone --> DDPM/Diffusion Head\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Let's write the PyTorch code for the Modality Field Adapter (MFA) module.\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class ModalityFieldAdapter(nn.Module):\n",
" def __init__(self, in_channels, cond_dim, embed_dim=128, num_heads=4):\n",
" super(ModalityFieldAdapter, self).__init__()\n",
" self.anatomy_proj = nn.Conv2d(in_channels, embed_dim, kernel_size=1)\n",
" self.modality_fc = nn.Linear(cond_dim, embed_dim)\n",
" self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)\n",
" self.out_proj = nn.Conv2d(embed_dim, in_channels, kernel_size=1)\n",
" self.norm = nn.LayerNorm(embed_dim)\n",
"\n",
" def forward(self, x, modality_cond):\n",
" \"\"\"\n",
" x: segmentation-based feature map (B, C, H, W)\n",
" modality_cond: modality condition vector (B, cond_dim), e.g., one-hot [1,0,0] for CT\n",
" \"\"\"\n",
" B, C, H, W = x.shape\n",
"\n",
" # Project anatomy features to token space\n",
" anatomy_feat = self.anatomy_proj(x) # (B, embed_dim, H, W)\n",
" anatomy_tokens = anatomy_feat.flatten(2).transpose(1, 2) # (B, HW, embed_dim)\n",
"\n",
" # Get modality embedding and expand\n",
" modality_embed = self.modality_fc(modality_cond).unsqueeze(1) # (B, 1, embed_dim)\n",
"\n",
" # Cross attention: Q=modality, K/V=anatomy tokens\n",
" attn_out, _ = self.cross_attn(query=modality_embed, key=anatomy_tokens, value=anatomy_tokens) # (B, 1, embed_dim)\n",
"\n",
" # Broadcast attention output back to spatial map\n",
" attn_map = attn_out.repeat(1, H * W, 1).reshape(B, H, W, -1).permute(0, 3, 1, 2) # (B, embed_dim, H, W)\n",
"\n",
" # Combine with anatomy features\n",
" fused = anatomy_feat + attn_map\n",
" fused = self.norm(fused.flatten(2).transpose(1, 2)).transpose(1, 2).view(B, -1, H, W)\n",
"\n",
" return self.out_proj(fused) # (B, C, H, W)\n",
"\n",
"# Sample instantiation\n",
"mfa = ModalityFieldAdapter(in_channels=1, cond_dim=3)\n",
"mfa\n"
]
}
],
"metadata": {
"language_info": {
"name": "python"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
|