File size: 9,782 Bytes
b3112c7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
"""Mixed-modality batch builder: [tx pseudo-tokens, SEP, text tokens].

This is the load-bearing data primitive for every multi-surface head
in ADR 0003. It converts a batch of `(feature_ids, text_prompt,
labels)` triples into the combined embedding sequence the LFM2.5
backbone consumes, plus the auxiliary tensors each head needs.

The boundary between this module and the model:

  - `MixedModalityBatch` is dtype/device-aware tensors plus index
    bookkeeping. It does NOT call the backbone.
  - The encoder + projector run separately to produce the tx
    pseudo-tokens.
  - The wrapper's `embed_text` / `embed_sep` produce the text and
    SEP embeddings.
  - The trainer is responsible for concatenating the three and
    invoking the backbone's `forward_mixed`.

Why split this way: keeping the batch builder backbone-free means
unit tests for the batch shape don't need a 350M model loaded.

Layout (per batch element, after concat):

  positions [0, 64)              tx pseudo-tokens   (from encoder+projector)
  position  64                   SEP token embedding
  positions [65, 65 + T_txt)     text token embeddings

Total sequence length: 64 + 1 + T_txt, where T_txt varies per batch
element. The collator pads text to the batch max and produces an
attention mask.

The LM training label layout (for the LM loss):

  lm_targets    : (B, T_total)    int64 token ids
  lm_target_mask: (B, T_total)    bool — True at positions where the
                                  LM loss is computed

Convention: lm_target_mask is True at the text positions only,
False at tx pseudo-token positions and at the SEP. The trainer
shifts targets by one position for the next-token prediction
objective.
"""

from __future__ import annotations

from dataclasses import dataclass

import torch
from transformers import PreTrainedTokenizerBase


@dataclass
class MixedModalityBatch:
    """A single batch ready for the multi-surface backbone.

    Attributes:
        feature_ids: (B, T_tx, F) int64 — passed to the encoder.
        text_input_ids: (B, T_txt_max) int64 — padded text tokens. The
            real length per element is given by `text_lengths`.
        text_attention_mask: (B, T_txt_max) bool/int — 1 at real text
            positions, 0 at padding.
        text_lengths: (B,) int64 — real text length per element.
        labels_probability: (B,) int64 — probability head class index
            (0=unlikely, 1=ambiguous, 2=likely for dispute legitimacy).
            May be None for batches that only train the attribution
            head or the LM head.
        labels_attribution: (B, T_tx) float — per-tx contribution
            labels in {0.0, 1.0}. May be None.
        labels_lm: (B, T_lm) int64 — reasoning_text target tokens for
            teacher-forced LM loss. The trainer aligns these to the
            text positions in the combined sequence and shifts by one
            for next-token prediction. May be None.
        head_target: which head's loss to compute on this batch. The
            surface_trainer uses per-batch homogeneous-head sampling,
            so each batch is tagged for one head. Values:
            "probability" | "attribution" | "lm".
    """

    feature_ids: torch.Tensor
    text_input_ids: torch.Tensor
    text_attention_mask: torch.Tensor
    text_lengths: torch.Tensor
    head_target: str
    # (B,) int64 — position of the disputed transaction. The model
    # adds a learnable marker embedding to the encoder output at this
    # position so the downstream backbone, attribution head, and LM
    # head all know which transaction is being asked about. Required
    # by the dispute-legitimacy data contract per ADR 0003 §Surface 1.
    disputed_idx: torch.Tensor | None = None
    labels_probability: torch.Tensor | None = None
    labels_attribution: torch.Tensor | None = None
    labels_lm: torch.Tensor | None = None

    def to(self, device: torch.device) -> "MixedModalityBatch":
        """Move all tensors to `device` without copying labels that are None."""

        def _move(t: torch.Tensor | None) -> torch.Tensor | None:
            return None if t is None else t.to(device, non_blocking=True)

        return MixedModalityBatch(
            feature_ids=self.feature_ids.to(device, non_blocking=True),
            text_input_ids=self.text_input_ids.to(device, non_blocking=True),
            text_attention_mask=self.text_attention_mask.to(device, non_blocking=True),
            text_lengths=self.text_lengths.to(device, non_blocking=True),
            head_target=self.head_target,
            disputed_idx=_move(self.disputed_idx),
            labels_probability=_move(self.labels_probability),
            labels_attribution=_move(self.labels_attribution),
            labels_lm=_move(self.labels_lm),
        )

    @property
    def batch_size(self) -> int:
        return int(self.feature_ids.shape[0])

    @property
    def num_tx_positions(self) -> int:
        return int(self.feature_ids.shape[1])

    @property
    def t_total(self) -> int:
        """Total combined-sequence length used by the backbone.

        = num_tx_positions + 1 (SEP) + T_txt_max.
        """
        return self.num_tx_positions + 1 + int(self.text_input_ids.shape[1])


def tokenize_texts(
    tokenizer: PreTrainedTokenizerBase,
    texts: list[str],
    max_length: int = 256,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """Tokenize a list of strings with right-padding.

    Args:
        tokenizer: any HF-style tokenizer (the wrapper's `self.tokenizer`).
        texts: list of `B` strings to tokenize.
        max_length: cap; texts longer than this are truncated.

    Returns:
        (input_ids, attention_mask, lengths) where:
            input_ids:      (B, T_max) int64
            attention_mask: (B, T_max) int64 (1 real, 0 pad)
            lengths:        (B,)        int64 — real length per row
    """
    enc = tokenizer(
        texts,
        padding=True,
        truncation=True,
        max_length=max_length,
        return_tensors="pt",
    )
    input_ids = enc["input_ids"].long()
    attention_mask = enc["attention_mask"].long()
    lengths = attention_mask.sum(dim=1)
    return input_ids, attention_mask, lengths


def build_combined_attention_mask(
    batch_size: int,
    num_tx_positions: int,
    text_attention_mask: torch.Tensor,
    device: torch.device,
) -> torch.Tensor:
    """Build the attention mask for the combined [tx, SEP, text] sequence.

    The transaction pseudo-tokens are always real (no padding among
    them in Phase 1: every sequence is fully 64-tx populated). The
    SEP position is always real. Text positions follow the caller's
    `text_attention_mask`.

    Args:
        batch_size: B.
        num_tx_positions: 64 in Phase 1.
        text_attention_mask: (B, T_txt) int — 1 at real positions.
        device: target device for the resulting mask.

    Returns:
        (B, num_tx_positions + 1 + T_txt) int — combined attention
        mask. dtype int64 for safety.
    """
    # (B, num_tx_positions) all ones — every tx position is real
    tx_mask = torch.ones(
        (batch_size, num_tx_positions),
        dtype=torch.long,
        device=device,
    )
    # (B, 1) all ones — SEP is real
    sep_mask = torch.ones(
        (batch_size, 1),
        dtype=torch.long,
        device=device,
    )
    # (B, T_txt) text padding mask
    txt_mask = text_attention_mask.to(device=device, dtype=torch.long)
    # (B, num_tx_positions + 1 + T_txt)
    return torch.cat([tx_mask, sep_mask, txt_mask], dim=1)


def build_lm_target_layout(
    batch_size: int,
    num_tx_positions: int,
    text_input_ids: torch.Tensor,
    text_attention_mask: torch.Tensor,
    ignore_index: int = -100,
) -> torch.Tensor:
    """Build LM targets aligned to the combined sequence.

    For teacher-forced next-token prediction over the text portion:

        combined_inputs[pos] -> backbone -> hidden[pos] -> lm_logits[pos]
        loss compares lm_logits[pos] against combined_inputs[pos + 1]

    We construct a target tensor of length T_total where:
      - positions [0, num_tx_positions) get ignore_index (no LM loss
        on tx pseudo-tokens — they aren't in the vocab anyway).
      - position num_tx_positions (SEP) gets ignore_index.
      - positions [num_tx_positions + 1, T_total) hold the text token
        ids, with padding positions also masked to ignore_index.

    The trainer does the standard shift-by-one: loss = CE(
        logits[..., :-1, :], targets[..., 1:]
    ). The first text token's prediction context is the SEP at
    position num_tx_positions; the SEP's logits predict text_token[0].
    This is the correct alignment for next-token prediction over the
    reasoning_text.

    Args:
        batch_size: B.
        num_tx_positions: 64 in Phase 1.
        text_input_ids: (B, T_txt) int64.
        text_attention_mask: (B, T_txt) int — 1 at real positions.
        ignore_index: target value for positions excluded from loss.
            CrossEntropyLoss treats this as a no-op.

    Returns:
        (B, T_total) int64 — LM targets, with ignore_index at all
        non-text positions and at text padding positions.
    """
    t_txt = text_input_ids.shape[1]
    t_total = num_tx_positions + 1 + t_txt
    device = text_input_ids.device

    targets = torch.full(
        (batch_size, t_total),
        fill_value=ignore_index,
        dtype=torch.long,
        device=device,
    )
    # Fill in text positions
    text_start = num_tx_positions + 1
    targets[:, text_start:] = text_input_ids
    # Mask padding positions to ignore_index
    padding_positions = text_attention_mask == 0
    targets[:, text_start:][padding_positions] = ignore_index
    return targets