lfm2-transaction-encoder / encoder /src /model /lfm_pseudo_token_wrapper.py
cdotsanghvi's picture
add multi-head demo as 4th-6th tabs; restore Why Liquid + Integration
083b138
Raw
History Blame Contribute Delete
9.56 kB
"""LFM2.5 backbone wrapper that consumes pre-computed pseudo-token embeddings.
This is the load-bearing module. It wraps `transformers.Lfm2Model` so the
pretrained LFM2.5 backbone processes our 64 per-transaction pseudo-tokens as
if they were 64 text tokens — the same hook LFM2.5-Audio and LFM2.5-VL use
to feed their continuous-embedding streams into the LM.
Why `Lfm2Model` and not `Lfm2ForCausalLM`:
We don't use the LM head — downstream task heads pool the hidden states
instead. `Lfm2Model` returns `last_hidden_state` without the vocab
projection, which is exactly what we want and skips ~67M unused params
(vocab_size 65536 * d_lfm 1024).
The injection hook (verified against transformers/models/lfm2/modeling_lfm2.py
lines 523-558):
Lfm2Model.forward(input_ids=None, inputs_embeds=<our_tensor>) skips
embed_tokens entirely, applies RoPE to whatever hidden states arrive,
and runs the full conv+attention stack. The XOR guard at line 533
requires us to pass exactly one of input_ids / inputs_embeds, so we
explicitly pass input_ids=None.
Frozen-base invariant:
`freeze_base()` sets requires_grad=False on every parameter at load
time. LoRA layers, attached AFTER freezing, have requires_grad=True
by default. The unit tests verify this invariant — if any base
parameter accumulates a gradient, the test fails.
LoRA target modules:
LFM2's attention modules use `q_proj`, `k_proj`, `v_proj`, `out_proj`
(NOT LLaMA's `o_proj`). The conv block uses its own `in_proj` and
`out_proj`; the SwiGLU MLP uses `w1`, `w2`, `w3`. POC default is
attention-only ("q_proj", "k_proj", "v_proj", "out_proj") at r=16,
α=32. Escalate to the Spotify-full set at r=64, α=128 if quality is
capacity-limited.
Note on `out_proj` collision:
`target_modules=["out_proj"]` will match BOTH attention's `out_proj`
AND the conv block's `out_proj`. For attention-only LoRA, this is
not what we want. We use a regex or explicit attention-scoped names
via `target_modules=["q_proj", "k_proj", "v_proj"]` and add
`out_proj` only when conv-LoRA is intentional. The current default
is `q_proj/k_proj/v_proj/out_proj`; the test asserts that
attention-only LoRA does NOT touch conv layers.
"""
from __future__ import annotations
from pathlib import Path
import torch
import torch.nn as nn
from peft import LoraConfig, get_peft_model
from transformers import Lfm2Model
# Canonical LFM2 LoRA target sets. The attention-only POC default targets only
# the attention proj modules; the Spotify-full set (from v8 production) adds
# conv `in_proj` and SwiGLU `w1/w2/w3`. Both sets include `out_proj`, which
# matches both attention `out_proj` AND conv `out_proj` — see Lfm2Attention
# vs Lfm2ShortConv. We constrain target_modules with the LAYER_PREFIX_REGEX
# fallback below if attention-only-strict is required.
ATTENTION_ONLY_TARGETS = ["q_proj", "k_proj", "v_proj"] # + out_proj added below
SPOTIFY_FULL_TARGETS = ["q_proj", "k_proj", "v_proj", "out_proj", "in_proj", "w1", "w2", "w3"]
class LfmPseudoTokenBackbone(nn.Module):
"""Wraps Lfm2Model to consume pre-computed pseudo-token embeddings.
Forward:
pseudo_tokens: (B, T, d_lfm) float — output of the projection adapter
attention_mask: (B, T) int or None — padding mask. For our use case
of fully-populated 64-tx sequences with no padding, None is the
normal call. Pass attention_mask explicitly if any pseudo-token
position is logically a "padding" slot.
Returns:
(B, T, d_lfm) float — `last_hidden_state` from the LFM2.5 stack
after 16 layers (10 conv + 6 attention).
Args:
model_path: HF-format directory containing config.json + safetensors.
For local-only POC we use ~/Projects/_models/LFM25-350M-Base; for
HF-hosted models pass a repo ID string.
lora: LoraConfig instance, or None for frozen-base-only (no LoRA).
dtype: bfloat16 for training (matches LFM2's pretraining dtype) or
float32 for CPU smoke tests.
device_map: "auto" for GPU runs, None for CPU/MPS local smoke runs.
trust_remote_code: required True for LiquidAI checkpoints that ship
modeling code alongside weights. Safe for LiquidAI/* paths.
"""
def __init__(
self,
model_path: str | Path,
lora: LoraConfig | None = None,
dtype: torch.dtype = torch.bfloat16,
device_map: str | None = "auto",
trust_remote_code: bool = True,
freeze_base: bool = True,
) -> None:
super().__init__()
load_kwargs: dict = {"torch_dtype": dtype, "trust_remote_code": trust_remote_code}
if device_map is not None:
load_kwargs["device_map"] = device_map
self.base = Lfm2Model.from_pretrained(str(model_path), **load_kwargs)
# Capture d_lfm before freezing for downstream consumers.
self.d_lfm: int = self.base.config.hidden_size
# Freeze controls. The default (freeze_base=True) is the "encoder
# pattern with a frozen base" — what makes the shared-base-across-
# customers story possible. Setting freeze_base=False is the
# stage-2 VL recipe ("unfreeze the base and fine-tune end-to-end"),
# used as a diagnostic upper-bound experiment.
if freeze_base:
self.freeze_base()
if lora is not None:
# get_peft_model wraps `self.base` and registers LoRA layers as
# trainable. The XOR guard inside Lfm2Model.forward still
# accepts inputs_embeds + input_ids=None — PEFT does not modify
# the forward signature, only the attention/MLP module weights.
self.base = get_peft_model(self.base, lora)
def freeze_base(self) -> None:
"""Set requires_grad=False on every base parameter."""
for p in self.base.parameters():
p.requires_grad = False
def forward(
self,
pseudo_tokens: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
# pseudo_tokens: (B, T, d_lfm) — must already be in the base's dtype
# (caller's responsibility; the projector should produce bf16 on GPU).
outputs = self.base(
input_ids=None,
inputs_embeds=pseudo_tokens,
attention_mask=attention_mask,
use_cache=False,
)
return outputs.last_hidden_state
# → (B, T, d_lfm)
def trainable_parameters(self) -> int:
"""Number of trainable params (should be LoRA-only when LoRA is on)."""
return sum(p.numel() for p in self.base.parameters() if p.requires_grad)
def total_parameters(self) -> int:
return sum(p.numel() for p in self.base.parameters())
# Regex matching attention modules only (strict mode). LFM2 attention modules
# live under `layers.{i}.self_attn.{q_proj,k_proj,v_proj,out_proj}`; conv
# `out_proj` lives under `layers.{i}.conv.out_proj`. The regex includes
# `self_attn` in the path to disambiguate.
STRICT_ATTENTION_REGEX = r".*self_attn\.(q_proj|k_proj|v_proj|out_proj)$"
def build_lora_config(
r: int = 16,
alpha: int = 32,
dropout: float = 0.05,
target_modules: list[str] | str | None = None,
strict_attention_only: bool = False,
) -> LoraConfig:
"""Build a LoraConfig tuned for the LFM2 attention-LoRA POC default.
Default behavior (`target_modules=None`, `strict_attention_only=False`):
target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
PEFT matches by leaf-module name, so `out_proj` matches BOTH
attention.out_proj (6 occurrences) AND conv.out_proj (10 occurrences).
Net LoRA count for LFM2.5-350M: ~1.01M params.
We accept this conv.out_proj collision because:
- The conv layers do local sequence mixing — having them slightly
adaptable helps with our pseudo-token input distribution.
- It's 0.22M extra trainable params, dwarfed by the projector (~2.6M)
and downstream heads.
- Spotify production uses an even larger set (`SPOTIFY_FULL_TARGETS`).
To opt out of conv adaptation, pass `strict_attention_only=True` — this
selects target modules by regex so only attention paths match.
Args:
r: LoRA rank. 16 for POC, escalate to 64 for production.
alpha: LoRA scaling. Conventionally 2*r.
dropout: LoRA dropout. 0.05 keeps regularization light at POC scale.
target_modules: override the default. Pass `SPOTIFY_FULL_TARGETS`
(list) to escalate, or a regex string for custom matching.
strict_attention_only: if True and `target_modules` is None, use
the `STRICT_ATTENTION_REGEX` so only attention modules get LoRA.
Returns:
LoraConfig ready for `get_peft_model(base, config)`.
"""
if target_modules is None:
if strict_attention_only:
target_modules = STRICT_ATTENTION_REGEX
else:
target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"]
# FEATURE_EXTRACTION task type — we use Lfm2Model (no LM head), and
# downstream task heads handle the loss. CAUSAL_LM would try to save
# an lm_head that does not exist on Lfm2Model.
return LoraConfig(
r=r,
lora_alpha=alpha,
lora_dropout=dropout,
target_modules=target_modules,
bias="none",
task_type="FEATURE_EXTRACTION",
)