File size: 5,421 Bytes
083b138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""Data loader for the encoder subproject.

Reuses the parent's `FinetuneDataset` verbatim — the parent's tokenized arrays
at `data/synthetic/` are already shaped `(N, 64, 15)`, which is exactly what
the per-transaction encoder needs. This module is a thin orchestrator that
resolves data paths (via the `encoder/data/synthetic -> ../../data/synthetic`
symlink), builds train/val/test loaders, and exposes a fingerprint-verification
helper so accidental data regeneration breaks fast.

Why we don't define a new Dataset class: the encoder's input contract is
identical to the parent's (`(B, 64, 15) int64` plus fraud + amount_range
labels). The only thing that changes is what the model does with those
tokens. Keeping the Dataset shared guarantees apples-to-apples comparison.
"""

from __future__ import annotations

from pathlib import Path

import numpy as np
import torch
from torch.utils.data import DataLoader

from src.training.finetune import FinetuneDataset


def load_data_arrays(
    data_dir: Path | str,
) -> tuple[np.ndarray, np.ndarray, np.ndarray | None, dict[str, np.ndarray]]:
    """Load raw token arrays + split indices from `data_dir`.

    Returns:
        token_ids: (N, 64, 15) int16
        sequence_labels: (N,) int8 (fraud)
        ar_targets: (N,) int8 last-transaction amount_range, or None if file absent
        splits: dict with keys 'train' / 'val' / 'test', each int64 indices
    """
    data_dir = Path(data_dir)
    token_ids = np.load(data_dir / "token_ids.npy")
    sequence_labels = np.load(data_dir / "sequence_labels.npy")
    splits = dict(np.load(data_dir / "split_indices.npz"))

    ar_path = data_dir / "amount_range_labels.npy"
    ar_targets: np.ndarray | None = None
    if ar_path.exists():
        # Parent stores per-transaction amount_range as (N, 64). The head
        # targets the LAST transaction's amount bucket, so we slice [:, -1].
        ar_all = np.load(ar_path)
        ar_targets = ar_all[:, -1]

    return token_ids, sequence_labels, ar_targets, splits


def verify_fingerprint(data_dir: Path | str, expected: str) -> None:
    """Raise if data fingerprint differs from `expected`.

    Catches the silent failure where data has been regenerated under us —
    in which case head-to-head comparison numbers against the parent's
    already-published eval.md.json would not be apples-to-apples.
    """
    fp_path = Path(data_dir) / "fingerprint.txt"
    if not fp_path.exists():
        raise FileNotFoundError(
            f"No fingerprint.txt at {fp_path}. Encoder relies on the parent's "
            f"data/synthetic/ for head-to-head; regenerate via parent's "
            f"`python -m scripts.generate` if missing.",
        )
    actual = fp_path.read_text().strip()
    if actual != expected:
        raise ValueError(
            f"Data fingerprint mismatch:\n"
            f"  expected: {expected}\n"
            f"  actual:   {actual}\n"
            f"Data has been regenerated since this config was pinned. Head-to-head "
            f"comparison against the parent's eval.md.json would not be valid.",
        )


def build_loaders(
    data_dir: Path | str,
    batch_size: int = 32,
    label_fraction: float = 1.0,
    seed: int = 42,
    num_workers: int = 4,
) -> tuple[DataLoader, DataLoader, DataLoader]:
    """Build train/val/test DataLoaders.

    Args:
        data_dir: path to the tokenized synthetic arrays (symlink to parent OK).
        batch_size: applied to all three loaders.
        label_fraction: subsample fraction of `train` indices for the
            label-scarcity sweep (1.0 = full, 0.10 = 10%, 0.01 = 1%). Val and
            test are never subsampled.
        seed: RNG seed for the train-subset selection. Same seed as the parent's
            scarcity protocol so the head-to-head selects the same training
            subsets across both architectures.
        num_workers: DataLoader worker count for train. Val/test use half.

    Returns:
        (train_loader, val_loader, test_loader)
    """
    token_ids, sequence_labels, ar_targets, splits = load_data_arrays(data_dir)

    train_indices = splits["train"]
    if label_fraction < 1.0:
        # np.random.RandomState (not Generator) to match the parent's
        # subsampling RNG exactly. Same seed -> identical train subset.
        rng = np.random.RandomState(seed)
        n_keep = max(1, int(len(train_indices) * label_fraction))
        train_indices = rng.choice(train_indices, n_keep, replace=False)

    train_ds = FinetuneDataset(token_ids, sequence_labels, train_indices, ar_targets)
    val_ds = FinetuneDataset(token_ids, sequence_labels, splits["val"], ar_targets)
    test_ds = FinetuneDataset(token_ids, sequence_labels, splits["test"], ar_targets)

    train_loader = DataLoader(
        train_ds,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=torch.cuda.is_available(),
        drop_last=True,
    )
    eval_workers = max(0, num_workers // 2)
    val_loader = DataLoader(
        val_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=eval_workers,
        pin_memory=torch.cuda.is_available(),
    )
    test_loader = DataLoader(
        test_ds,
        batch_size=batch_size,
        shuffle=False,
        num_workers=eval_workers,
        pin_memory=torch.cuda.is_available(),
    )
    return train_loader, val_loader, test_loader