"""LFM2.5 wrapper extended with the LM head and the mixed-modality input path. Sibling of `lfm_pseudo_token_wrapper.LfmPseudoTokenBackbone`. The original wraps `Lfm2Model` (no LM head — saves 67M params unused by the four classification heads). For the multi-surface expansion we need the LM head active (for reasoning-text generation) AND we need to concatenate the customer's transaction pseudo-tokens with the analyst's free-text query so the backbone attends across both modalities. Why a sibling instead of an in-place edit: The existing wrapper is in production for the current demo. Phase 1 of the multi-surface expansion is strictly additive — the original wrapper continues to back the existing four-head demo; this sibling backs the new dispute-legitimacy + Co-Pilot surfaces. When all five surfaces ship and the Co-Pilot replaces the existing demo, we can decide to consolidate or keep both. What changes vs the original: 1. Loads `Lfm2ForCausalLM` (which is `Lfm2Model + lm_head`) instead of `Lfm2Model`. The LM head is tied to embed_tokens, so storage cost is shared. Memory delta is the lm_head Linear module's bookkeeping, not its weights. 2. Exposes `embed_tokens` and the SEP token id so callers can build the mixed-modality input embedding sequence. 3. PEFT `task_type` becomes `CAUSAL_LM` so PEFT preserves the LM head correctly. We also widen `target_modules` to include `lm_head` so the LoRA can shape generation. 4. Adds `forward_mixed` which takes pre-built combined input embeddings (tx_pseudo + SEP + text) and returns BOTH the hidden states (for the probability + attribution heads) AND the LM logits over the text positions (for the LM loss). Frozen-base invariant preserved: `freeze_base()` still sets requires_grad=False on every base parameter. PEFT adds LoRA layers AFTER freezing; their requires_grad=True is preserved by PEFT's wrap logic. LoRA target modules: Same names as the original (LFM2 uses `q_proj`/`k_proj`/`v_proj`/ `out_proj`/`in_proj`/`w1`/`w2`/`w3`). The default attention-only LoRA stays r=16. For the LM head we add `lm_head` to the targets. PEFT will create a separate LoRA matrix for the lm_head Linear. """ from __future__ import annotations from pathlib import Path import torch import torch.nn as nn from peft import LoraConfig, get_peft_model from transformers import AutoTokenizer, Lfm2ForCausalLM, PreTrainedTokenizerBase # Default LoRA targets for the multi-surface wrapper. Attention-only # adapter at r=16 (matches the production Phase-1 demo) PLUS the LM # head Linear at the same rank. The LM head is a single 2048-out (or # 1024 for 350M) Linear so r=16 is plenty of capacity — the doctrine # suggests r=8 for the LM head separately but Phase 1 uses a single # shared rank for simplicity. DEFAULT_MULTISURFACE_TARGETS = [ "q_proj", "k_proj", "v_proj", "out_proj", "lm_head", ] class LfmMultiSurfaceBackbone(nn.Module): """LFM2.5 wrapper supporting tx pseudo-tokens + text tokens + LM head. Forward signatures: forward(pseudo_tokens, attention_mask=None) -> hidden Same as the original wrapper. Used when no text input is needed (the existing four-head heads still work via this path). forward_mixed(combined_embeds, lm_head_positions, attention_mask) -> dict(hidden_states, lm_logits) The mixed-modality forward. `combined_embeds` is (B, T_total, D) with tx pseudo-tokens at positions [0, num_tx_positions), SEP at position num_tx_positions, and text token embeddings at positions (num_tx_positions + 1, T_total). `lm_head_positions` is the slice of positions where the LM head should produce logits — typically the text positions only. Args: model_path: HF-format directory or repo ID for LFM2.5. lora: LoraConfig instance, or None for frozen-base-only. dtype: bfloat16 for GPU training, float32 for CPU smoke. device_map: "auto" for GPU, None for CPU. trust_remote_code: True for LiquidAI checkpoints. freeze_base: Set base params to non-trainable at init. sep_token_id: Token id used as the separator between the transaction pseudo-tokens and the text tokens. Defaults to the tokenizer's BOS or special sep token; caller can override. The SEP embedding is taken from the model's own embed_tokens table at this id. """ def __init__( self, model_path: str | Path, lora: LoraConfig | None = None, dtype: torch.dtype = torch.bfloat16, device_map: str | None = "auto", trust_remote_code: bool = True, freeze_base: bool = True, sep_token_id: int | None = None, ) -> None: super().__init__() load_kwargs: dict = { "torch_dtype": dtype, "trust_remote_code": trust_remote_code, } if device_map is not None: load_kwargs["device_map"] = device_map # Lfm2ForCausalLM = Lfm2Model + lm_head (tied to embed_tokens). self.base = Lfm2ForCausalLM.from_pretrained( str(model_path), **load_kwargs, ) # Architectural facts captured early — they're needed by the # heads and by the tokenizer. d_lfm = hidden_size; vocab_size # is used by the LM loss. self.d_lfm: int = self.base.config.hidden_size self.vocab_size: int = self.base.config.vocab_size # Load the tokenizer to resolve the SEP token id and to give # callers a single tokenizer instance keyed to this model. self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained( str(model_path), trust_remote_code=trust_remote_code, ) # Decide the SEP token. Strategy: prefer the explicit sep_token # if the tokenizer has one; otherwise fall back to BOS, then to # the explicit override, then to 0. The point of the SEP token # is to mark the modality boundary inside the embedding # sequence; the LFM2.5 tokenizer ships several plausible # candidates. We resolve once at construction time and store. if sep_token_id is not None: self.sep_token_id = sep_token_id elif self.tokenizer.sep_token_id is not None: self.sep_token_id = self.tokenizer.sep_token_id elif self.tokenizer.bos_token_id is not None: self.sep_token_id = self.tokenizer.bos_token_id else: self.sep_token_id = 0 if freeze_base: self.freeze_base() if lora is not None: # get_peft_model wraps the base and registers LoRA layers # as trainable. task_type=CAUSAL_LM preserves the LM head. self.base = get_peft_model(self.base, lora) # ---- frozen-base hygiene ---- def freeze_base(self) -> None: """Set requires_grad=False on every base parameter.""" for p in self.base.parameters(): p.requires_grad = False # ---- embedding access ---- @property def _lfm_for_causal_lm(self) -> nn.Module: """Resolve through PEFT wrapping to the underlying Lfm2ForCausalLM.""" # When wrapped by PEFT: base = PeftModelForCausalLM, # base.base_model = LoraModel, base.base_model.model = Lfm2ForCausalLM. # When not wrapped: base IS the Lfm2ForCausalLM. inner = getattr(self.base, "base_model", None) if inner is not None: inner = getattr(inner, "model", inner) else: inner = self.base return inner @property def _inner_lfm_model(self) -> nn.Module: """The Lfm2Model (no LM head). For hidden-state-only forwards.""" # Lfm2ForCausalLM.model is the inner Lfm2Model. return self._lfm_for_causal_lm.model @property def embed_tokens(self) -> nn.Embedding: """The model's input-token embedding table. Used by callers to embed text tokens before concatenation with the transaction pseudo-tokens. """ return self._inner_lfm_model.embed_tokens def embed_text(self, input_ids: torch.Tensor) -> torch.Tensor: """Embed text token ids using the model's own table. Args: input_ids: (B, T_txt) int64 token ids. Returns: (B, T_txt, D) embeddings in the model's dtype. """ return self.embed_tokens(input_ids) def embed_sep(self, batch_size: int, device: torch.device) -> torch.Tensor: """Build the SEP embedding for a batch. Returns: (B, 1, D) SEP embedding, one slot per batch element. """ sep_id = torch.full( (batch_size, 1), fill_value=self.sep_token_id, dtype=torch.long, device=device, ) return self.embed_tokens(sep_id) # ---- forward paths ---- def forward( self, pseudo_tokens: torch.Tensor, attention_mask: torch.Tensor | None = None, ) -> torch.Tensor: """Original tx-only forward (parity with the legacy wrapper). Used by the existing four classification heads that don't need a text input. Returns hidden states only; the LM head is not invoked. Args: pseudo_tokens: (B, T_tx, D) tx pseudo-tokens from the projector. Must already be in the base's dtype. attention_mask: (B, T_tx) optional padding mask. Returns: (B, T_tx, D) `last_hidden_state` after the LFM2.5 stack. """ # Skip the LM head — call the inner Lfm2Model directly. PEFT's # LoRA wrap is on the parent Lfm2ForCausalLM, but the LoRA # adapters on attention modules are still active here because # the adapters are nested inside `self_attn`, which lives on # Lfm2Model.layers[i]. Only the LoRA on `lm_head` is bypassed, # which is fine because this forward doesn't use the LM head. outputs = self._inner_lfm_model( input_ids=None, inputs_embeds=pseudo_tokens, attention_mask=attention_mask, use_cache=False, ) return outputs.last_hidden_state def forward_mixed( self, combined_embeds: torch.Tensor, attention_mask: torch.Tensor | None = None, compute_lm_logits: bool = True, ) -> dict[str, torch.Tensor]: """Mixed-modality forward over [tx_pseudo, SEP, text]. Args: combined_embeds: (B, T_total, D) the pre-concatenated embedding sequence built by the MixedModalityBatch. attention_mask: (B, T_total) padding mask. 1 for real positions, 0 for padding. Required when text lengths vary across the batch. compute_lm_logits: If True, run the LM head and return lm_logits. If False, skip the LM head — saves a vocab-sized matmul when the caller only wants hidden states (probability + attribution heads). Returns: Dict with: hidden_states: (B, T_total, D) — for probability + attribution heads. lm_logits: (B, T_total, vocab_size) if compute_lm_logits=True, else absent. """ # Run the inner Lfm2Model to get hidden states. This bypasses # the LM head Linear; we apply it explicitly below if requested. # Going through the inner model (rather than the PEFT-wrapped # Lfm2ForCausalLM) keeps the return type as BaseModelOutput # with `last_hidden_state` populated. model_outputs = self._inner_lfm_model( input_ids=None, inputs_embeds=combined_embeds, attention_mask=attention_mask, use_cache=False, ) hidden = model_outputs.last_hidden_state # (B, T_total, D) out: dict[str, torch.Tensor] = {"hidden_states": hidden} if compute_lm_logits: # Use the PEFT-wrapped lm_head so its LoRA adapter is # active. `self._lfm_for_causal_lm.lm_head` is the LM head # of the Lfm2ForCausalLM module that PEFT wrapped. out["lm_logits"] = self._lfm_for_causal_lm.lm_head(hidden) return out # ---- generation helpers ---- @torch.no_grad() def generate_reasoning( self, combined_embeds: torch.Tensor, attention_mask: torch.Tensor | None = None, max_new_tokens: int = 128, temperature: float = 0.0, ) -> torch.Tensor: """Greedy / temperature-sampled generation from a mixed prefix. Args: combined_embeds: (B, T_prefix, D) the [tx_pseudo, SEP, text_prompt] prefix to condition generation on. attention_mask: (B, T_prefix) padding mask for the prefix. max_new_tokens: how many new text tokens to produce. temperature: 0.0 = greedy; >0 = temperature sampling. Returns: (B, max_new_tokens) int64 generated token ids. Caller decodes via `self.tokenizer.batch_decode`. Note: We hand-roll generation rather than calling `base.generate` because the input is `inputs_embeds`, not `input_ids`. HuggingFace's `generate` supports `inputs_embeds` on most architectures but has edge cases with hybrid models (cache shape, prepare_inputs_for_generation). Hand-rolled greedy gives us full control and matches the doctrine "eval-time decoder hygiene = temperature=0 greedy." """ device = combined_embeds.device # Working buffer of next-token ids to append to the prefix. generated_ids: list[torch.Tensor] = [] # The embedding sequence grows by one position per step. # Start with the prefix; append the next token's embedding. current_embeds = combined_embeds current_mask = attention_mask for _ in range(max_new_tokens): outputs = self.forward_mixed( combined_embeds=current_embeds, attention_mask=current_mask, compute_lm_logits=True, ) # (B, vocab_size) — logits at the last position next_logits = outputs["lm_logits"][:, -1, :] if temperature == 0.0: next_token = next_logits.argmax(dim=-1, keepdim=True) else: probs = torch.softmax(next_logits / temperature, dim=-1) next_token = torch.multinomial(probs, num_samples=1) generated_ids.append(next_token) # Stop if every sample emitted EOS. We check the # tokenizer's eos_token_id (None-safe). eos_id = self.tokenizer.eos_token_id if eos_id is not None and (next_token == eos_id).all(): break # Append the new token's embedding to the running prefix. next_embed = self.embed_tokens(next_token) current_embeds = torch.cat([current_embeds, next_embed], dim=1) if current_mask is not None: ones = torch.ones( (current_mask.shape[0], 1), dtype=current_mask.dtype, device=device, ) current_mask = torch.cat([current_mask, ones], dim=1) return torch.cat(generated_ids, dim=1) # ---- introspection ---- def trainable_parameters(self) -> int: return sum(p.numel() for p in self.base.parameters() if p.requires_grad) def total_parameters(self) -> int: return sum(p.numel() for p in self.base.parameters()) def lora_module_count(self) -> int: """Count LoRA-tagged modules. Used by the smoke gate. Doctrine: expect ~92 modules on LFM2.5-350M for the standard attention-only target set. If we add `lm_head`, expect +1. <50 means target_modules names are wrong. """ return sum( 1 for name, _ in self.base.named_modules() if "lora" in name.lower() ) def build_multisurface_lora_config( r: int = 16, alpha: int = 32, dropout: float = 0.05, target_modules: list[str] | None = None, ) -> LoraConfig: """Build a LoraConfig for the multi-surface backbone. Defaults match the Phase-1 dispute-legitimacy recipe: r=16, alpha=32 (= 2r per the LEAP golden path), dropout=0.05, target modules = attention + lm_head. `task_type="CAUSAL_LM"` ensures PEFT preserves the LM head structure so generation works after wrapping. Args: r: LoRA rank. 16 for Phase 1; ablate to 8 / 32 later. alpha: LoRA scaling. 2 × r per LEAP convention. dropout: 0.05 — light regularization at POC scale. target_modules: Override the default LFM2.5 target list. None means use the multi-surface default (attn + lm_head). """ if target_modules is None: target_modules = DEFAULT_MULTISURFACE_TARGETS return LoraConfig( r=r, lora_alpha=alpha, lora_dropout=dropout, target_modules=target_modules, bias="none", task_type="CAUSAL_LM", )