File size: 2,924 Bytes
fd601de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "[Segmentation Mask] --> [Encoder (U-Net style or ViT Patch Embedding)] ---> Q\n",
    "\n",
    "[Text Condition: \"CT\", \"T1-MR\"] --> [Condition Embedding] --> K, V\n",
    "\n",
    "Q, K, V ---> [Cross-Attention Block] ---> Fused Feature Map --> UNet Backbone --> DDPM/Diffusion Head\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let's write the PyTorch code for the Modality Field Adapter (MFA) module.\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "\n",
    "class ModalityFieldAdapter(nn.Module):\n",
    "    def __init__(self, in_channels, cond_dim, embed_dim=128, num_heads=4):\n",
    "        super(ModalityFieldAdapter, self).__init__()\n",
    "        self.anatomy_proj = nn.Conv2d(in_channels, embed_dim, kernel_size=1)\n",
    "        self.modality_fc = nn.Linear(cond_dim, embed_dim)\n",
    "        self.cross_attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)\n",
    "        self.out_proj = nn.Conv2d(embed_dim, in_channels, kernel_size=1)\n",
    "        self.norm = nn.LayerNorm(embed_dim)\n",
    "\n",
    "    def forward(self, x, modality_cond):\n",
    "        \"\"\"\n",
    "        x: segmentation-based feature map (B, C, H, W)\n",
    "        modality_cond: modality condition vector (B, cond_dim), e.g., one-hot [1,0,0] for CT\n",
    "        \"\"\"\n",
    "        B, C, H, W = x.shape\n",
    "\n",
    "        # Project anatomy features to token space\n",
    "        anatomy_feat = self.anatomy_proj(x)  # (B, embed_dim, H, W)\n",
    "        anatomy_tokens = anatomy_feat.flatten(2).transpose(1, 2)  # (B, HW, embed_dim)\n",
    "\n",
    "        # Get modality embedding and expand\n",
    "        modality_embed = self.modality_fc(modality_cond).unsqueeze(1)  # (B, 1, embed_dim)\n",
    "\n",
    "        # Cross attention: Q=modality, K/V=anatomy tokens\n",
    "        attn_out, _ = self.cross_attn(query=modality_embed, key=anatomy_tokens, value=anatomy_tokens)  # (B, 1, embed_dim)\n",
    "\n",
    "        # Broadcast attention output back to spatial map\n",
    "        attn_map = attn_out.repeat(1, H * W, 1).reshape(B, H, W, -1).permute(0, 3, 1, 2)  # (B, embed_dim, H, W)\n",
    "\n",
    "        # Combine with anatomy features\n",
    "        fused = anatomy_feat + attn_map\n",
    "        fused = self.norm(fused.flatten(2).transpose(1, 2)).transpose(1, 2).view(B, -1, H, W)\n",
    "\n",
    "        return self.out_proj(fused)  # (B, C, H, W)\n",
    "\n",
    "# Sample instantiation\n",
    "mfa = ModalityFieldAdapter(in_channels=1, cond_dim=3)\n",
    "mfa\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}