| """ |
| projection.py |
| ------------- |
| MLP alignment layer that projects BioViL-T patch features (768-dim) |
| into the LLM token embedding space (4096-dim for Vicuna-7B). |
| |
| Inspired by RaDialog v2: uses a simple MLP projection instead of |
| the heavier Q-Former used in the original RaDialog / XrayGPT. |
| This is more parameter-efficient and easier to train. |
| |
| The projection learns to: |
| 1. Pool patch features into a fixed number of visual tokens (32) |
| 2. Project each token from 768 → 4096 dims |
| 3. These tokens are then prepended to the text token sequence |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| from typing import Optional |
|
|
|
|
| class MLPProjection(nn.Module): |
| """ |
| Two-stage MLP alignment module: |
| Stage 1 — Spatial pooling: reduces variable number of patches → num_image_tokens |
| Stage 2 — Dimension projection: 768 → hidden_dim → llm_hidden_size |
| |
| Args: |
| input_dim: BioViL-T output dim (768) |
| hidden_dim: intermediate MLP dim (1024) |
| output_dim: LLM hidden size (4096 for Vicuna-7B) |
| num_image_tokens: number of visual tokens passed to LLM (32, same as RaDialog) |
| dropout: dropout rate |
| """ |
|
|
| def __init__( |
| self, |
| input_dim: int = 768, |
| hidden_dim: int = 1024, |
| output_dim: int = 4096, |
| num_image_tokens: int = 32, |
| dropout: float = 0.1, |
| ): |
| super().__init__() |
|
|
| self.num_image_tokens = num_image_tokens |
| self.input_dim = input_dim |
| self.output_dim = output_dim |
|
|
| |
| |
| self.query_tokens = nn.Parameter( |
| torch.randn(1, num_image_tokens, input_dim) |
| ) |
| self.cross_attn = nn.MultiheadAttention( |
| embed_dim = input_dim, |
| num_heads = 8, |
| dropout = dropout, |
| batch_first = True, |
| ) |
|
|
| |
| self.mlp = nn.Sequential( |
| nn.Linear(input_dim, hidden_dim), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, output_dim), |
| ) |
|
|
| self._init_weights() |
|
|
| def _init_weights(self): |
| """Initialize weights with small normal values for stable training.""" |
| nn.init.normal_(self.query_tokens, std=0.02) |
| for module in self.mlp.modules(): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
| def forward(self, patch_features: torch.Tensor, return_intermediate: bool = False): |
| """ |
| Args: |
| patch_features: (B, num_patches, 768) — output from BioViL-T |
| return_intermediate: also return the hidden 1024-d feature tapped |
| between the two MLP linears. This is the |
| grounding feature the ITC head operates on |
| (Stage-1 image-text contrastive alignment). |
| |
| Returns: |
| image_tokens: (B, num_image_tokens, 4096) — visual tokens for LLM |
| (if return_intermediate) hidden: (B, num_image_tokens, hidden_dim) |
| |
| Note: `self.mlp` is kept as a single nn.Sequential (NOT split into |
| named submodules) so existing stage1/stage2 checkpoints |
| (mlp.0.*, mlp.3.*) load unchanged. We just run it in two slices to |
| tap the intermediate activation. |
| """ |
| B = patch_features.size(0) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| target_dtype = self.query_tokens.dtype |
| if patch_features.dtype != target_dtype: |
| patch_features = patch_features.to(target_dtype) |
|
|
| |
| queries = self.query_tokens.expand(B, -1, -1) |
|
|
| |
| pooled, _ = self.cross_attn( |
| query = queries, |
| key = patch_features, |
| value = patch_features, |
| ) |
|
|
| |
| |
| |
| hidden = self.mlp[:3](pooled) |
| image_tokens = self.mlp[3:](hidden) |
|
|
| if return_intermediate: |
| return image_tokens, hidden |
| return image_tokens |
|
|
| @property |
| def num_trainable_params(self) -> int: |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|
|
| class ITCHead(nn.Module): |
| """ |
| Image-Text Contrastive head (Stage-1 explicit alignment, BLIP-2 ITC style). |
| |
| Pools the projection's 32 intermediate visual tokens (1024-d) into a |
| single vector and projects it into the joint contrastive space shared |
| with CXR-BERT's `get_projected_text_embeddings` output (128-d, L2-norm). |
| |
| Used ONLY in the ITC Stage-1 mode; it never touches the LLM. The output |
| is compared against precomputed, cached text embeddings via InfoNCE. |
| |
| Args: |
| hidden_dim: projection intermediate dim (1024) |
| proj_dim: joint contrastive space dim — MUST match the text |
| encoder's projected dim (CXR-BERT-specialized = 128) |
| """ |
|
|
| def __init__(self, hidden_dim: int = 1024, proj_dim: int = 128): |
| super().__init__() |
| self.proj_dim = proj_dim |
| self.proj = nn.Linear(hidden_dim, proj_dim) |
| nn.init.normal_(self.proj.weight, std=0.02) |
| nn.init.zeros_(self.proj.bias) |
|
|
| def forward(self, hidden: torch.Tensor) -> torch.Tensor: |
| """ |
| Args: |
| hidden: (B, num_image_tokens, hidden_dim) — projection intermediate |
| |
| Returns: |
| img_embed: (B, proj_dim) — L2-normalized image embedding in the |
| joint image-text space. |
| """ |
| pooled = hidden.mean(dim=1) |
| embed = self.proj(pooled) |
| return torch.nn.functional.normalize(embed, dim=-1) |
|
|