"""Projection adapter: encoder output → LFM2.5 hidden dim. This is the load-bearing module that lifts the per-transaction encoder's output into the pretrained LFM2.5 backbone's continuous-token space. The architecture mirrors LFM2-VL's `Lfm2VlMultiModalProjector` exactly: LayerNorm → Linear(d_encoder → hidden) → GELU → Linear(hidden → d_lfm) Why exactly this shape: - **LayerNorm at the front.** The encoder is freshly initialized; its outputs have arbitrary norm. LayerNorm gives the projector a unit-norm input distribution and stabilizes early training. LFM2-VL's config defaults to `projector_use_layernorm=True`. - **2 layers, not 1.** LLaVA-1.5 ([arXiv 2310.03744](https://arxiv.org/abs/2310.03744)) switched from a single Linear (LLaVA-1.0) to a 2-layer MLP and got materially better benchmarks. LFM2-VL adopted this. The non-linear bridge handles the encoder-domain ↔ text-pretrained-domain gap that a single Linear cannot. - **GELU activation.** LFM2-VL uses GELU (`projector_hidden_act="gelu"`). The encoder uses SiLU internally (matching LFM2's SwiGLU MLPs), but the projector sits outside the LFM backbone and follows the VL projector convention. Cross-modality consistency matters more here than intra-encoder consistency. - **hidden = 2 * d_lfm by default.** LFM2-VL uses `projector_hidden_size=2560` for `text_hidden_size=2048` (ratio 1.25). We use 2x for the 350M (hidden 2048 for d_lfm=1024) — slightly more capacity at minor parameter cost. Defensible default; revisit if the projector becomes a bottleneck. Shape contract: (B, T, d_encoder) → (B, T, d_lfm) """ from __future__ import annotations import torch import torch.nn as nn class ProjectionAdapter(nn.Module): """LFM2-VL-shaped 2-layer MLP projector with input LayerNorm. Args: d_encoder: input feature dim from the per-transaction encoder. d_lfm: output feature dim — must match the LFM2.5 backbone's hidden size (1024 for LFM2.5-350M, 2048 for LFM2.5-1.2B). hidden: intermediate projector hidden size. Defaults to `2 * d_lfm` following LFM2-VL's pattern of `projector_hidden_size > d_lfm`. use_layernorm: include LayerNorm at the input (LFM2-VL default: True). """ def __init__( self, d_encoder: int = 256, d_lfm: int = 1024, hidden: int | None = None, use_layernorm: bool = True, ) -> None: super().__init__() if hidden is None: hidden = 2 * d_lfm self.input_norm: nn.Module = ( nn.LayerNorm(d_encoder) if use_layernorm else nn.Identity() ) self.up = nn.Linear(d_encoder, hidden) self.act = nn.GELU() self.down = nn.Linear(hidden, d_lfm) def forward(self, x: torch.Tensor) -> torch.Tensor: # x: (B, T, d_encoder) x = self.input_norm(x) x = self.up(x) x = self.act(x) return self.down(x) # → (B, T, d_lfm)