File size: 6,931 Bytes
28b13fc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8356dae
28b13fc
 
8356dae
 
 
 
 
28b13fc
 
 
8356dae
 
 
 
 
 
28b13fc
 
 
cba2b6c
 
 
 
 
 
 
 
 
 
 
 
28b13fc
 
 
 
 
 
 
 
 
 
8356dae
 
 
 
 
28b13fc
8356dae
 
28b13fc
 
 
 
 
8356dae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""
projection.py
-------------
MLP alignment layer that projects BioViL-T patch features (768-dim)
into the LLM token embedding space (4096-dim for Vicuna-7B).

Inspired by RaDialog v2: uses a simple MLP projection instead of
the heavier Q-Former used in the original RaDialog / XrayGPT.
This is more parameter-efficient and easier to train.

The projection learns to:
  1. Pool patch features into a fixed number of visual tokens (32)
  2. Project each token from 768 → 4096 dims
  3. These tokens are then prepended to the text token sequence
"""

import torch
import torch.nn as nn
from typing import Optional


class MLPProjection(nn.Module):
    """
    Two-stage MLP alignment module:
      Stage 1 — Spatial pooling: reduces variable number of patches → num_image_tokens
      Stage 2 — Dimension projection: 768 → hidden_dim → llm_hidden_size

    Args:
        input_dim:        BioViL-T output dim (768)
        hidden_dim:       intermediate MLP dim (1024)
        output_dim:       LLM hidden size (4096 for Vicuna-7B)
        num_image_tokens: number of visual tokens passed to LLM (32, same as RaDialog)
        dropout:          dropout rate
    """

    def __init__(
        self,
        input_dim:        int = 768,
        hidden_dim:       int = 1024,
        output_dim:       int = 4096,
        num_image_tokens: int = 32,
        dropout:          float = 0.1,
    ):
        super().__init__()

        self.num_image_tokens = num_image_tokens
        self.input_dim        = input_dim
        self.output_dim       = output_dim

        # Learnable pooling: reduce patch sequence → num_image_tokens
        # Uses a learned query matrix (similar to perceiver resampler idea)
        self.query_tokens = nn.Parameter(
            torch.randn(1, num_image_tokens, input_dim)
        )
        self.cross_attn = nn.MultiheadAttention(
            embed_dim   = input_dim,
            num_heads   = 8,
            dropout     = dropout,
            batch_first = True,
        )

        # MLP projection: input_dim → hidden_dim → output_dim
        self.mlp = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, output_dim),
        )

        self._init_weights()

    def _init_weights(self):
        """Initialize weights with small normal values for stable training."""
        nn.init.normal_(self.query_tokens, std=0.02)
        for module in self.mlp.modules():
            if isinstance(module, nn.Linear):
                nn.init.normal_(module.weight, std=0.02)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)

    def forward(self, patch_features: torch.Tensor, return_intermediate: bool = False):
        """
        Args:
            patch_features:      (B, num_patches, 768) — output from BioViL-T
            return_intermediate: also return the hidden 1024-d feature tapped
                                 between the two MLP linears. This is the
                                 grounding feature the ITC head operates on
                                 (Stage-1 image-text contrastive alignment).

        Returns:
            image_tokens: (B, num_image_tokens, 4096) — visual tokens for LLM
            (if return_intermediate) hidden: (B, num_image_tokens, hidden_dim)

        Note: `self.mlp` is kept as a single nn.Sequential (NOT split into
        named submodules) so existing stage1/stage2 checkpoints
        (mlp.0.*, mlp.3.*) load unchanged. We just run it in two slices to
        tap the intermediate activation.
        """
        B = patch_features.size(0)

        # Align input dtype with the projection's own parameter dtype.
        # The frozen image encoder may run in bf16/fp16 (llm_dtype) while
        # the projection's MLP/MHA weights stay fp32. Under bf16 autocast,
        # nn.MultiheadAttention's in-projection sometimes bypasses autocast
        # (cross-attention path), giving:
        #   RuntimeError: mat1 and mat2 must have the same dtype: BFloat16 vs Float
        # Upcasting patch_features keeps the matmul self-consistent on any
        # GPU/precision. No-op when dtypes already match (T4 fp16 fast path).
        target_dtype = self.query_tokens.dtype
        if patch_features.dtype != target_dtype:
            patch_features = patch_features.to(target_dtype)

        # Expand query tokens to batch size
        queries = self.query_tokens.expand(B, -1, -1)  # (B, 32, 768)

        # Cross-attention: queries attend over patch features
        pooled, _ = self.cross_attn(
            query = queries,           # (B, 32, 768)
            key   = patch_features,    # (B, num_patches, 768)
            value = patch_features,    # (B, num_patches, 768)
        )  # pooled: (B, 32, 768)

        # MLP projection → LLM space, tapping the 1024-d intermediate.
        #   self.mlp[:3] = Linear(768→1024) + GELU + Dropout
        #   self.mlp[3:] = Linear(1024→4096)
        hidden       = self.mlp[:3](pooled)   # (B, 32, hidden_dim=1024)
        image_tokens = self.mlp[3:](hidden)   # (B, 32, output_dim=4096)

        if return_intermediate:
            return image_tokens, hidden
        return image_tokens

    @property
    def num_trainable_params(self) -> int:
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


class ITCHead(nn.Module):
    """
    Image-Text Contrastive head (Stage-1 explicit alignment, BLIP-2 ITC style).

    Pools the projection's 32 intermediate visual tokens (1024-d) into a
    single vector and projects it into the joint contrastive space shared
    with CXR-BERT's `get_projected_text_embeddings` output (128-d, L2-norm).

    Used ONLY in the ITC Stage-1 mode; it never touches the LLM. The output
    is compared against precomputed, cached text embeddings via InfoNCE.

    Args:
        hidden_dim: projection intermediate dim (1024)
        proj_dim:   joint contrastive space dim — MUST match the text
                    encoder's projected dim (CXR-BERT-specialized = 128)
    """

    def __init__(self, hidden_dim: int = 1024, proj_dim: int = 128):
        super().__init__()
        self.proj_dim = proj_dim
        self.proj = nn.Linear(hidden_dim, proj_dim)
        nn.init.normal_(self.proj.weight, std=0.02)
        nn.init.zeros_(self.proj.bias)

    def forward(self, hidden: torch.Tensor) -> torch.Tensor:
        """
        Args:
            hidden: (B, num_image_tokens, hidden_dim) — projection intermediate

        Returns:
            img_embed: (B, proj_dim) — L2-normalized image embedding in the
                       joint image-text space.
        """
        pooled = hidden.mean(dim=1)                       # (B, hidden_dim)
        embed  = self.proj(pooled)                        # (B, proj_dim)
        return torch.nn.functional.normalize(embed, dim=-1)