"""Mixed-modality batch builder: [tx pseudo-tokens, SEP, text tokens]. This is the load-bearing data primitive for every multi-surface head in ADR 0003. It converts a batch of `(feature_ids, text_prompt, labels)` triples into the combined embedding sequence the LFM2.5 backbone consumes, plus the auxiliary tensors each head needs. The boundary between this module and the model: - `MixedModalityBatch` is dtype/device-aware tensors plus index bookkeeping. It does NOT call the backbone. - The encoder + projector run separately to produce the tx pseudo-tokens. - The wrapper's `embed_text` / `embed_sep` produce the text and SEP embeddings. - The trainer is responsible for concatenating the three and invoking the backbone's `forward_mixed`. Why split this way: keeping the batch builder backbone-free means unit tests for the batch shape don't need a 350M model loaded. Layout (per batch element, after concat): positions [0, 64) tx pseudo-tokens (from encoder+projector) position 64 SEP token embedding positions [65, 65 + T_txt) text token embeddings Total sequence length: 64 + 1 + T_txt, where T_txt varies per batch element. The collator pads text to the batch max and produces an attention mask. The LM training label layout (for the LM loss): lm_targets : (B, T_total) int64 token ids lm_target_mask: (B, T_total) bool — True at positions where the LM loss is computed Convention: lm_target_mask is True at the text positions only, False at tx pseudo-token positions and at the SEP. The trainer shifts targets by one position for the next-token prediction objective. """ from __future__ import annotations from dataclasses import dataclass import torch from transformers import PreTrainedTokenizerBase @dataclass class MixedModalityBatch: """A single batch ready for the multi-surface backbone. Attributes: feature_ids: (B, T_tx, F) int64 — passed to the encoder. text_input_ids: (B, T_txt_max) int64 — padded text tokens. The real length per element is given by `text_lengths`. text_attention_mask: (B, T_txt_max) bool/int — 1 at real text positions, 0 at padding. text_lengths: (B,) int64 — real text length per element. labels_probability: (B,) int64 — probability head class index (0=unlikely, 1=ambiguous, 2=likely for dispute legitimacy). May be None for batches that only train the attribution head or the LM head. labels_attribution: (B, T_tx) float — per-tx contribution labels in {0.0, 1.0}. May be None. labels_lm: (B, T_lm) int64 — reasoning_text target tokens for teacher-forced LM loss. The trainer aligns these to the text positions in the combined sequence and shifts by one for next-token prediction. May be None. head_target: which head's loss to compute on this batch. The surface_trainer uses per-batch homogeneous-head sampling, so each batch is tagged for one head. Values: "probability" | "attribution" | "lm". """ feature_ids: torch.Tensor text_input_ids: torch.Tensor text_attention_mask: torch.Tensor text_lengths: torch.Tensor head_target: str # (B,) int64 — position of the disputed transaction. The model # adds a learnable marker embedding to the encoder output at this # position so the downstream backbone, attribution head, and LM # head all know which transaction is being asked about. Required # by the dispute-legitimacy data contract per ADR 0003 §Surface 1. disputed_idx: torch.Tensor | None = None labels_probability: torch.Tensor | None = None labels_attribution: torch.Tensor | None = None labels_lm: torch.Tensor | None = None def to(self, device: torch.device) -> "MixedModalityBatch": """Move all tensors to `device` without copying labels that are None.""" def _move(t: torch.Tensor | None) -> torch.Tensor | None: return None if t is None else t.to(device, non_blocking=True) return MixedModalityBatch( feature_ids=self.feature_ids.to(device, non_blocking=True), text_input_ids=self.text_input_ids.to(device, non_blocking=True), text_attention_mask=self.text_attention_mask.to(device, non_blocking=True), text_lengths=self.text_lengths.to(device, non_blocking=True), head_target=self.head_target, disputed_idx=_move(self.disputed_idx), labels_probability=_move(self.labels_probability), labels_attribution=_move(self.labels_attribution), labels_lm=_move(self.labels_lm), ) @property def batch_size(self) -> int: return int(self.feature_ids.shape[0]) @property def num_tx_positions(self) -> int: return int(self.feature_ids.shape[1]) @property def t_total(self) -> int: """Total combined-sequence length used by the backbone. = num_tx_positions + 1 (SEP) + T_txt_max. """ return self.num_tx_positions + 1 + int(self.text_input_ids.shape[1]) def tokenize_texts( tokenizer: PreTrainedTokenizerBase, texts: list[str], max_length: int = 256, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Tokenize a list of strings with right-padding. Args: tokenizer: any HF-style tokenizer (the wrapper's `self.tokenizer`). texts: list of `B` strings to tokenize. max_length: cap; texts longer than this are truncated. Returns: (input_ids, attention_mask, lengths) where: input_ids: (B, T_max) int64 attention_mask: (B, T_max) int64 (1 real, 0 pad) lengths: (B,) int64 — real length per row """ enc = tokenizer( texts, padding=True, truncation=True, max_length=max_length, return_tensors="pt", ) input_ids = enc["input_ids"].long() attention_mask = enc["attention_mask"].long() lengths = attention_mask.sum(dim=1) return input_ids, attention_mask, lengths def build_combined_attention_mask( batch_size: int, num_tx_positions: int, text_attention_mask: torch.Tensor, device: torch.device, ) -> torch.Tensor: """Build the attention mask for the combined [tx, SEP, text] sequence. The transaction pseudo-tokens are always real (no padding among them in Phase 1: every sequence is fully 64-tx populated). The SEP position is always real. Text positions follow the caller's `text_attention_mask`. Args: batch_size: B. num_tx_positions: 64 in Phase 1. text_attention_mask: (B, T_txt) int — 1 at real positions. device: target device for the resulting mask. Returns: (B, num_tx_positions + 1 + T_txt) int — combined attention mask. dtype int64 for safety. """ # (B, num_tx_positions) all ones — every tx position is real tx_mask = torch.ones( (batch_size, num_tx_positions), dtype=torch.long, device=device, ) # (B, 1) all ones — SEP is real sep_mask = torch.ones( (batch_size, 1), dtype=torch.long, device=device, ) # (B, T_txt) text padding mask txt_mask = text_attention_mask.to(device=device, dtype=torch.long) # (B, num_tx_positions + 1 + T_txt) return torch.cat([tx_mask, sep_mask, txt_mask], dim=1) def build_lm_target_layout( batch_size: int, num_tx_positions: int, text_input_ids: torch.Tensor, text_attention_mask: torch.Tensor, ignore_index: int = -100, ) -> torch.Tensor: """Build LM targets aligned to the combined sequence. For teacher-forced next-token prediction over the text portion: combined_inputs[pos] -> backbone -> hidden[pos] -> lm_logits[pos] loss compares lm_logits[pos] against combined_inputs[pos + 1] We construct a target tensor of length T_total where: - positions [0, num_tx_positions) get ignore_index (no LM loss on tx pseudo-tokens — they aren't in the vocab anyway). - position num_tx_positions (SEP) gets ignore_index. - positions [num_tx_positions + 1, T_total) hold the text token ids, with padding positions also masked to ignore_index. The trainer does the standard shift-by-one: loss = CE( logits[..., :-1, :], targets[..., 1:] ). The first text token's prediction context is the SEP at position num_tx_positions; the SEP's logits predict text_token[0]. This is the correct alignment for next-token prediction over the reasoning_text. Args: batch_size: B. num_tx_positions: 64 in Phase 1. text_input_ids: (B, T_txt) int64. text_attention_mask: (B, T_txt) int — 1 at real positions. ignore_index: target value for positions excluded from loss. CrossEntropyLoss treats this as a no-op. Returns: (B, T_total) int64 — LM targets, with ignore_index at all non-text positions and at text padding positions. """ t_txt = text_input_ids.shape[1] t_total = num_tx_positions + 1 + t_txt device = text_input_ids.device targets = torch.full( (batch_size, t_total), fill_value=ignore_index, dtype=torch.long, device=device, ) # Fill in text positions text_start = num_tx_positions + 1 targets[:, text_start:] = text_input_ids # Mask padding positions to ignore_index padding_positions = text_attention_mask == 0 targets[:, text_start:][padding_positions] = ignore_index return targets