cxr-vlm-code / model /projection.py
convitom
d
cba2b6c
"""
projection.py
-------------
MLP alignment layer that projects BioViL-T patch features (768-dim)
into the LLM token embedding space (4096-dim for Vicuna-7B).
Inspired by RaDialog v2: uses a simple MLP projection instead of
the heavier Q-Former used in the original RaDialog / XrayGPT.
This is more parameter-efficient and easier to train.
The projection learns to:
1. Pool patch features into a fixed number of visual tokens (32)
2. Project each token from 768 → 4096 dims
3. These tokens are then prepended to the text token sequence
"""
import torch
import torch.nn as nn
from typing import Optional
class MLPProjection(nn.Module):
"""
Two-stage MLP alignment module:
Stage 1 — Spatial pooling: reduces variable number of patches → num_image_tokens
Stage 2 — Dimension projection: 768 → hidden_dim → llm_hidden_size
Args:
input_dim: BioViL-T output dim (768)
hidden_dim: intermediate MLP dim (1024)
output_dim: LLM hidden size (4096 for Vicuna-7B)
num_image_tokens: number of visual tokens passed to LLM (32, same as RaDialog)
dropout: dropout rate
"""
def __init__(
self,
input_dim: int = 768,
hidden_dim: int = 1024,
output_dim: int = 4096,
num_image_tokens: int = 32,
dropout: float = 0.1,
):
super().__init__()
self.num_image_tokens = num_image_tokens
self.input_dim = input_dim
self.output_dim = output_dim
# Learnable pooling: reduce patch sequence → num_image_tokens
# Uses a learned query matrix (similar to perceiver resampler idea)
self.query_tokens = nn.Parameter(
torch.randn(1, num_image_tokens, input_dim)
)
self.cross_attn = nn.MultiheadAttention(
embed_dim = input_dim,
num_heads = 8,
dropout = dropout,
batch_first = True,
)
# MLP projection: input_dim → hidden_dim → output_dim
self.mlp = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, output_dim),
)
self._init_weights()
def _init_weights(self):
"""Initialize weights with small normal values for stable training."""
nn.init.normal_(self.query_tokens, std=0.02)
for module in self.mlp.modules():
if isinstance(module, nn.Linear):
nn.init.normal_(module.weight, std=0.02)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(self, patch_features: torch.Tensor, return_intermediate: bool = False):
"""
Args:
patch_features: (B, num_patches, 768) — output from BioViL-T
return_intermediate: also return the hidden 1024-d feature tapped
between the two MLP linears. This is the
grounding feature the ITC head operates on
(Stage-1 image-text contrastive alignment).
Returns:
image_tokens: (B, num_image_tokens, 4096) — visual tokens for LLM
(if return_intermediate) hidden: (B, num_image_tokens, hidden_dim)
Note: `self.mlp` is kept as a single nn.Sequential (NOT split into
named submodules) so existing stage1/stage2 checkpoints
(mlp.0.*, mlp.3.*) load unchanged. We just run it in two slices to
tap the intermediate activation.
"""
B = patch_features.size(0)
# Align input dtype with the projection's own parameter dtype.
# The frozen image encoder may run in bf16/fp16 (llm_dtype) while
# the projection's MLP/MHA weights stay fp32. Under bf16 autocast,
# nn.MultiheadAttention's in-projection sometimes bypasses autocast
# (cross-attention path), giving:
# RuntimeError: mat1 and mat2 must have the same dtype: BFloat16 vs Float
# Upcasting patch_features keeps the matmul self-consistent on any
# GPU/precision. No-op when dtypes already match (T4 fp16 fast path).
target_dtype = self.query_tokens.dtype
if patch_features.dtype != target_dtype:
patch_features = patch_features.to(target_dtype)
# Expand query tokens to batch size
queries = self.query_tokens.expand(B, -1, -1) # (B, 32, 768)
# Cross-attention: queries attend over patch features
pooled, _ = self.cross_attn(
query = queries, # (B, 32, 768)
key = patch_features, # (B, num_patches, 768)
value = patch_features, # (B, num_patches, 768)
) # pooled: (B, 32, 768)
# MLP projection → LLM space, tapping the 1024-d intermediate.
# self.mlp[:3] = Linear(768→1024) + GELU + Dropout
# self.mlp[3:] = Linear(1024→4096)
hidden = self.mlp[:3](pooled) # (B, 32, hidden_dim=1024)
image_tokens = self.mlp[3:](hidden) # (B, 32, output_dim=4096)
if return_intermediate:
return image_tokens, hidden
return image_tokens
@property
def num_trainable_params(self) -> int:
return sum(p.numel() for p in self.parameters() if p.requires_grad)
class ITCHead(nn.Module):
"""
Image-Text Contrastive head (Stage-1 explicit alignment, BLIP-2 ITC style).
Pools the projection's 32 intermediate visual tokens (1024-d) into a
single vector and projects it into the joint contrastive space shared
with CXR-BERT's `get_projected_text_embeddings` output (128-d, L2-norm).
Used ONLY in the ITC Stage-1 mode; it never touches the LLM. The output
is compared against precomputed, cached text embeddings via InfoNCE.
Args:
hidden_dim: projection intermediate dim (1024)
proj_dim: joint contrastive space dim — MUST match the text
encoder's projected dim (CXR-BERT-specialized = 128)
"""
def __init__(self, hidden_dim: int = 1024, proj_dim: int = 128):
super().__init__()
self.proj_dim = proj_dim
self.proj = nn.Linear(hidden_dim, proj_dim)
nn.init.normal_(self.proj.weight, std=0.02)
nn.init.zeros_(self.proj.bias)
def forward(self, hidden: torch.Tensor) -> torch.Tensor:
"""
Args:
hidden: (B, num_image_tokens, hidden_dim) — projection intermediate
Returns:
img_embed: (B, proj_dim) — L2-normalized image embedding in the
joint image-text space.
"""
pooled = hidden.mean(dim=1) # (B, hidden_dim)
embed = self.proj(pooled) # (B, proj_dim)
return torch.nn.functional.normalize(embed, dim=-1)