""" projection.py ------------- MLP alignment layer that projects BioViL-T patch features (768-dim) into the LLM token embedding space (4096-dim for Vicuna-7B). Inspired by RaDialog v2: uses a simple MLP projection instead of the heavier Q-Former used in the original RaDialog / XrayGPT. This is more parameter-efficient and easier to train. The projection learns to: 1. Pool patch features into a fixed number of visual tokens (32) 2. Project each token from 768 → 4096 dims 3. These tokens are then prepended to the text token sequence """ import torch import torch.nn as nn from typing import Optional class MLPProjection(nn.Module): """ Two-stage MLP alignment module: Stage 1 — Spatial pooling: reduces variable number of patches → num_image_tokens Stage 2 — Dimension projection: 768 → hidden_dim → llm_hidden_size Args: input_dim: BioViL-T output dim (768) hidden_dim: intermediate MLP dim (1024) output_dim: LLM hidden size (4096 for Vicuna-7B) num_image_tokens: number of visual tokens passed to LLM (32, same as RaDialog) dropout: dropout rate """ def __init__( self, input_dim: int = 768, hidden_dim: int = 1024, output_dim: int = 4096, num_image_tokens: int = 32, dropout: float = 0.1, ): super().__init__() self.num_image_tokens = num_image_tokens self.input_dim = input_dim self.output_dim = output_dim # Learnable pooling: reduce patch sequence → num_image_tokens # Uses a learned query matrix (similar to perceiver resampler idea) self.query_tokens = nn.Parameter( torch.randn(1, num_image_tokens, input_dim) ) self.cross_attn = nn.MultiheadAttention( embed_dim = input_dim, num_heads = 8, dropout = dropout, batch_first = True, ) # MLP projection: input_dim → hidden_dim → output_dim self.mlp = nn.Sequential( nn.Linear(input_dim, hidden_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(hidden_dim, output_dim), ) self._init_weights() def _init_weights(self): """Initialize weights with small normal values for stable training.""" nn.init.normal_(self.query_tokens, std=0.02) for module in self.mlp.modules(): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) def forward(self, patch_features: torch.Tensor, return_intermediate: bool = False): """ Args: patch_features: (B, num_patches, 768) — output from BioViL-T return_intermediate: also return the hidden 1024-d feature tapped between the two MLP linears. This is the grounding feature the ITC head operates on (Stage-1 image-text contrastive alignment). Returns: image_tokens: (B, num_image_tokens, 4096) — visual tokens for LLM (if return_intermediate) hidden: (B, num_image_tokens, hidden_dim) Note: `self.mlp` is kept as a single nn.Sequential (NOT split into named submodules) so existing stage1/stage2 checkpoints (mlp.0.*, mlp.3.*) load unchanged. We just run it in two slices to tap the intermediate activation. """ B = patch_features.size(0) # Align input dtype with the projection's own parameter dtype. # The frozen image encoder may run in bf16/fp16 (llm_dtype) while # the projection's MLP/MHA weights stay fp32. Under bf16 autocast, # nn.MultiheadAttention's in-projection sometimes bypasses autocast # (cross-attention path), giving: # RuntimeError: mat1 and mat2 must have the same dtype: BFloat16 vs Float # Upcasting patch_features keeps the matmul self-consistent on any # GPU/precision. No-op when dtypes already match (T4 fp16 fast path). target_dtype = self.query_tokens.dtype if patch_features.dtype != target_dtype: patch_features = patch_features.to(target_dtype) # Expand query tokens to batch size queries = self.query_tokens.expand(B, -1, -1) # (B, 32, 768) # Cross-attention: queries attend over patch features pooled, _ = self.cross_attn( query = queries, # (B, 32, 768) key = patch_features, # (B, num_patches, 768) value = patch_features, # (B, num_patches, 768) ) # pooled: (B, 32, 768) # MLP projection → LLM space, tapping the 1024-d intermediate. # self.mlp[:3] = Linear(768→1024) + GELU + Dropout # self.mlp[3:] = Linear(1024→4096) hidden = self.mlp[:3](pooled) # (B, 32, hidden_dim=1024) image_tokens = self.mlp[3:](hidden) # (B, 32, output_dim=4096) if return_intermediate: return image_tokens, hidden return image_tokens @property def num_trainable_params(self) -> int: return sum(p.numel() for p in self.parameters() if p.requires_grad) class ITCHead(nn.Module): """ Image-Text Contrastive head (Stage-1 explicit alignment, BLIP-2 ITC style). Pools the projection's 32 intermediate visual tokens (1024-d) into a single vector and projects it into the joint contrastive space shared with CXR-BERT's `get_projected_text_embeddings` output (128-d, L2-norm). Used ONLY in the ITC Stage-1 mode; it never touches the LLM. The output is compared against precomputed, cached text embeddings via InfoNCE. Args: hidden_dim: projection intermediate dim (1024) proj_dim: joint contrastive space dim — MUST match the text encoder's projected dim (CXR-BERT-specialized = 128) """ def __init__(self, hidden_dim: int = 1024, proj_dim: int = 128): super().__init__() self.proj_dim = proj_dim self.proj = nn.Linear(hidden_dim, proj_dim) nn.init.normal_(self.proj.weight, std=0.02) nn.init.zeros_(self.proj.bias) def forward(self, hidden: torch.Tensor) -> torch.Tensor: """ Args: hidden: (B, num_image_tokens, hidden_dim) — projection intermediate Returns: img_embed: (B, proj_dim) — L2-normalized image embedding in the joint image-text space. """ pooled = hidden.mean(dim=1) # (B, hidden_dim) embed = self.proj(pooled) # (B, proj_dim) return torch.nn.functional.normalize(embed, dim=-1)