| """LFM2.5 backbone wrapper that consumes pre-computed pseudo-token embeddings. |
| |
| This is the load-bearing module. It wraps `transformers.Lfm2Model` so the |
| pretrained LFM2.5 backbone processes our 64 per-transaction pseudo-tokens as |
| if they were 64 text tokens — the same hook LFM2.5-Audio and LFM2.5-VL use |
| to feed their continuous-embedding streams into the LM. |
| |
| Why `Lfm2Model` and not `Lfm2ForCausalLM`: |
| We don't use the LM head — downstream task heads pool the hidden states |
| instead. `Lfm2Model` returns `last_hidden_state` without the vocab |
| projection, which is exactly what we want and skips ~67M unused params |
| (vocab_size 65536 * d_lfm 1024). |
| |
| The injection hook (verified against transformers/models/lfm2/modeling_lfm2.py |
| lines 523-558): |
| |
| Lfm2Model.forward(input_ids=None, inputs_embeds=<our_tensor>) skips |
| embed_tokens entirely, applies RoPE to whatever hidden states arrive, |
| and runs the full conv+attention stack. The XOR guard at line 533 |
| requires us to pass exactly one of input_ids / inputs_embeds, so we |
| explicitly pass input_ids=None. |
| |
| Frozen-base invariant: |
| `freeze_base()` sets requires_grad=False on every parameter at load |
| time. LoRA layers, attached AFTER freezing, have requires_grad=True |
| by default. The unit tests verify this invariant — if any base |
| parameter accumulates a gradient, the test fails. |
| |
| LoRA target modules: |
| LFM2's attention modules use `q_proj`, `k_proj`, `v_proj`, `out_proj` |
| (NOT LLaMA's `o_proj`). The conv block uses its own `in_proj` and |
| `out_proj`; the SwiGLU MLP uses `w1`, `w2`, `w3`. POC default is |
| attention-only ("q_proj", "k_proj", "v_proj", "out_proj") at r=16, |
| α=32. Escalate to the Spotify-full set at r=64, α=128 if quality is |
| capacity-limited. |
| |
| Note on `out_proj` collision: |
| `target_modules=["out_proj"]` will match BOTH attention's `out_proj` |
| AND the conv block's `out_proj`. For attention-only LoRA, this is |
| not what we want. We use a regex or explicit attention-scoped names |
| via `target_modules=["q_proj", "k_proj", "v_proj"]` and add |
| `out_proj` only when conv-LoRA is intentional. The current default |
| is `q_proj/k_proj/v_proj/out_proj`; the test asserts that |
| attention-only LoRA does NOT touch conv layers. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| from peft import LoraConfig, get_peft_model |
| from transformers import Lfm2Model |
|
|
| |
| |
| |
| |
| |
| |
| ATTENTION_ONLY_TARGETS = ["q_proj", "k_proj", "v_proj"] |
| SPOTIFY_FULL_TARGETS = ["q_proj", "k_proj", "v_proj", "out_proj", "in_proj", "w1", "w2", "w3"] |
|
|
|
|
| class LfmPseudoTokenBackbone(nn.Module): |
| """Wraps Lfm2Model to consume pre-computed pseudo-token embeddings. |
| |
| Forward: |
| pseudo_tokens: (B, T, d_lfm) float — output of the projection adapter |
| attention_mask: (B, T) int or None — padding mask. For our use case |
| of fully-populated 64-tx sequences with no padding, None is the |
| normal call. Pass attention_mask explicitly if any pseudo-token |
| position is logically a "padding" slot. |
| |
| Returns: |
| (B, T, d_lfm) float — `last_hidden_state` from the LFM2.5 stack |
| after 16 layers (10 conv + 6 attention). |
| |
| Args: |
| model_path: HF-format directory containing config.json + safetensors. |
| For local-only POC we use ~/Projects/_models/LFM25-350M-Base; for |
| HF-hosted models pass a repo ID string. |
| lora: LoraConfig instance, or None for frozen-base-only (no LoRA). |
| dtype: bfloat16 for training (matches LFM2's pretraining dtype) or |
| float32 for CPU smoke tests. |
| device_map: "auto" for GPU runs, None for CPU/MPS local smoke runs. |
| trust_remote_code: required True for LiquidAI checkpoints that ship |
| modeling code alongside weights. Safe for LiquidAI/* paths. |
| """ |
|
|
| def __init__( |
| self, |
| model_path: str | Path, |
| lora: LoraConfig | None = None, |
| dtype: torch.dtype = torch.bfloat16, |
| device_map: str | None = "auto", |
| trust_remote_code: bool = True, |
| freeze_base: bool = True, |
| ) -> None: |
| super().__init__() |
| load_kwargs: dict = {"torch_dtype": dtype, "trust_remote_code": trust_remote_code} |
| if device_map is not None: |
| load_kwargs["device_map"] = device_map |
| self.base = Lfm2Model.from_pretrained(str(model_path), **load_kwargs) |
|
|
| |
| self.d_lfm: int = self.base.config.hidden_size |
|
|
| |
| |
| |
| |
| |
| if freeze_base: |
| self.freeze_base() |
|
|
| if lora is not None: |
| |
| |
| |
| |
| self.base = get_peft_model(self.base, lora) |
|
|
| def freeze_base(self) -> None: |
| """Set requires_grad=False on every base parameter.""" |
| for p in self.base.parameters(): |
| p.requires_grad = False |
|
|
| def forward( |
| self, |
| pseudo_tokens: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| |
| |
| outputs = self.base( |
| input_ids=None, |
| inputs_embeds=pseudo_tokens, |
| attention_mask=attention_mask, |
| use_cache=False, |
| ) |
| return outputs.last_hidden_state |
| |
|
|
| def trainable_parameters(self) -> int: |
| """Number of trainable params (should be LoRA-only when LoRA is on).""" |
| return sum(p.numel() for p in self.base.parameters() if p.requires_grad) |
|
|
| def total_parameters(self) -> int: |
| return sum(p.numel() for p in self.base.parameters()) |
|
|
|
|
| |
| |
| |
| |
| STRICT_ATTENTION_REGEX = r".*self_attn\.(q_proj|k_proj|v_proj|out_proj)$" |
|
|
|
|
| def build_lora_config( |
| r: int = 16, |
| alpha: int = 32, |
| dropout: float = 0.05, |
| target_modules: list[str] | str | None = None, |
| strict_attention_only: bool = False, |
| ) -> LoraConfig: |
| """Build a LoraConfig tuned for the LFM2 attention-LoRA POC default. |
| |
| Default behavior (`target_modules=None`, `strict_attention_only=False`): |
| target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] |
| PEFT matches by leaf-module name, so `out_proj` matches BOTH |
| attention.out_proj (6 occurrences) AND conv.out_proj (10 occurrences). |
| Net LoRA count for LFM2.5-350M: ~1.01M params. |
| |
| We accept this conv.out_proj collision because: |
| - The conv layers do local sequence mixing — having them slightly |
| adaptable helps with our pseudo-token input distribution. |
| - It's 0.22M extra trainable params, dwarfed by the projector (~2.6M) |
| and downstream heads. |
| - Spotify production uses an even larger set (`SPOTIFY_FULL_TARGETS`). |
| |
| To opt out of conv adaptation, pass `strict_attention_only=True` — this |
| selects target modules by regex so only attention paths match. |
| |
| Args: |
| r: LoRA rank. 16 for POC, escalate to 64 for production. |
| alpha: LoRA scaling. Conventionally 2*r. |
| dropout: LoRA dropout. 0.05 keeps regularization light at POC scale. |
| target_modules: override the default. Pass `SPOTIFY_FULL_TARGETS` |
| (list) to escalate, or a regex string for custom matching. |
| strict_attention_only: if True and `target_modules` is None, use |
| the `STRICT_ATTENTION_REGEX` so only attention modules get LoRA. |
| |
| Returns: |
| LoraConfig ready for `get_peft_model(base, config)`. |
| """ |
| if target_modules is None: |
| if strict_attention_only: |
| target_modules = STRICT_ATTENTION_REGEX |
| else: |
| target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] |
| |
| |
| |
| return LoraConfig( |
| r=r, |
| lora_alpha=alpha, |
| lora_dropout=dropout, |
| target_modules=target_modules, |
| bias="none", |
| task_type="FEATURE_EXTRACTION", |
| ) |
|
|