lfm2-transaction-encoder / encoder /src /model /projection_adapter.py
cdotsanghvi's picture
initial transaction co-pilot deployment
b3112c7
Raw
History Blame Contribute Delete
3.01 kB
"""Projection adapter: encoder output → LFM2.5 hidden dim.
This is the load-bearing module that lifts the per-transaction encoder's
output into the pretrained LFM2.5 backbone's continuous-token space. The
architecture mirrors LFM2-VL's `Lfm2VlMultiModalProjector` exactly:
LayerNorm → Linear(d_encoder → hidden) → GELU → Linear(hidden → d_lfm)
Why exactly this shape:
- **LayerNorm at the front.** The encoder is freshly initialized; its
outputs have arbitrary norm. LayerNorm gives the projector a unit-norm
input distribution and stabilizes early training. LFM2-VL's config
defaults to `projector_use_layernorm=True`.
- **2 layers, not 1.** LLaVA-1.5 ([arXiv 2310.03744](https://arxiv.org/abs/2310.03744))
switched from a single Linear (LLaVA-1.0) to a 2-layer MLP and got
materially better benchmarks. LFM2-VL adopted this. The non-linear
bridge handles the encoder-domain ↔ text-pretrained-domain gap that a
single Linear cannot.
- **GELU activation.** LFM2-VL uses GELU (`projector_hidden_act="gelu"`).
The encoder uses SiLU internally (matching LFM2's SwiGLU MLPs), but the
projector sits outside the LFM backbone and follows the VL projector
convention. Cross-modality consistency matters more here than
intra-encoder consistency.
- **hidden = 2 * d_lfm by default.** LFM2-VL uses `projector_hidden_size=2560`
for `text_hidden_size=2048` (ratio 1.25). We use 2x for the 350M (hidden
2048 for d_lfm=1024) — slightly more capacity at minor parameter cost.
Defensible default; revisit if the projector becomes a bottleneck.
Shape contract:
(B, T, d_encoder) → (B, T, d_lfm)
"""
from __future__ import annotations
import torch
import torch.nn as nn
class ProjectionAdapter(nn.Module):
"""LFM2-VL-shaped 2-layer MLP projector with input LayerNorm.
Args:
d_encoder: input feature dim from the per-transaction encoder.
d_lfm: output feature dim — must match the LFM2.5 backbone's hidden
size (1024 for LFM2.5-350M, 2048 for LFM2.5-1.2B).
hidden: intermediate projector hidden size. Defaults to `2 * d_lfm`
following LFM2-VL's pattern of `projector_hidden_size > d_lfm`.
use_layernorm: include LayerNorm at the input (LFM2-VL default: True).
"""
def __init__(
self,
d_encoder: int = 256,
d_lfm: int = 1024,
hidden: int | None = None,
use_layernorm: bool = True,
) -> None:
super().__init__()
if hidden is None:
hidden = 2 * d_lfm
self.input_norm: nn.Module = (
nn.LayerNorm(d_encoder) if use_layernorm else nn.Identity()
)
self.up = nn.Linear(d_encoder, hidden)
self.act = nn.GELU()
self.down = nn.Linear(hidden, d_lfm)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (B, T, d_encoder)
x = self.input_norm(x)
x = self.up(x)
x = self.act(x)
return self.down(x)
# → (B, T, d_lfm)