"""LFM2.5 backbone wrapper that consumes pre-computed pseudo-token embeddings. This is the load-bearing module. It wraps `transformers.Lfm2Model` so the pretrained LFM2.5 backbone processes our 64 per-transaction pseudo-tokens as if they were 64 text tokens — the same hook LFM2.5-Audio and LFM2.5-VL use to feed their continuous-embedding streams into the LM. Why `Lfm2Model` and not `Lfm2ForCausalLM`: We don't use the LM head — downstream task heads pool the hidden states instead. `Lfm2Model` returns `last_hidden_state` without the vocab projection, which is exactly what we want and skips ~67M unused params (vocab_size 65536 * d_lfm 1024). The injection hook (verified against transformers/models/lfm2/modeling_lfm2.py lines 523-558): Lfm2Model.forward(input_ids=None, inputs_embeds=) skips embed_tokens entirely, applies RoPE to whatever hidden states arrive, and runs the full conv+attention stack. The XOR guard at line 533 requires us to pass exactly one of input_ids / inputs_embeds, so we explicitly pass input_ids=None. Frozen-base invariant: `freeze_base()` sets requires_grad=False on every parameter at load time. LoRA layers, attached AFTER freezing, have requires_grad=True by default. The unit tests verify this invariant — if any base parameter accumulates a gradient, the test fails. LoRA target modules: LFM2's attention modules use `q_proj`, `k_proj`, `v_proj`, `out_proj` (NOT LLaMA's `o_proj`). The conv block uses its own `in_proj` and `out_proj`; the SwiGLU MLP uses `w1`, `w2`, `w3`. POC default is attention-only ("q_proj", "k_proj", "v_proj", "out_proj") at r=16, α=32. Escalate to the Spotify-full set at r=64, α=128 if quality is capacity-limited. Note on `out_proj` collision: `target_modules=["out_proj"]` will match BOTH attention's `out_proj` AND the conv block's `out_proj`. For attention-only LoRA, this is not what we want. We use a regex or explicit attention-scoped names via `target_modules=["q_proj", "k_proj", "v_proj"]` and add `out_proj` only when conv-LoRA is intentional. The current default is `q_proj/k_proj/v_proj/out_proj`; the test asserts that attention-only LoRA does NOT touch conv layers. """ from __future__ import annotations from pathlib import Path import torch import torch.nn as nn from peft import LoraConfig, get_peft_model from transformers import Lfm2Model # Canonical LFM2 LoRA target sets. The attention-only POC default targets only # the attention proj modules; the Spotify-full set (from v8 production) adds # conv `in_proj` and SwiGLU `w1/w2/w3`. Both sets include `out_proj`, which # matches both attention `out_proj` AND conv `out_proj` — see Lfm2Attention # vs Lfm2ShortConv. We constrain target_modules with the LAYER_PREFIX_REGEX # fallback below if attention-only-strict is required. ATTENTION_ONLY_TARGETS = ["q_proj", "k_proj", "v_proj"] # + out_proj added below SPOTIFY_FULL_TARGETS = ["q_proj", "k_proj", "v_proj", "out_proj", "in_proj", "w1", "w2", "w3"] class LfmPseudoTokenBackbone(nn.Module): """Wraps Lfm2Model to consume pre-computed pseudo-token embeddings. Forward: pseudo_tokens: (B, T, d_lfm) float — output of the projection adapter attention_mask: (B, T) int or None — padding mask. For our use case of fully-populated 64-tx sequences with no padding, None is the normal call. Pass attention_mask explicitly if any pseudo-token position is logically a "padding" slot. Returns: (B, T, d_lfm) float — `last_hidden_state` from the LFM2.5 stack after 16 layers (10 conv + 6 attention). Args: model_path: HF-format directory containing config.json + safetensors. For local-only POC we use ~/Projects/_models/LFM25-350M-Base; for HF-hosted models pass a repo ID string. lora: LoraConfig instance, or None for frozen-base-only (no LoRA). dtype: bfloat16 for training (matches LFM2's pretraining dtype) or float32 for CPU smoke tests. device_map: "auto" for GPU runs, None for CPU/MPS local smoke runs. trust_remote_code: required True for LiquidAI checkpoints that ship modeling code alongside weights. Safe for LiquidAI/* paths. """ def __init__( self, model_path: str | Path, lora: LoraConfig | None = None, dtype: torch.dtype = torch.bfloat16, device_map: str | None = "auto", trust_remote_code: bool = True, freeze_base: bool = True, ) -> None: super().__init__() load_kwargs: dict = {"torch_dtype": dtype, "trust_remote_code": trust_remote_code} if device_map is not None: load_kwargs["device_map"] = device_map self.base = Lfm2Model.from_pretrained(str(model_path), **load_kwargs) # Capture d_lfm before freezing for downstream consumers. self.d_lfm: int = self.base.config.hidden_size # Freeze controls. The default (freeze_base=True) is the "encoder # pattern with a frozen base" — what makes the shared-base-across- # customers story possible. Setting freeze_base=False is the # stage-2 VL recipe ("unfreeze the base and fine-tune end-to-end"), # used as a diagnostic upper-bound experiment. if freeze_base: self.freeze_base() if lora is not None: # get_peft_model wraps `self.base` and registers LoRA layers as # trainable. The XOR guard inside Lfm2Model.forward still # accepts inputs_embeds + input_ids=None — PEFT does not modify # the forward signature, only the attention/MLP module weights. self.base = get_peft_model(self.base, lora) def freeze_base(self) -> None: """Set requires_grad=False on every base parameter.""" for p in self.base.parameters(): p.requires_grad = False def forward( self, pseudo_tokens: torch.Tensor, attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: # pseudo_tokens: (B, T, d_lfm) — must already be in the base's dtype # (caller's responsibility; the projector should produce bf16 on GPU). outputs = self.base( input_ids=None, inputs_embeds=pseudo_tokens, attention_mask=attention_mask, use_cache=False, ) return outputs.last_hidden_state # → (B, T, d_lfm) def trainable_parameters(self) -> int: """Number of trainable params (should be LoRA-only when LoRA is on).""" return sum(p.numel() for p in self.base.parameters() if p.requires_grad) def total_parameters(self) -> int: return sum(p.numel() for p in self.base.parameters()) # Regex matching attention modules only (strict mode). LFM2 attention modules # live under `layers.{i}.self_attn.{q_proj,k_proj,v_proj,out_proj}`; conv # `out_proj` lives under `layers.{i}.conv.out_proj`. The regex includes # `self_attn` in the path to disambiguate. STRICT_ATTENTION_REGEX = r".*self_attn\.(q_proj|k_proj|v_proj|out_proj)$" def build_lora_config( r: int = 16, alpha: int = 32, dropout: float = 0.05, target_modules: list[str] | str | None = None, strict_attention_only: bool = False, ) -> LoraConfig: """Build a LoraConfig tuned for the LFM2 attention-LoRA POC default. Default behavior (`target_modules=None`, `strict_attention_only=False`): target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] PEFT matches by leaf-module name, so `out_proj` matches BOTH attention.out_proj (6 occurrences) AND conv.out_proj (10 occurrences). Net LoRA count for LFM2.5-350M: ~1.01M params. We accept this conv.out_proj collision because: - The conv layers do local sequence mixing — having them slightly adaptable helps with our pseudo-token input distribution. - It's 0.22M extra trainable params, dwarfed by the projector (~2.6M) and downstream heads. - Spotify production uses an even larger set (`SPOTIFY_FULL_TARGETS`). To opt out of conv adaptation, pass `strict_attention_only=True` — this selects target modules by regex so only attention paths match. Args: r: LoRA rank. 16 for POC, escalate to 64 for production. alpha: LoRA scaling. Conventionally 2*r. dropout: LoRA dropout. 0.05 keeps regularization light at POC scale. target_modules: override the default. Pass `SPOTIFY_FULL_TARGETS` (list) to escalate, or a regex string for custom matching. strict_attention_only: if True and `target_modules` is None, use the `STRICT_ATTENTION_REGEX` so only attention modules get LoRA. Returns: LoraConfig ready for `get_peft_model(base, config)`. """ if target_modules is None: if strict_attention_only: target_modules = STRICT_ATTENTION_REGEX else: target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] # FEATURE_EXTRACTION task type — we use Lfm2Model (no LM head), and # downstream task heads handle the loss. CAUSAL_LM would try to save # an lm_head that does not exist on Lfm2Model. return LoraConfig( r=r, lora_alpha=alpha, lora_dropout=dropout, target_modules=target_modules, bias="none", task_type="FEATURE_EXTRACTION", )