lfm2-transaction-encoder / encoder /src /data /mixed_modality.py
cdotsanghvi's picture
initial transaction co-pilot deployment
b3112c7
Raw
History Blame Contribute Delete
9.78 kB
"""Mixed-modality batch builder: [tx pseudo-tokens, SEP, text tokens].
This is the load-bearing data primitive for every multi-surface head
in ADR 0003. It converts a batch of `(feature_ids, text_prompt,
labels)` triples into the combined embedding sequence the LFM2.5
backbone consumes, plus the auxiliary tensors each head needs.
The boundary between this module and the model:
- `MixedModalityBatch` is dtype/device-aware tensors plus index
bookkeeping. It does NOT call the backbone.
- The encoder + projector run separately to produce the tx
pseudo-tokens.
- The wrapper's `embed_text` / `embed_sep` produce the text and
SEP embeddings.
- The trainer is responsible for concatenating the three and
invoking the backbone's `forward_mixed`.
Why split this way: keeping the batch builder backbone-free means
unit tests for the batch shape don't need a 350M model loaded.
Layout (per batch element, after concat):
positions [0, 64) tx pseudo-tokens (from encoder+projector)
position 64 SEP token embedding
positions [65, 65 + T_txt) text token embeddings
Total sequence length: 64 + 1 + T_txt, where T_txt varies per batch
element. The collator pads text to the batch max and produces an
attention mask.
The LM training label layout (for the LM loss):
lm_targets : (B, T_total) int64 token ids
lm_target_mask: (B, T_total) bool — True at positions where the
LM loss is computed
Convention: lm_target_mask is True at the text positions only,
False at tx pseudo-token positions and at the SEP. The trainer
shifts targets by one position for the next-token prediction
objective.
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
from transformers import PreTrainedTokenizerBase
@dataclass
class MixedModalityBatch:
"""A single batch ready for the multi-surface backbone.
Attributes:
feature_ids: (B, T_tx, F) int64 — passed to the encoder.
text_input_ids: (B, T_txt_max) int64 — padded text tokens. The
real length per element is given by `text_lengths`.
text_attention_mask: (B, T_txt_max) bool/int — 1 at real text
positions, 0 at padding.
text_lengths: (B,) int64 — real text length per element.
labels_probability: (B,) int64 — probability head class index
(0=unlikely, 1=ambiguous, 2=likely for dispute legitimacy).
May be None for batches that only train the attribution
head or the LM head.
labels_attribution: (B, T_tx) float — per-tx contribution
labels in {0.0, 1.0}. May be None.
labels_lm: (B, T_lm) int64 — reasoning_text target tokens for
teacher-forced LM loss. The trainer aligns these to the
text positions in the combined sequence and shifts by one
for next-token prediction. May be None.
head_target: which head's loss to compute on this batch. The
surface_trainer uses per-batch homogeneous-head sampling,
so each batch is tagged for one head. Values:
"probability" | "attribution" | "lm".
"""
feature_ids: torch.Tensor
text_input_ids: torch.Tensor
text_attention_mask: torch.Tensor
text_lengths: torch.Tensor
head_target: str
# (B,) int64 — position of the disputed transaction. The model
# adds a learnable marker embedding to the encoder output at this
# position so the downstream backbone, attribution head, and LM
# head all know which transaction is being asked about. Required
# by the dispute-legitimacy data contract per ADR 0003 §Surface 1.
disputed_idx: torch.Tensor | None = None
labels_probability: torch.Tensor | None = None
labels_attribution: torch.Tensor | None = None
labels_lm: torch.Tensor | None = None
def to(self, device: torch.device) -> "MixedModalityBatch":
"""Move all tensors to `device` without copying labels that are None."""
def _move(t: torch.Tensor | None) -> torch.Tensor | None:
return None if t is None else t.to(device, non_blocking=True)
return MixedModalityBatch(
feature_ids=self.feature_ids.to(device, non_blocking=True),
text_input_ids=self.text_input_ids.to(device, non_blocking=True),
text_attention_mask=self.text_attention_mask.to(device, non_blocking=True),
text_lengths=self.text_lengths.to(device, non_blocking=True),
head_target=self.head_target,
disputed_idx=_move(self.disputed_idx),
labels_probability=_move(self.labels_probability),
labels_attribution=_move(self.labels_attribution),
labels_lm=_move(self.labels_lm),
)
@property
def batch_size(self) -> int:
return int(self.feature_ids.shape[0])
@property
def num_tx_positions(self) -> int:
return int(self.feature_ids.shape[1])
@property
def t_total(self) -> int:
"""Total combined-sequence length used by the backbone.
= num_tx_positions + 1 (SEP) + T_txt_max.
"""
return self.num_tx_positions + 1 + int(self.text_input_ids.shape[1])
def tokenize_texts(
tokenizer: PreTrainedTokenizerBase,
texts: list[str],
max_length: int = 256,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Tokenize a list of strings with right-padding.
Args:
tokenizer: any HF-style tokenizer (the wrapper's `self.tokenizer`).
texts: list of `B` strings to tokenize.
max_length: cap; texts longer than this are truncated.
Returns:
(input_ids, attention_mask, lengths) where:
input_ids: (B, T_max) int64
attention_mask: (B, T_max) int64 (1 real, 0 pad)
lengths: (B,) int64 — real length per row
"""
enc = tokenizer(
texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
)
input_ids = enc["input_ids"].long()
attention_mask = enc["attention_mask"].long()
lengths = attention_mask.sum(dim=1)
return input_ids, attention_mask, lengths
def build_combined_attention_mask(
batch_size: int,
num_tx_positions: int,
text_attention_mask: torch.Tensor,
device: torch.device,
) -> torch.Tensor:
"""Build the attention mask for the combined [tx, SEP, text] sequence.
The transaction pseudo-tokens are always real (no padding among
them in Phase 1: every sequence is fully 64-tx populated). The
SEP position is always real. Text positions follow the caller's
`text_attention_mask`.
Args:
batch_size: B.
num_tx_positions: 64 in Phase 1.
text_attention_mask: (B, T_txt) int — 1 at real positions.
device: target device for the resulting mask.
Returns:
(B, num_tx_positions + 1 + T_txt) int — combined attention
mask. dtype int64 for safety.
"""
# (B, num_tx_positions) all ones — every tx position is real
tx_mask = torch.ones(
(batch_size, num_tx_positions),
dtype=torch.long,
device=device,
)
# (B, 1) all ones — SEP is real
sep_mask = torch.ones(
(batch_size, 1),
dtype=torch.long,
device=device,
)
# (B, T_txt) text padding mask
txt_mask = text_attention_mask.to(device=device, dtype=torch.long)
# (B, num_tx_positions + 1 + T_txt)
return torch.cat([tx_mask, sep_mask, txt_mask], dim=1)
def build_lm_target_layout(
batch_size: int,
num_tx_positions: int,
text_input_ids: torch.Tensor,
text_attention_mask: torch.Tensor,
ignore_index: int = -100,
) -> torch.Tensor:
"""Build LM targets aligned to the combined sequence.
For teacher-forced next-token prediction over the text portion:
combined_inputs[pos] -> backbone -> hidden[pos] -> lm_logits[pos]
loss compares lm_logits[pos] against combined_inputs[pos + 1]
We construct a target tensor of length T_total where:
- positions [0, num_tx_positions) get ignore_index (no LM loss
on tx pseudo-tokens — they aren't in the vocab anyway).
- position num_tx_positions (SEP) gets ignore_index.
- positions [num_tx_positions + 1, T_total) hold the text token
ids, with padding positions also masked to ignore_index.
The trainer does the standard shift-by-one: loss = CE(
logits[..., :-1, :], targets[..., 1:]
). The first text token's prediction context is the SEP at
position num_tx_positions; the SEP's logits predict text_token[0].
This is the correct alignment for next-token prediction over the
reasoning_text.
Args:
batch_size: B.
num_tx_positions: 64 in Phase 1.
text_input_ids: (B, T_txt) int64.
text_attention_mask: (B, T_txt) int — 1 at real positions.
ignore_index: target value for positions excluded from loss.
CrossEntropyLoss treats this as a no-op.
Returns:
(B, T_total) int64 — LM targets, with ignore_index at all
non-text positions and at text padding positions.
"""
t_txt = text_input_ids.shape[1]
t_total = num_tx_positions + 1 + t_txt
device = text_input_ids.device
targets = torch.full(
(batch_size, t_total),
fill_value=ignore_index,
dtype=torch.long,
device=device,
)
# Fill in text positions
text_start = num_tx_positions + 1
targets[:, text_start:] = text_input_ids
# Mask padding positions to ignore_index
padding_positions = text_attention_mask == 0
targets[:, text_start:][padding_positions] = ignore_index
return targets