File size: 9,782 Bytes
b3112c7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 | """Mixed-modality batch builder: [tx pseudo-tokens, SEP, text tokens].
This is the load-bearing data primitive for every multi-surface head
in ADR 0003. It converts a batch of `(feature_ids, text_prompt,
labels)` triples into the combined embedding sequence the LFM2.5
backbone consumes, plus the auxiliary tensors each head needs.
The boundary between this module and the model:
- `MixedModalityBatch` is dtype/device-aware tensors plus index
bookkeeping. It does NOT call the backbone.
- The encoder + projector run separately to produce the tx
pseudo-tokens.
- The wrapper's `embed_text` / `embed_sep` produce the text and
SEP embeddings.
- The trainer is responsible for concatenating the three and
invoking the backbone's `forward_mixed`.
Why split this way: keeping the batch builder backbone-free means
unit tests for the batch shape don't need a 350M model loaded.
Layout (per batch element, after concat):
positions [0, 64) tx pseudo-tokens (from encoder+projector)
position 64 SEP token embedding
positions [65, 65 + T_txt) text token embeddings
Total sequence length: 64 + 1 + T_txt, where T_txt varies per batch
element. The collator pads text to the batch max and produces an
attention mask.
The LM training label layout (for the LM loss):
lm_targets : (B, T_total) int64 token ids
lm_target_mask: (B, T_total) bool — True at positions where the
LM loss is computed
Convention: lm_target_mask is True at the text positions only,
False at tx pseudo-token positions and at the SEP. The trainer
shifts targets by one position for the next-token prediction
objective.
"""
from __future__ import annotations
from dataclasses import dataclass
import torch
from transformers import PreTrainedTokenizerBase
@dataclass
class MixedModalityBatch:
"""A single batch ready for the multi-surface backbone.
Attributes:
feature_ids: (B, T_tx, F) int64 — passed to the encoder.
text_input_ids: (B, T_txt_max) int64 — padded text tokens. The
real length per element is given by `text_lengths`.
text_attention_mask: (B, T_txt_max) bool/int — 1 at real text
positions, 0 at padding.
text_lengths: (B,) int64 — real text length per element.
labels_probability: (B,) int64 — probability head class index
(0=unlikely, 1=ambiguous, 2=likely for dispute legitimacy).
May be None for batches that only train the attribution
head or the LM head.
labels_attribution: (B, T_tx) float — per-tx contribution
labels in {0.0, 1.0}. May be None.
labels_lm: (B, T_lm) int64 — reasoning_text target tokens for
teacher-forced LM loss. The trainer aligns these to the
text positions in the combined sequence and shifts by one
for next-token prediction. May be None.
head_target: which head's loss to compute on this batch. The
surface_trainer uses per-batch homogeneous-head sampling,
so each batch is tagged for one head. Values:
"probability" | "attribution" | "lm".
"""
feature_ids: torch.Tensor
text_input_ids: torch.Tensor
text_attention_mask: torch.Tensor
text_lengths: torch.Tensor
head_target: str
# (B,) int64 — position of the disputed transaction. The model
# adds a learnable marker embedding to the encoder output at this
# position so the downstream backbone, attribution head, and LM
# head all know which transaction is being asked about. Required
# by the dispute-legitimacy data contract per ADR 0003 §Surface 1.
disputed_idx: torch.Tensor | None = None
labels_probability: torch.Tensor | None = None
labels_attribution: torch.Tensor | None = None
labels_lm: torch.Tensor | None = None
def to(self, device: torch.device) -> "MixedModalityBatch":
"""Move all tensors to `device` without copying labels that are None."""
def _move(t: torch.Tensor | None) -> torch.Tensor | None:
return None if t is None else t.to(device, non_blocking=True)
return MixedModalityBatch(
feature_ids=self.feature_ids.to(device, non_blocking=True),
text_input_ids=self.text_input_ids.to(device, non_blocking=True),
text_attention_mask=self.text_attention_mask.to(device, non_blocking=True),
text_lengths=self.text_lengths.to(device, non_blocking=True),
head_target=self.head_target,
disputed_idx=_move(self.disputed_idx),
labels_probability=_move(self.labels_probability),
labels_attribution=_move(self.labels_attribution),
labels_lm=_move(self.labels_lm),
)
@property
def batch_size(self) -> int:
return int(self.feature_ids.shape[0])
@property
def num_tx_positions(self) -> int:
return int(self.feature_ids.shape[1])
@property
def t_total(self) -> int:
"""Total combined-sequence length used by the backbone.
= num_tx_positions + 1 (SEP) + T_txt_max.
"""
return self.num_tx_positions + 1 + int(self.text_input_ids.shape[1])
def tokenize_texts(
tokenizer: PreTrainedTokenizerBase,
texts: list[str],
max_length: int = 256,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Tokenize a list of strings with right-padding.
Args:
tokenizer: any HF-style tokenizer (the wrapper's `self.tokenizer`).
texts: list of `B` strings to tokenize.
max_length: cap; texts longer than this are truncated.
Returns:
(input_ids, attention_mask, lengths) where:
input_ids: (B, T_max) int64
attention_mask: (B, T_max) int64 (1 real, 0 pad)
lengths: (B,) int64 — real length per row
"""
enc = tokenizer(
texts,
padding=True,
truncation=True,
max_length=max_length,
return_tensors="pt",
)
input_ids = enc["input_ids"].long()
attention_mask = enc["attention_mask"].long()
lengths = attention_mask.sum(dim=1)
return input_ids, attention_mask, lengths
def build_combined_attention_mask(
batch_size: int,
num_tx_positions: int,
text_attention_mask: torch.Tensor,
device: torch.device,
) -> torch.Tensor:
"""Build the attention mask for the combined [tx, SEP, text] sequence.
The transaction pseudo-tokens are always real (no padding among
them in Phase 1: every sequence is fully 64-tx populated). The
SEP position is always real. Text positions follow the caller's
`text_attention_mask`.
Args:
batch_size: B.
num_tx_positions: 64 in Phase 1.
text_attention_mask: (B, T_txt) int — 1 at real positions.
device: target device for the resulting mask.
Returns:
(B, num_tx_positions + 1 + T_txt) int — combined attention
mask. dtype int64 for safety.
"""
# (B, num_tx_positions) all ones — every tx position is real
tx_mask = torch.ones(
(batch_size, num_tx_positions),
dtype=torch.long,
device=device,
)
# (B, 1) all ones — SEP is real
sep_mask = torch.ones(
(batch_size, 1),
dtype=torch.long,
device=device,
)
# (B, T_txt) text padding mask
txt_mask = text_attention_mask.to(device=device, dtype=torch.long)
# (B, num_tx_positions + 1 + T_txt)
return torch.cat([tx_mask, sep_mask, txt_mask], dim=1)
def build_lm_target_layout(
batch_size: int,
num_tx_positions: int,
text_input_ids: torch.Tensor,
text_attention_mask: torch.Tensor,
ignore_index: int = -100,
) -> torch.Tensor:
"""Build LM targets aligned to the combined sequence.
For teacher-forced next-token prediction over the text portion:
combined_inputs[pos] -> backbone -> hidden[pos] -> lm_logits[pos]
loss compares lm_logits[pos] against combined_inputs[pos + 1]
We construct a target tensor of length T_total where:
- positions [0, num_tx_positions) get ignore_index (no LM loss
on tx pseudo-tokens — they aren't in the vocab anyway).
- position num_tx_positions (SEP) gets ignore_index.
- positions [num_tx_positions + 1, T_total) hold the text token
ids, with padding positions also masked to ignore_index.
The trainer does the standard shift-by-one: loss = CE(
logits[..., :-1, :], targets[..., 1:]
). The first text token's prediction context is the SEP at
position num_tx_positions; the SEP's logits predict text_token[0].
This is the correct alignment for next-token prediction over the
reasoning_text.
Args:
batch_size: B.
num_tx_positions: 64 in Phase 1.
text_input_ids: (B, T_txt) int64.
text_attention_mask: (B, T_txt) int — 1 at real positions.
ignore_index: target value for positions excluded from loss.
CrossEntropyLoss treats this as a no-op.
Returns:
(B, T_total) int64 — LM targets, with ignore_index at all
non-text positions and at text padding positions.
"""
t_txt = text_input_ids.shape[1]
t_total = num_tx_positions + 1 + t_txt
device = text_input_ids.device
targets = torch.full(
(batch_size, t_total),
fill_value=ignore_index,
dtype=torch.long,
device=device,
)
# Fill in text positions
text_start = num_tx_positions + 1
targets[:, text_start:] = text_input_ids
# Mask padding positions to ignore_index
padding_positions = text_attention_mask == 0
targets[:, text_start:][padding_positions] = ignore_index
return targets
|