| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
| from typing import Iterator |
|
|
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset, IterableDataset |
|
|
| from sllm.utils import ensure_dir |
|
|
|
|
| class TokenShardWriter: |
| def __init__(self, output_dir: str | Path, prefix: str, shard_size_tokens: int) -> None: |
| self.output_dir = ensure_dir(output_dir) |
| self.prefix = prefix |
| self.shard_size_tokens = shard_size_tokens |
| self.buffer: list[int] = [] |
| self.shard_index = 0 |
| self.shards: list[dict] = [] |
|
|
| def add_tokens(self, tokens: list[int]) -> None: |
| self.buffer.extend(tokens) |
| while len(self.buffer) >= self.shard_size_tokens: |
| chunk = self.buffer[: self.shard_size_tokens] |
| self.buffer = self.buffer[self.shard_size_tokens :] |
| self._write_chunk(chunk) |
|
|
| def finalize(self) -> list[dict]: |
| if self.buffer: |
| self._write_chunk(self.buffer) |
| self.buffer = [] |
| manifest_path = self.output_dir / f"{self.prefix}_manifest.json" |
| with manifest_path.open("w", encoding="utf-8") as handle: |
| json.dump(self.shards, handle, indent=2, ensure_ascii=False) |
| return self.shards |
|
|
| def _write_chunk(self, chunk: list[int]) -> None: |
| shard_name = f"{self.prefix}_{self.shard_index:05d}.bin" |
| shard_path = self.output_dir / shard_name |
| array = np.asarray(chunk, dtype=np.uint16) |
| with shard_path.open("wb") as handle: |
| array.tofile(handle) |
| self.shards.append( |
| { |
| "path": shard_name, |
| "num_tokens": int(array.shape[0]), |
| "dtype": "uint16", |
| } |
| ) |
| self.shard_index += 1 |
|
|
|
|
| class SFTShardWriter: |
| def __init__(self, output_dir: str | Path, prefix: str, seq_len: int) -> None: |
| self.output_dir = ensure_dir(output_dir) |
| self.prefix = prefix |
| self.seq_len = seq_len |
| self.num_examples = 0 |
| self.input_path = self.output_dir / f"{self.prefix}_input_ids.bin" |
| self.label_path = self.output_dir / f"{self.prefix}_labels.bin" |
| self.input_handle = self.input_path.open("wb") |
| self.label_handle = self.label_path.open("wb") |
|
|
| def add_example(self, input_ids: list[int], labels: list[int]) -> None: |
| if len(input_ids) != self.seq_len or len(labels) != self.seq_len: |
| raise ValueError("Packed SFT example must match fixed seq_len.") |
| np.asarray(input_ids, dtype=np.uint16).tofile(self.input_handle) |
| np.asarray(labels, dtype=np.int32).tofile(self.label_handle) |
| self.num_examples += 1 |
|
|
| def finalize(self) -> dict: |
| self.input_handle.close() |
| self.label_handle.close() |
| if self.num_examples == 0: |
| raise RuntimeError("No SFT examples were written.") |
| metadata = { |
| "num_examples": self.num_examples, |
| "seq_len": self.seq_len, |
| "input_ids_path": self.input_path.name, |
| "labels_path": self.label_path.name, |
| } |
| with (self.output_dir / f"{self.prefix}_metadata.json").open("w", encoding="utf-8") as handle: |
| json.dump(metadata, handle, indent=2, ensure_ascii=False) |
| return metadata |
|
|
|
|
| def load_shard_manifest(data_dir: str | Path, split: str) -> list[dict]: |
| data_dir = Path(data_dir) |
| manifest_paths = sorted(data_dir.glob(f"{split}_manifest.json")) |
| if not manifest_paths and data_dir.name == split: |
| manifest_paths = sorted(data_dir.glob("*_manifest.json")) |
| if not manifest_paths and (data_dir / split).exists(): |
| manifest_paths = sorted((data_dir / split).glob("*_manifest.json")) |
| if not manifest_paths: |
| raise FileNotFoundError(f"Shard manifest not found in {data_dir}.") |
|
|
| shards: list[dict] = [] |
| for manifest_path in manifest_paths: |
| with manifest_path.open("r", encoding="utf-8") as handle: |
| items = json.load(handle) |
| for item in items: |
| item["absolute_path"] = str((manifest_path.parent / item["path"]).resolve()) |
| shards.append(item) |
| if not shards: |
| raise RuntimeError(f"Shard manifest {manifest_paths} is empty.") |
| return shards |
|
|
|
|
| class RandomTokenDataset(IterableDataset): |
| def __init__( |
| self, |
| data_dir: str | Path, |
| split: str, |
| seq_len: int, |
| seed: int = 42, |
| ) -> None: |
| super().__init__() |
| self.seq_len = seq_len |
| self.seed = seed |
| self.shards = load_shard_manifest(data_dir, split) |
| self.arrays = [np.memmap(item["absolute_path"], dtype=np.uint16, mode="r") for item in self.shards] |
| capacities = [max(0, int(item["num_tokens"]) - seq_len - 1) for item in self.shards] |
| valid_pairs = [(item, array, capacity) for item, array, capacity in zip(self.shards, self.arrays, capacities) if capacity > 0] |
| if not valid_pairs: |
| raise RuntimeError("No shard contains enough tokens for the selected sequence length.") |
| self.shards, self.arrays, capacities = map(list, zip(*valid_pairs)) |
| weights = np.asarray(capacities, dtype=np.float64) |
| self.probabilities = weights / weights.sum() |
| self.capacities = capacities |
|
|
| def __iter__(self) -> Iterator[dict[str, torch.Tensor]]: |
| worker_info = torch.utils.data.get_worker_info() |
| worker_id = worker_info.id if worker_info is not None else 0 |
| rng = np.random.default_rng(self.seed + worker_id) |
|
|
| while True: |
| shard_index = int(rng.choice(len(self.arrays), p=self.probabilities)) |
| capacity = self.capacities[shard_index] |
| start = int(rng.integers(0, capacity)) |
| array = self.arrays[shard_index] |
| window = np.asarray(array[start : start + self.seq_len + 1], dtype=np.int64) |
| input_ids = torch.from_numpy(window[:-1].copy()).long() |
| labels = torch.from_numpy(window[1:].copy()).long() |
| attention_mask = torch.ones(self.seq_len, dtype=torch.long) |
| yield { |
| "input_ids": input_ids, |
| "labels": labels, |
| "attention_mask": attention_mask, |
| } |
|
|
|
|
| class SequentialEvalDataset(IterableDataset): |
| def __init__( |
| self, |
| data_dir: str | Path, |
| split: str, |
| seq_len: int, |
| max_batches: int, |
| ) -> None: |
| super().__init__() |
| self.seq_len = seq_len |
| self.max_batches = max_batches |
| self.shards = load_shard_manifest(data_dir, split) |
| self.arrays = [np.memmap(item["absolute_path"], dtype=np.uint16, mode="r") for item in self.shards] |
|
|
| def __iter__(self) -> Iterator[dict[str, torch.Tensor]]: |
| yielded = 0 |
| for array in self.arrays: |
| max_start = len(array) - self.seq_len - 1 |
| if max_start <= 0: |
| continue |
| for start in range(0, max_start, self.seq_len): |
| if yielded >= self.max_batches: |
| return |
| window = np.asarray(array[start : start + self.seq_len + 1], dtype=np.int64) |
| if len(window) < self.seq_len + 1: |
| break |
| input_ids = torch.from_numpy(window[:-1].copy()).long() |
| labels = torch.from_numpy(window[1:].copy()).long() |
| attention_mask = torch.ones(self.seq_len, dtype=torch.long) |
| yield { |
| "input_ids": input_ids, |
| "labels": labels, |
| "attention_mask": attention_mask, |
| } |
| yielded += 1 |
|
|
|
|
| class FixedSFTDataset(Dataset): |
| def __init__(self, dataset_dir: str | Path, split: str) -> None: |
| dataset_dir = Path(dataset_dir) |
| metadata_path = dataset_dir / f"{split}_metadata.json" |
| if not metadata_path.exists(): |
| raise FileNotFoundError(f"Metadata file not found: {metadata_path}") |
| with metadata_path.open("r", encoding="utf-8") as handle: |
| metadata = json.load(handle) |
|
|
| self.seq_len = int(metadata["seq_len"]) |
| self.num_examples = int(metadata["num_examples"]) |
| self.input_ids = np.memmap( |
| dataset_dir / metadata["input_ids_path"], |
| dtype=np.uint16, |
| mode="r", |
| shape=(self.num_examples, self.seq_len), |
| ) |
| self.labels = np.memmap( |
| dataset_dir / metadata["labels_path"], |
| dtype=np.int32, |
| mode="r", |
| shape=(self.num_examples, self.seq_len), |
| ) |
|
|
| def __len__(self) -> int: |
| return self.num_examples |
|
|
| def __getitem__(self, index: int) -> dict[str, torch.Tensor]: |
| input_ids = torch.from_numpy(np.asarray(self.input_ids[index], dtype=np.int64).copy()).long() |
| labels = torch.from_numpy(np.asarray(self.labels[index], dtype=np.int64).copy()).long() |
| attention_mask = (input_ids != 0).long() |
| return { |
| "input_ids": input_ids, |
| "labels": labels, |
| "attention_mask": attention_mask, |
| } |
|
|