| """Mixed-modality batch builder: [tx pseudo-tokens, SEP, text tokens]. |
| |
| This is the load-bearing data primitive for every multi-surface head |
| in ADR 0003. It converts a batch of `(feature_ids, text_prompt, |
| labels)` triples into the combined embedding sequence the LFM2.5 |
| backbone consumes, plus the auxiliary tensors each head needs. |
| |
| The boundary between this module and the model: |
| |
| - `MixedModalityBatch` is dtype/device-aware tensors plus index |
| bookkeeping. It does NOT call the backbone. |
| - The encoder + projector run separately to produce the tx |
| pseudo-tokens. |
| - The wrapper's `embed_text` / `embed_sep` produce the text and |
| SEP embeddings. |
| - The trainer is responsible for concatenating the three and |
| invoking the backbone's `forward_mixed`. |
| |
| Why split this way: keeping the batch builder backbone-free means |
| unit tests for the batch shape don't need a 350M model loaded. |
| |
| Layout (per batch element, after concat): |
| |
| positions [0, 64) tx pseudo-tokens (from encoder+projector) |
| position 64 SEP token embedding |
| positions [65, 65 + T_txt) text token embeddings |
| |
| Total sequence length: 64 + 1 + T_txt, where T_txt varies per batch |
| element. The collator pads text to the batch max and produces an |
| attention mask. |
| |
| The LM training label layout (for the LM loss): |
| |
| lm_targets : (B, T_total) int64 token ids |
| lm_target_mask: (B, T_total) bool — True at positions where the |
| LM loss is computed |
| |
| Convention: lm_target_mask is True at the text positions only, |
| False at tx pseudo-token positions and at the SEP. The trainer |
| shifts targets by one position for the next-token prediction |
| objective. |
| """ |
|
|
| from __future__ import annotations |
|
|
| from dataclasses import dataclass |
|
|
| import torch |
| from transformers import PreTrainedTokenizerBase |
|
|
|
|
| @dataclass |
| class MixedModalityBatch: |
| """A single batch ready for the multi-surface backbone. |
| |
| Attributes: |
| feature_ids: (B, T_tx, F) int64 — passed to the encoder. |
| text_input_ids: (B, T_txt_max) int64 — padded text tokens. The |
| real length per element is given by `text_lengths`. |
| text_attention_mask: (B, T_txt_max) bool/int — 1 at real text |
| positions, 0 at padding. |
| text_lengths: (B,) int64 — real text length per element. |
| labels_probability: (B,) int64 — probability head class index |
| (0=unlikely, 1=ambiguous, 2=likely for dispute legitimacy). |
| May be None for batches that only train the attribution |
| head or the LM head. |
| labels_attribution: (B, T_tx) float — per-tx contribution |
| labels in {0.0, 1.0}. May be None. |
| labels_lm: (B, T_lm) int64 — reasoning_text target tokens for |
| teacher-forced LM loss. The trainer aligns these to the |
| text positions in the combined sequence and shifts by one |
| for next-token prediction. May be None. |
| head_target: which head's loss to compute on this batch. The |
| surface_trainer uses per-batch homogeneous-head sampling, |
| so each batch is tagged for one head. Values: |
| "probability" | "attribution" | "lm". |
| """ |
|
|
| feature_ids: torch.Tensor |
| text_input_ids: torch.Tensor |
| text_attention_mask: torch.Tensor |
| text_lengths: torch.Tensor |
| head_target: str |
| |
| |
| |
| |
| |
| disputed_idx: torch.Tensor | None = None |
| labels_probability: torch.Tensor | None = None |
| labels_attribution: torch.Tensor | None = None |
| labels_lm: torch.Tensor | None = None |
|
|
| def to(self, device: torch.device) -> "MixedModalityBatch": |
| """Move all tensors to `device` without copying labels that are None.""" |
|
|
| def _move(t: torch.Tensor | None) -> torch.Tensor | None: |
| return None if t is None else t.to(device, non_blocking=True) |
|
|
| return MixedModalityBatch( |
| feature_ids=self.feature_ids.to(device, non_blocking=True), |
| text_input_ids=self.text_input_ids.to(device, non_blocking=True), |
| text_attention_mask=self.text_attention_mask.to(device, non_blocking=True), |
| text_lengths=self.text_lengths.to(device, non_blocking=True), |
| head_target=self.head_target, |
| disputed_idx=_move(self.disputed_idx), |
| labels_probability=_move(self.labels_probability), |
| labels_attribution=_move(self.labels_attribution), |
| labels_lm=_move(self.labels_lm), |
| ) |
|
|
| @property |
| def batch_size(self) -> int: |
| return int(self.feature_ids.shape[0]) |
|
|
| @property |
| def num_tx_positions(self) -> int: |
| return int(self.feature_ids.shape[1]) |
|
|
| @property |
| def t_total(self) -> int: |
| """Total combined-sequence length used by the backbone. |
| |
| = num_tx_positions + 1 (SEP) + T_txt_max. |
| """ |
| return self.num_tx_positions + 1 + int(self.text_input_ids.shape[1]) |
|
|
|
|
| def tokenize_texts( |
| tokenizer: PreTrainedTokenizerBase, |
| texts: list[str], |
| max_length: int = 256, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Tokenize a list of strings with right-padding. |
| |
| Args: |
| tokenizer: any HF-style tokenizer (the wrapper's `self.tokenizer`). |
| texts: list of `B` strings to tokenize. |
| max_length: cap; texts longer than this are truncated. |
| |
| Returns: |
| (input_ids, attention_mask, lengths) where: |
| input_ids: (B, T_max) int64 |
| attention_mask: (B, T_max) int64 (1 real, 0 pad) |
| lengths: (B,) int64 — real length per row |
| """ |
| enc = tokenizer( |
| texts, |
| padding=True, |
| truncation=True, |
| max_length=max_length, |
| return_tensors="pt", |
| ) |
| input_ids = enc["input_ids"].long() |
| attention_mask = enc["attention_mask"].long() |
| lengths = attention_mask.sum(dim=1) |
| return input_ids, attention_mask, lengths |
|
|
|
|
| def build_combined_attention_mask( |
| batch_size: int, |
| num_tx_positions: int, |
| text_attention_mask: torch.Tensor, |
| device: torch.device, |
| ) -> torch.Tensor: |
| """Build the attention mask for the combined [tx, SEP, text] sequence. |
| |
| The transaction pseudo-tokens are always real (no padding among |
| them in Phase 1: every sequence is fully 64-tx populated). The |
| SEP position is always real. Text positions follow the caller's |
| `text_attention_mask`. |
| |
| Args: |
| batch_size: B. |
| num_tx_positions: 64 in Phase 1. |
| text_attention_mask: (B, T_txt) int — 1 at real positions. |
| device: target device for the resulting mask. |
| |
| Returns: |
| (B, num_tx_positions + 1 + T_txt) int — combined attention |
| mask. dtype int64 for safety. |
| """ |
| |
| tx_mask = torch.ones( |
| (batch_size, num_tx_positions), |
| dtype=torch.long, |
| device=device, |
| ) |
| |
| sep_mask = torch.ones( |
| (batch_size, 1), |
| dtype=torch.long, |
| device=device, |
| ) |
| |
| txt_mask = text_attention_mask.to(device=device, dtype=torch.long) |
| |
| return torch.cat([tx_mask, sep_mask, txt_mask], dim=1) |
|
|
|
|
| def build_lm_target_layout( |
| batch_size: int, |
| num_tx_positions: int, |
| text_input_ids: torch.Tensor, |
| text_attention_mask: torch.Tensor, |
| ignore_index: int = -100, |
| ) -> torch.Tensor: |
| """Build LM targets aligned to the combined sequence. |
| |
| For teacher-forced next-token prediction over the text portion: |
| |
| combined_inputs[pos] -> backbone -> hidden[pos] -> lm_logits[pos] |
| loss compares lm_logits[pos] against combined_inputs[pos + 1] |
| |
| We construct a target tensor of length T_total where: |
| - positions [0, num_tx_positions) get ignore_index (no LM loss |
| on tx pseudo-tokens — they aren't in the vocab anyway). |
| - position num_tx_positions (SEP) gets ignore_index. |
| - positions [num_tx_positions + 1, T_total) hold the text token |
| ids, with padding positions also masked to ignore_index. |
| |
| The trainer does the standard shift-by-one: loss = CE( |
| logits[..., :-1, :], targets[..., 1:] |
| ). The first text token's prediction context is the SEP at |
| position num_tx_positions; the SEP's logits predict text_token[0]. |
| This is the correct alignment for next-token prediction over the |
| reasoning_text. |
| |
| Args: |
| batch_size: B. |
| num_tx_positions: 64 in Phase 1. |
| text_input_ids: (B, T_txt) int64. |
| text_attention_mask: (B, T_txt) int — 1 at real positions. |
| ignore_index: target value for positions excluded from loss. |
| CrossEntropyLoss treats this as a no-op. |
| |
| Returns: |
| (B, T_total) int64 — LM targets, with ignore_index at all |
| non-text positions and at text padding positions. |
| """ |
| t_txt = text_input_ids.shape[1] |
| t_total = num_tx_positions + 1 + t_txt |
| device = text_input_ids.device |
|
|
| targets = torch.full( |
| (batch_size, t_total), |
| fill_value=ignore_index, |
| dtype=torch.long, |
| device=device, |
| ) |
| |
| text_start = num_tx_positions + 1 |
| targets[:, text_start:] = text_input_ids |
| |
| padding_positions = text_attention_mask == 0 |
| targets[:, text_start:][padding_positions] = ignore_index |
| return targets |
|
|