lfm2-transaction-encoder / encoder /src /model /lfm_multisurface_wrapper.py
cdotsanghvi's picture
initial transaction co-pilot deployment
b3112c7
Raw
History Blame Contribute Delete
17.5 kB
"""LFM2.5 wrapper extended with the LM head and the mixed-modality input path.
Sibling of `lfm_pseudo_token_wrapper.LfmPseudoTokenBackbone`. The
original wraps `Lfm2Model` (no LM head — saves 67M params unused by
the four classification heads). For the multi-surface expansion we
need the LM head active (for reasoning-text generation) AND we need
to concatenate the customer's transaction pseudo-tokens with the
analyst's free-text query so the backbone attends across both
modalities.
Why a sibling instead of an in-place edit:
The existing wrapper is in production for the current demo. Phase
1 of the multi-surface expansion is strictly additive — the
original wrapper continues to back the existing four-head demo;
this sibling backs the new dispute-legitimacy + Co-Pilot surfaces.
When all five surfaces ship and the Co-Pilot replaces the existing
demo, we can decide to consolidate or keep both.
What changes vs the original:
1. Loads `Lfm2ForCausalLM` (which is `Lfm2Model + lm_head`)
instead of `Lfm2Model`. The LM head is tied to embed_tokens,
so storage cost is shared. Memory delta is the lm_head Linear
module's bookkeeping, not its weights.
2. Exposes `embed_tokens` and the SEP token id so callers can
build the mixed-modality input embedding sequence.
3. PEFT `task_type` becomes `CAUSAL_LM` so PEFT preserves the LM
head correctly. We also widen `target_modules` to include
`lm_head` so the LoRA can shape generation.
4. Adds `forward_mixed` which takes pre-built combined input
embeddings (tx_pseudo + SEP + text) and returns BOTH the
hidden states (for the probability + attribution heads) AND
the LM logits over the text positions (for the LM loss).
Frozen-base invariant preserved:
`freeze_base()` still sets requires_grad=False on every base
parameter. PEFT adds LoRA layers AFTER freezing; their
requires_grad=True is preserved by PEFT's wrap logic.
LoRA target modules:
Same names as the original (LFM2 uses `q_proj`/`k_proj`/`v_proj`/
`out_proj`/`in_proj`/`w1`/`w2`/`w3`). The default attention-only
LoRA stays r=16. For the LM head we add `lm_head` to the targets.
PEFT will create a separate LoRA matrix for the lm_head Linear.
"""
from __future__ import annotations
from pathlib import Path
import torch
import torch.nn as nn
from peft import LoraConfig, get_peft_model
from transformers import AutoTokenizer, Lfm2ForCausalLM, PreTrainedTokenizerBase
# Default LoRA targets for the multi-surface wrapper. Attention-only
# adapter at r=16 (matches the production Phase-1 demo) PLUS the LM
# head Linear at the same rank. The LM head is a single 2048-out (or
# 1024 for 350M) Linear so r=16 is plenty of capacity — the doctrine
# suggests r=8 for the LM head separately but Phase 1 uses a single
# shared rank for simplicity.
DEFAULT_MULTISURFACE_TARGETS = [
"q_proj",
"k_proj",
"v_proj",
"out_proj",
"lm_head",
]
class LfmMultiSurfaceBackbone(nn.Module):
"""LFM2.5 wrapper supporting tx pseudo-tokens + text tokens + LM head.
Forward signatures:
forward(pseudo_tokens, attention_mask=None) -> hidden
Same as the original wrapper. Used when no text input is
needed (the existing four-head heads still work via this
path).
forward_mixed(combined_embeds, lm_head_positions, attention_mask)
-> dict(hidden_states, lm_logits)
The mixed-modality forward. `combined_embeds` is
(B, T_total, D) with tx pseudo-tokens at positions
[0, num_tx_positions), SEP at position num_tx_positions,
and text token embeddings at positions
(num_tx_positions + 1, T_total). `lm_head_positions` is the
slice of positions where the LM head should produce logits
— typically the text positions only.
Args:
model_path: HF-format directory or repo ID for LFM2.5.
lora: LoraConfig instance, or None for frozen-base-only.
dtype: bfloat16 for GPU training, float32 for CPU smoke.
device_map: "auto" for GPU, None for CPU.
trust_remote_code: True for LiquidAI checkpoints.
freeze_base: Set base params to non-trainable at init.
sep_token_id: Token id used as the separator between the
transaction pseudo-tokens and the text tokens. Defaults to
the tokenizer's BOS or special sep token; caller can
override. The SEP embedding is taken from the model's own
embed_tokens table at this id.
"""
def __init__(
self,
model_path: str | Path,
lora: LoraConfig | None = None,
dtype: torch.dtype = torch.bfloat16,
device_map: str | None = "auto",
trust_remote_code: bool = True,
freeze_base: bool = True,
sep_token_id: int | None = None,
) -> None:
super().__init__()
load_kwargs: dict = {
"torch_dtype": dtype,
"trust_remote_code": trust_remote_code,
}
if device_map is not None:
load_kwargs["device_map"] = device_map
# Lfm2ForCausalLM = Lfm2Model + lm_head (tied to embed_tokens).
self.base = Lfm2ForCausalLM.from_pretrained(
str(model_path),
**load_kwargs,
)
# Architectural facts captured early — they're needed by the
# heads and by the tokenizer. d_lfm = hidden_size; vocab_size
# is used by the LM loss.
self.d_lfm: int = self.base.config.hidden_size
self.vocab_size: int = self.base.config.vocab_size
# Load the tokenizer to resolve the SEP token id and to give
# callers a single tokenizer instance keyed to this model.
self.tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(
str(model_path),
trust_remote_code=trust_remote_code,
)
# Decide the SEP token. Strategy: prefer the explicit sep_token
# if the tokenizer has one; otherwise fall back to BOS, then to
# the explicit override, then to 0. The point of the SEP token
# is to mark the modality boundary inside the embedding
# sequence; the LFM2.5 tokenizer ships several plausible
# candidates. We resolve once at construction time and store.
if sep_token_id is not None:
self.sep_token_id = sep_token_id
elif self.tokenizer.sep_token_id is not None:
self.sep_token_id = self.tokenizer.sep_token_id
elif self.tokenizer.bos_token_id is not None:
self.sep_token_id = self.tokenizer.bos_token_id
else:
self.sep_token_id = 0
if freeze_base:
self.freeze_base()
if lora is not None:
# get_peft_model wraps the base and registers LoRA layers
# as trainable. task_type=CAUSAL_LM preserves the LM head.
self.base = get_peft_model(self.base, lora)
# ---- frozen-base hygiene ----
def freeze_base(self) -> None:
"""Set requires_grad=False on every base parameter."""
for p in self.base.parameters():
p.requires_grad = False
# ---- embedding access ----
@property
def _lfm_for_causal_lm(self) -> nn.Module:
"""Resolve through PEFT wrapping to the underlying Lfm2ForCausalLM."""
# When wrapped by PEFT: base = PeftModelForCausalLM,
# base.base_model = LoraModel, base.base_model.model = Lfm2ForCausalLM.
# When not wrapped: base IS the Lfm2ForCausalLM.
inner = getattr(self.base, "base_model", None)
if inner is not None:
inner = getattr(inner, "model", inner)
else:
inner = self.base
return inner
@property
def _inner_lfm_model(self) -> nn.Module:
"""The Lfm2Model (no LM head). For hidden-state-only forwards."""
# Lfm2ForCausalLM.model is the inner Lfm2Model.
return self._lfm_for_causal_lm.model
@property
def embed_tokens(self) -> nn.Embedding:
"""The model's input-token embedding table.
Used by callers to embed text tokens before concatenation with
the transaction pseudo-tokens.
"""
return self._inner_lfm_model.embed_tokens
def embed_text(self, input_ids: torch.Tensor) -> torch.Tensor:
"""Embed text token ids using the model's own table.
Args:
input_ids: (B, T_txt) int64 token ids.
Returns:
(B, T_txt, D) embeddings in the model's dtype.
"""
return self.embed_tokens(input_ids)
def embed_sep(self, batch_size: int, device: torch.device) -> torch.Tensor:
"""Build the SEP embedding for a batch.
Returns:
(B, 1, D) SEP embedding, one slot per batch element.
"""
sep_id = torch.full(
(batch_size, 1),
fill_value=self.sep_token_id,
dtype=torch.long,
device=device,
)
return self.embed_tokens(sep_id)
# ---- forward paths ----
def forward(
self,
pseudo_tokens: torch.Tensor,
attention_mask: torch.Tensor | None = None,
) -> torch.Tensor:
"""Original tx-only forward (parity with the legacy wrapper).
Used by the existing four classification heads that don't need
a text input. Returns hidden states only; the LM head is not
invoked.
Args:
pseudo_tokens: (B, T_tx, D) tx pseudo-tokens from the
projector. Must already be in the base's dtype.
attention_mask: (B, T_tx) optional padding mask.
Returns:
(B, T_tx, D) `last_hidden_state` after the LFM2.5 stack.
"""
# Skip the LM head — call the inner Lfm2Model directly. PEFT's
# LoRA wrap is on the parent Lfm2ForCausalLM, but the LoRA
# adapters on attention modules are still active here because
# the adapters are nested inside `self_attn`, which lives on
# Lfm2Model.layers[i]. Only the LoRA on `lm_head` is bypassed,
# which is fine because this forward doesn't use the LM head.
outputs = self._inner_lfm_model(
input_ids=None,
inputs_embeds=pseudo_tokens,
attention_mask=attention_mask,
use_cache=False,
)
return outputs.last_hidden_state
def forward_mixed(
self,
combined_embeds: torch.Tensor,
attention_mask: torch.Tensor | None = None,
compute_lm_logits: bool = True,
) -> dict[str, torch.Tensor]:
"""Mixed-modality forward over [tx_pseudo, SEP, text].
Args:
combined_embeds: (B, T_total, D) the pre-concatenated
embedding sequence built by the MixedModalityBatch.
attention_mask: (B, T_total) padding mask. 1 for real
positions, 0 for padding. Required when text lengths
vary across the batch.
compute_lm_logits: If True, run the LM head and return
lm_logits. If False, skip the LM head — saves a
vocab-sized matmul when the caller only wants hidden
states (probability + attribution heads).
Returns:
Dict with:
hidden_states: (B, T_total, D) — for probability +
attribution heads.
lm_logits: (B, T_total, vocab_size) if
compute_lm_logits=True, else absent.
"""
# Run the inner Lfm2Model to get hidden states. This bypasses
# the LM head Linear; we apply it explicitly below if requested.
# Going through the inner model (rather than the PEFT-wrapped
# Lfm2ForCausalLM) keeps the return type as BaseModelOutput
# with `last_hidden_state` populated.
model_outputs = self._inner_lfm_model(
input_ids=None,
inputs_embeds=combined_embeds,
attention_mask=attention_mask,
use_cache=False,
)
hidden = model_outputs.last_hidden_state # (B, T_total, D)
out: dict[str, torch.Tensor] = {"hidden_states": hidden}
if compute_lm_logits:
# Use the PEFT-wrapped lm_head so its LoRA adapter is
# active. `self._lfm_for_causal_lm.lm_head` is the LM head
# of the Lfm2ForCausalLM module that PEFT wrapped.
out["lm_logits"] = self._lfm_for_causal_lm.lm_head(hidden)
return out
# ---- generation helpers ----
@torch.no_grad()
def generate_reasoning(
self,
combined_embeds: torch.Tensor,
attention_mask: torch.Tensor | None = None,
max_new_tokens: int = 128,
temperature: float = 0.0,
) -> torch.Tensor:
"""Greedy / temperature-sampled generation from a mixed prefix.
Args:
combined_embeds: (B, T_prefix, D) the [tx_pseudo, SEP,
text_prompt] prefix to condition generation on.
attention_mask: (B, T_prefix) padding mask for the prefix.
max_new_tokens: how many new text tokens to produce.
temperature: 0.0 = greedy; >0 = temperature sampling.
Returns:
(B, max_new_tokens) int64 generated token ids. Caller
decodes via `self.tokenizer.batch_decode`.
Note:
We hand-roll generation rather than calling `base.generate`
because the input is `inputs_embeds`, not `input_ids`.
HuggingFace's `generate` supports `inputs_embeds` on most
architectures but has edge cases with hybrid models
(cache shape, prepare_inputs_for_generation). Hand-rolled
greedy gives us full control and matches the doctrine
"eval-time decoder hygiene = temperature=0 greedy."
"""
device = combined_embeds.device
# Working buffer of next-token ids to append to the prefix.
generated_ids: list[torch.Tensor] = []
# The embedding sequence grows by one position per step.
# Start with the prefix; append the next token's embedding.
current_embeds = combined_embeds
current_mask = attention_mask
for _ in range(max_new_tokens):
outputs = self.forward_mixed(
combined_embeds=current_embeds,
attention_mask=current_mask,
compute_lm_logits=True,
)
# (B, vocab_size) — logits at the last position
next_logits = outputs["lm_logits"][:, -1, :]
if temperature == 0.0:
next_token = next_logits.argmax(dim=-1, keepdim=True)
else:
probs = torch.softmax(next_logits / temperature, dim=-1)
next_token = torch.multinomial(probs, num_samples=1)
generated_ids.append(next_token)
# Stop if every sample emitted EOS. We check the
# tokenizer's eos_token_id (None-safe).
eos_id = self.tokenizer.eos_token_id
if eos_id is not None and (next_token == eos_id).all():
break
# Append the new token's embedding to the running prefix.
next_embed = self.embed_tokens(next_token)
current_embeds = torch.cat([current_embeds, next_embed], dim=1)
if current_mask is not None:
ones = torch.ones(
(current_mask.shape[0], 1),
dtype=current_mask.dtype,
device=device,
)
current_mask = torch.cat([current_mask, ones], dim=1)
return torch.cat(generated_ids, dim=1)
# ---- introspection ----
def trainable_parameters(self) -> int:
return sum(p.numel() for p in self.base.parameters() if p.requires_grad)
def total_parameters(self) -> int:
return sum(p.numel() for p in self.base.parameters())
def lora_module_count(self) -> int:
"""Count LoRA-tagged modules. Used by the smoke gate.
Doctrine: expect ~92 modules on LFM2.5-350M for the standard
attention-only target set. If we add `lm_head`, expect +1.
<50 means target_modules names are wrong.
"""
return sum(
1 for name, _ in self.base.named_modules() if "lora" in name.lower()
)
def build_multisurface_lora_config(
r: int = 16,
alpha: int = 32,
dropout: float = 0.05,
target_modules: list[str] | None = None,
) -> LoraConfig:
"""Build a LoraConfig for the multi-surface backbone.
Defaults match the Phase-1 dispute-legitimacy recipe: r=16,
alpha=32 (= 2r per the LEAP golden path), dropout=0.05, target
modules = attention + lm_head. `task_type="CAUSAL_LM"` ensures
PEFT preserves the LM head structure so generation works after
wrapping.
Args:
r: LoRA rank. 16 for Phase 1; ablate to 8 / 32 later.
alpha: LoRA scaling. 2 × r per LEAP convention.
dropout: 0.05 — light regularization at POC scale.
target_modules: Override the default LFM2.5 target list. None
means use the multi-surface default (attn + lm_head).
"""
if target_modules is None:
target_modules = DEFAULT_MULTISURFACE_TARGETS
return LoraConfig(
r=r,
lora_alpha=alpha,
lora_dropout=dropout,
target_modules=target_modules,
bias="none",
task_type="CAUSAL_LM",
)