| """LFM2.5 wrapper extended with the LM head and the mixed-modality input path. |
| |
| Sibling of `lfm_pseudo_token_wrapper.LfmPseudoTokenBackbone`. The |
| original wraps `Lfm2Model` (no LM head — saves 67M params unused by |
| the four classification heads). For the multi-surface expansion we |
| need the LM head active (for reasoning-text generation) AND we need |
| to concatenate the customer's transaction pseudo-tokens with the |
| analyst's free-text query so the backbone attends across both |
| modalities. |
| |
| Why a sibling instead of an in-place edit: |
| |
| The existing wrapper is in production for the current demo. Phase |
| 1 of the multi-surface expansion is strictly additive — the |
| original wrapper continues to back the existing four-head demo; |
| this sibling backs the new dispute-legitimacy + Co-Pilot surfaces. |
| When all five surfaces ship and the Co-Pilot replaces the existing |
| demo, we can decide to consolidate or keep both. |
| |
| What changes vs the original: |
| |
| 1. Loads `Lfm2ForCausalLM` (which is `Lfm2Model + lm_head`) |
| instead of `Lfm2Model`. The LM head is tied to embed_tokens, |
| so storage cost is shared. Memory delta is the lm_head Linear |
| module's bookkeeping, not its weights. |
| |
| 2. Exposes `embed_tokens` and the SEP token id so callers can |
| build the mixed-modality input embedding sequence. |
| |
| 3. PEFT `task_type` becomes `CAUSAL_LM` so PEFT preserves the LM |
| head correctly. We also widen `target_modules` to include |
| `lm_head` so the LoRA can shape generation. |
| |
| 4. Adds `forward_mixed` which takes pre-built combined input |
| embeddings (tx_pseudo + SEP + text) and returns BOTH the |
| hidden states (for the probability + attribution heads) AND |
| the LM logits over the text positions (for the LM loss). |
| |
| Frozen-base invariant preserved: |
| `freeze_base()` still sets requires_grad=False on every base |
| parameter. PEFT adds LoRA layers AFTER freezing; their |
| requires_grad=True is preserved by PEFT's wrap logic. |
| |
| LoRA target modules: |
| Same names as the original (LFM2 uses `q_proj`/`k_proj`/`v_proj`/ |
| `out_proj`/`in_proj`/`w1`/`w2`/`w3`). The default attention-only |
| LoRA stays r=16. For the LM head we add `lm_head` to the targets. |
| PEFT will create a separate LoRA matrix for the lm_head Linear. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| import torch |
| import torch.nn as nn |
| from peft import LoraConfig, get_peft_model |
| from transformers import AutoTokenizer, Lfm2ForCausalLM, PreTrainedTokenizerBase |
|
|
|
|
| |
| |
| |
| |
| |
| |
| DEFAULT_MULTISURFACE_TARGETS = [ |
| "q_proj", |
| "k_proj", |
| "v_proj", |
| "out_proj", |
| "lm_head", |
| ] |
|
|
|
|
| class LfmMultiSurfaceBackbone(nn.Module): |
| """LFM2.5 wrapper supporting tx pseudo-tokens + text tokens + LM head. |
| |
| Forward signatures: |
| |
| forward(pseudo_tokens, attention_mask=None) -> hidden |
| Same as the original wrapper. Used when no text input is |
| needed (the existing four-head heads still work via this |
| path). |
| |
| forward_mixed(combined_embeds, lm_head_positions, attention_mask) |
| -> dict(hidden_states, lm_logits) |
| The mixed-modality forward. `combined_embeds` is |
| (B, T_total, D) with tx pseudo-tokens at positions |
| [0, num_tx_positions), SEP at position num_tx_positions, |
| and text token embeddings at positions |
| (num_tx_positions + 1, T_total). `lm_head_positions` is the |
| slice of positions where the LM head should produce logits |
| — typically the text positions only. |
| |
| Args: |
| model_path: HF-format directory or repo ID for LFM2.5. |
| lora: LoraConfig instance, or None for frozen-base-only. |
| dtype: bfloat16 for GPU training, float32 for CPU smoke. |
| device_map: "auto" for GPU, None for CPU. |
| trust_remote_code: True for LiquidAI checkpoints. |
| freeze_base: Set base params to non-trainable at init. |
| sep_token_id: Token id used as the separator between the |
| transaction pseudo-tokens and the text tokens. Defaults to |
| the tokenizer's BOS or special sep token; caller can |
| override. The SEP embedding is taken from the model's own |
| embed_tokens table at this id. |
| """ |
|
|
| def __init__( |
| self, |
| model_path: str | Path, |
| lora: LoraConfig | None = None, |
| dtype: torch.dtype = torch.bfloat16, |
| device_map: str | None = "auto", |
| trust_remote_code: bool = True, |
| freeze_base: bool = True, |
| sep_token_id: int | None = None, |
| ) -> None: |
| super().__init__() |
|
|
| load_kwargs: dict = { |
| "torch_dtype": dtype, |
| "trust_remote_code": trust_remote_code, |
| } |
| if device_map is not None: |
| load_kwargs["device_map"] = device_map |
|
|
| |
| self.base = Lfm2ForCausalLM.from_pretrained( |
| str(model_path), |
| **load_kwargs, |
| ) |
|
|
| |
| |
| |
| self.d_lfm: int = self.base.config.hidden_size |
| self.vocab_size: int = self.base.config.vocab_size |
|
|
| |
| |
| self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained( |
| str(model_path), |
| trust_remote_code=trust_remote_code, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| if sep_token_id is not None: |
| self.sep_token_id = sep_token_id |
| elif self.tokenizer.sep_token_id is not None: |
| self.sep_token_id = self.tokenizer.sep_token_id |
| elif self.tokenizer.bos_token_id is not None: |
| self.sep_token_id = self.tokenizer.bos_token_id |
| else: |
| self.sep_token_id = 0 |
|
|
| if freeze_base: |
| self.freeze_base() |
|
|
| if lora is not None: |
| |
| |
| self.base = get_peft_model(self.base, lora) |
|
|
| |
|
|
| def freeze_base(self) -> None: |
| """Set requires_grad=False on every base parameter.""" |
| for p in self.base.parameters(): |
| p.requires_grad = False |
|
|
| |
|
|
| @property |
| def _lfm_for_causal_lm(self) -> nn.Module: |
| """Resolve through PEFT wrapping to the underlying Lfm2ForCausalLM.""" |
| |
| |
| |
| inner = getattr(self.base, "base_model", None) |
| if inner is not None: |
| inner = getattr(inner, "model", inner) |
| else: |
| inner = self.base |
| return inner |
|
|
| @property |
| def _inner_lfm_model(self) -> nn.Module: |
| """The Lfm2Model (no LM head). For hidden-state-only forwards.""" |
| |
| return self._lfm_for_causal_lm.model |
|
|
| @property |
| def embed_tokens(self) -> nn.Embedding: |
| """The model's input-token embedding table. |
| |
| Used by callers to embed text tokens before concatenation with |
| the transaction pseudo-tokens. |
| """ |
| return self._inner_lfm_model.embed_tokens |
|
|
| def embed_text(self, input_ids: torch.Tensor) -> torch.Tensor: |
| """Embed text token ids using the model's own table. |
| |
| Args: |
| input_ids: (B, T_txt) int64 token ids. |
| |
| Returns: |
| (B, T_txt, D) embeddings in the model's dtype. |
| """ |
| return self.embed_tokens(input_ids) |
|
|
| def embed_sep(self, batch_size: int, device: torch.device) -> torch.Tensor: |
| """Build the SEP embedding for a batch. |
| |
| Returns: |
| (B, 1, D) SEP embedding, one slot per batch element. |
| """ |
| sep_id = torch.full( |
| (batch_size, 1), |
| fill_value=self.sep_token_id, |
| dtype=torch.long, |
| device=device, |
| ) |
| return self.embed_tokens(sep_id) |
|
|
| |
|
|
| def forward( |
| self, |
| pseudo_tokens: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| ) -> torch.Tensor: |
| """Original tx-only forward (parity with the legacy wrapper). |
| |
| Used by the existing four classification heads that don't need |
| a text input. Returns hidden states only; the LM head is not |
| invoked. |
| |
| Args: |
| pseudo_tokens: (B, T_tx, D) tx pseudo-tokens from the |
| projector. Must already be in the base's dtype. |
| attention_mask: (B, T_tx) optional padding mask. |
| |
| Returns: |
| (B, T_tx, D) `last_hidden_state` after the LFM2.5 stack. |
| """ |
| |
| |
| |
| |
| |
| |
| outputs = self._inner_lfm_model( |
| input_ids=None, |
| inputs_embeds=pseudo_tokens, |
| attention_mask=attention_mask, |
| use_cache=False, |
| ) |
| return outputs.last_hidden_state |
|
|
| def forward_mixed( |
| self, |
| combined_embeds: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| compute_lm_logits: bool = True, |
| ) -> dict[str, torch.Tensor]: |
| """Mixed-modality forward over [tx_pseudo, SEP, text]. |
| |
| Args: |
| combined_embeds: (B, T_total, D) the pre-concatenated |
| embedding sequence built by the MixedModalityBatch. |
| attention_mask: (B, T_total) padding mask. 1 for real |
| positions, 0 for padding. Required when text lengths |
| vary across the batch. |
| compute_lm_logits: If True, run the LM head and return |
| lm_logits. If False, skip the LM head — saves a |
| vocab-sized matmul when the caller only wants hidden |
| states (probability + attribution heads). |
| |
| Returns: |
| Dict with: |
| hidden_states: (B, T_total, D) — for probability + |
| attribution heads. |
| lm_logits: (B, T_total, vocab_size) if |
| compute_lm_logits=True, else absent. |
| """ |
| |
| |
| |
| |
| |
| model_outputs = self._inner_lfm_model( |
| input_ids=None, |
| inputs_embeds=combined_embeds, |
| attention_mask=attention_mask, |
| use_cache=False, |
| ) |
| hidden = model_outputs.last_hidden_state |
|
|
| out: dict[str, torch.Tensor] = {"hidden_states": hidden} |
| if compute_lm_logits: |
| |
| |
| |
| out["lm_logits"] = self._lfm_for_causal_lm.lm_head(hidden) |
| return out |
|
|
| |
|
|
| @torch.no_grad() |
| def generate_reasoning( |
| self, |
| combined_embeds: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| max_new_tokens: int = 128, |
| temperature: float = 0.0, |
| ) -> torch.Tensor: |
| """Greedy / temperature-sampled generation from a mixed prefix. |
| |
| Args: |
| combined_embeds: (B, T_prefix, D) the [tx_pseudo, SEP, |
| text_prompt] prefix to condition generation on. |
| attention_mask: (B, T_prefix) padding mask for the prefix. |
| max_new_tokens: how many new text tokens to produce. |
| temperature: 0.0 = greedy; >0 = temperature sampling. |
| |
| Returns: |
| (B, max_new_tokens) int64 generated token ids. Caller |
| decodes via `self.tokenizer.batch_decode`. |
| |
| Note: |
| We hand-roll generation rather than calling `base.generate` |
| because the input is `inputs_embeds`, not `input_ids`. |
| HuggingFace's `generate` supports `inputs_embeds` on most |
| architectures but has edge cases with hybrid models |
| (cache shape, prepare_inputs_for_generation). Hand-rolled |
| greedy gives us full control and matches the doctrine |
| "eval-time decoder hygiene = temperature=0 greedy." |
| """ |
| device = combined_embeds.device |
| |
| generated_ids: list[torch.Tensor] = [] |
| |
| |
| current_embeds = combined_embeds |
| current_mask = attention_mask |
|
|
| for _ in range(max_new_tokens): |
| outputs = self.forward_mixed( |
| combined_embeds=current_embeds, |
| attention_mask=current_mask, |
| compute_lm_logits=True, |
| ) |
| |
| next_logits = outputs["lm_logits"][:, -1, :] |
| if temperature == 0.0: |
| next_token = next_logits.argmax(dim=-1, keepdim=True) |
| else: |
| probs = torch.softmax(next_logits / temperature, dim=-1) |
| next_token = torch.multinomial(probs, num_samples=1) |
| generated_ids.append(next_token) |
|
|
| |
| |
| eos_id = self.tokenizer.eos_token_id |
| if eos_id is not None and (next_token == eos_id).all(): |
| break |
|
|
| |
| next_embed = self.embed_tokens(next_token) |
| current_embeds = torch.cat([current_embeds, next_embed], dim=1) |
| if current_mask is not None: |
| ones = torch.ones( |
| (current_mask.shape[0], 1), |
| dtype=current_mask.dtype, |
| device=device, |
| ) |
| current_mask = torch.cat([current_mask, ones], dim=1) |
|
|
| return torch.cat(generated_ids, dim=1) |
|
|
| |
|
|
| def trainable_parameters(self) -> int: |
| return sum(p.numel() for p in self.base.parameters() if p.requires_grad) |
|
|
| def total_parameters(self) -> int: |
| return sum(p.numel() for p in self.base.parameters()) |
|
|
| def lora_module_count(self) -> int: |
| """Count LoRA-tagged modules. Used by the smoke gate. |
| |
| Doctrine: expect ~92 modules on LFM2.5-350M for the standard |
| attention-only target set. If we add `lm_head`, expect +1. |
| <50 means target_modules names are wrong. |
| """ |
| return sum( |
| 1 for name, _ in self.base.named_modules() if "lora" in name.lower() |
| ) |
|
|
|
|
| def build_multisurface_lora_config( |
| r: int = 16, |
| alpha: int = 32, |
| dropout: float = 0.05, |
| target_modules: list[str] | None = None, |
| ) -> LoraConfig: |
| """Build a LoraConfig for the multi-surface backbone. |
| |
| Defaults match the Phase-1 dispute-legitimacy recipe: r=16, |
| alpha=32 (= 2r per the LEAP golden path), dropout=0.05, target |
| modules = attention + lm_head. `task_type="CAUSAL_LM"` ensures |
| PEFT preserves the LM head structure so generation works after |
| wrapping. |
| |
| Args: |
| r: LoRA rank. 16 for Phase 1; ablate to 8 / 32 later. |
| alpha: LoRA scaling. 2 × r per LEAP convention. |
| dropout: 0.05 — light regularization at POC scale. |
| target_modules: Override the default LFM2.5 target list. None |
| means use the multi-surface default (attn + lm_head). |
| """ |
| if target_modules is None: |
| target_modules = DEFAULT_MULTISURFACE_TARGETS |
| return LoraConfig( |
| r=r, |
| lora_alpha=alpha, |
| lora_dropout=dropout, |
| target_modules=target_modules, |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
|
|