File size: 6,931 Bytes
28b13fc 8356dae 28b13fc 8356dae 28b13fc 8356dae 28b13fc cba2b6c 28b13fc 8356dae 28b13fc 8356dae 28b13fc 8356dae | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | """
projection.py
-------------
MLP alignment layer that projects BioViL-T patch features (768-dim)
into the LLM token embedding space (4096-dim for Vicuna-7B).
Inspired by RaDialog v2: uses a simple MLP projection instead of
the heavier Q-Former used in the original RaDialog / XrayGPT.
This is more parameter-efficient and easier to train.
The projection learns to:
1. Pool patch features into a fixed number of visual tokens (32)
2. Project each token from 768 → 4096 dims
3. These tokens are then prepended to the text token sequence
"""
import torch
import torch.nn as nn
from typing import Optional
class MLPProjection(nn.Module):
"""
Two-stage MLP alignment module:
Stage 1 — Spatial pooling: reduces variable number of patches → num_image_tokens
Stage 2 — Dimension projection: 768 → hidden_dim → llm_hidden_size
Args:
input_dim: BioViL-T output dim (768)
hidden_dim: intermediate MLP dim (1024)
output_dim: LLM hidden size (4096 for Vicuna-7B)
num_image_tokens: number of visual tokens passed to LLM (32, same as RaDialog)
dropout: dropout rate
"""
def __init__(
self,
input_dim: int = 768,
hidden_dim: int = 1024,
output_dim: int = 4096,
num_image_tokens: int = 32,
dropout: float = 0.1,
):
super().__init__()
self.num_image_tokens = num_image_tokens
self.input_dim = input_dim
self.output_dim = output_dim
# Learnable pooling: reduce patch sequence → num_image_tokens
# Uses a learned query matrix (similar to perceiver resampler idea)
self.query_tokens = nn.Parameter(
torch.randn(1, num_image_tokens, input_dim)
)
self.cross_attn = nn.MultiheadAttention(
embed_dim = input_dim,
num_heads = 8,
dropout = dropout,
batch_first = True,
)
# MLP projection: input_dim → hidden_dim → output_dim
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, output_dim),
)
self._init_weights()
def _init_weights(self):
"""Initialize weights with small normal values for stable training."""
nn.init.normal_(self.query_tokens, std=0.02)
for module in self.mlp.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, patch_features: torch.Tensor, return_intermediate: bool = False):
"""
Args:
patch_features: (B, num_patches, 768) — output from BioViL-T
return_intermediate: also return the hidden 1024-d feature tapped
between the two MLP linears. This is the
grounding feature the ITC head operates on
(Stage-1 image-text contrastive alignment).
Returns:
image_tokens: (B, num_image_tokens, 4096) — visual tokens for LLM
(if return_intermediate) hidden: (B, num_image_tokens, hidden_dim)
Note: `self.mlp` is kept as a single nn.Sequential (NOT split into
named submodules) so existing stage1/stage2 checkpoints
(mlp.0.*, mlp.3.*) load unchanged. We just run it in two slices to
tap the intermediate activation.
"""
B = patch_features.size(0)
# Align input dtype with the projection's own parameter dtype.
# The frozen image encoder may run in bf16/fp16 (llm_dtype) while
# the projection's MLP/MHA weights stay fp32. Under bf16 autocast,
# nn.MultiheadAttention's in-projection sometimes bypasses autocast
# (cross-attention path), giving:
# RuntimeError: mat1 and mat2 must have the same dtype: BFloat16 vs Float
# Upcasting patch_features keeps the matmul self-consistent on any
# GPU/precision. No-op when dtypes already match (T4 fp16 fast path).
target_dtype = self.query_tokens.dtype
if patch_features.dtype != target_dtype:
patch_features = patch_features.to(target_dtype)
# Expand query tokens to batch size
queries = self.query_tokens.expand(B, -1, -1) # (B, 32, 768)
# Cross-attention: queries attend over patch features
pooled, _ = self.cross_attn(
query = queries, # (B, 32, 768)
key = patch_features, # (B, num_patches, 768)
value = patch_features, # (B, num_patches, 768)
) # pooled: (B, 32, 768)
# MLP projection → LLM space, tapping the 1024-d intermediate.
# self.mlp[:3] = Linear(768→1024) + GELU + Dropout
# self.mlp[3:] = Linear(1024→4096)
hidden = self.mlp[:3](pooled) # (B, 32, hidden_dim=1024)
image_tokens = self.mlp[3:](hidden) # (B, 32, output_dim=4096)
if return_intermediate:
return image_tokens, hidden
return image_tokens
@property
def num_trainable_params(self) -> int:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
class ITCHead(nn.Module):
"""
Image-Text Contrastive head (Stage-1 explicit alignment, BLIP-2 ITC style).
Pools the projection's 32 intermediate visual tokens (1024-d) into a
single vector and projects it into the joint contrastive space shared
with CXR-BERT's `get_projected_text_embeddings` output (128-d, L2-norm).
Used ONLY in the ITC Stage-1 mode; it never touches the LLM. The output
is compared against precomputed, cached text embeddings via InfoNCE.
Args:
hidden_dim: projection intermediate dim (1024)
proj_dim: joint contrastive space dim — MUST match the text
encoder's projected dim (CXR-BERT-specialized = 128)
"""
def __init__(self, hidden_dim: int = 1024, proj_dim: int = 128):
super().__init__()
self.proj_dim = proj_dim
self.proj = nn.Linear(hidden_dim, proj_dim)
nn.init.normal_(self.proj.weight, std=0.02)
nn.init.zeros_(self.proj.bias)
def forward(self, hidden: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden: (B, num_image_tokens, hidden_dim) — projection intermediate
Returns:
img_embed: (B, proj_dim) — L2-normalized image embedding in the
joint image-text space.
"""
pooled = hidden.mean(dim=1) # (B, hidden_dim)
embed = self.proj(pooled) # (B, proj_dim)
return torch.nn.functional.normalize(embed, dim=-1)
|