levossadtchi's picture
Add files using upload-large-folder tool
355eea1 verified
from __future__ import annotations
import json
from pathlib import Path
from typing import Iterator
import numpy as np
import torch
from torch.utils.data import Dataset, IterableDataset
from sllm.utils import ensure_dir
class TokenShardWriter:
def __init__(self, output_dir: str | Path, prefix: str, shard_size_tokens: int) -> None:
self.output_dir = ensure_dir(output_dir)
self.prefix = prefix
self.shard_size_tokens = shard_size_tokens
self.buffer: list[int] = []
self.shard_index = 0
self.shards: list[dict] = []
def add_tokens(self, tokens: list[int]) -> None:
self.buffer.extend(tokens)
while len(self.buffer) >= self.shard_size_tokens:
chunk = self.buffer[: self.shard_size_tokens]
self.buffer = self.buffer[self.shard_size_tokens :]
self._write_chunk(chunk)
def finalize(self) -> list[dict]:
if self.buffer:
self._write_chunk(self.buffer)
self.buffer = []
manifest_path = self.output_dir / f"{self.prefix}_manifest.json"
with manifest_path.open("w", encoding="utf-8") as handle:
json.dump(self.shards, handle, indent=2, ensure_ascii=False)
return self.shards
def _write_chunk(self, chunk: list[int]) -> None:
shard_name = f"{self.prefix}_{self.shard_index:05d}.bin"
shard_path = self.output_dir / shard_name
array = np.asarray(chunk, dtype=np.uint16)
with shard_path.open("wb") as handle:
array.tofile(handle)
self.shards.append(
{
"path": shard_name,
"num_tokens": int(array.shape[0]),
"dtype": "uint16",
}
)
self.shard_index += 1
class SFTShardWriter:
def __init__(self, output_dir: str | Path, prefix: str, seq_len: int) -> None:
self.output_dir = ensure_dir(output_dir)
self.prefix = prefix
self.seq_len = seq_len
self.num_examples = 0
self.input_path = self.output_dir / f"{self.prefix}_input_ids.bin"
self.label_path = self.output_dir / f"{self.prefix}_labels.bin"
self.input_handle = self.input_path.open("wb")
self.label_handle = self.label_path.open("wb")
def add_example(self, input_ids: list[int], labels: list[int]) -> None:
if len(input_ids) != self.seq_len or len(labels) != self.seq_len:
raise ValueError("Packed SFT example must match fixed seq_len.")
np.asarray(input_ids, dtype=np.uint16).tofile(self.input_handle)
np.asarray(labels, dtype=np.int32).tofile(self.label_handle)
self.num_examples += 1
def finalize(self) -> dict:
self.input_handle.close()
self.label_handle.close()
if self.num_examples == 0:
raise RuntimeError("No SFT examples were written.")
metadata = {
"num_examples": self.num_examples,
"seq_len": self.seq_len,
"input_ids_path": self.input_path.name,
"labels_path": self.label_path.name,
}
with (self.output_dir / f"{self.prefix}_metadata.json").open("w", encoding="utf-8") as handle:
json.dump(metadata, handle, indent=2, ensure_ascii=False)
return metadata
def load_shard_manifest(data_dir: str | Path, split: str) -> list[dict]:
data_dir = Path(data_dir)
manifest_paths = sorted(data_dir.glob(f"{split}_manifest.json"))
if not manifest_paths and data_dir.name == split:
manifest_paths = sorted(data_dir.glob("*_manifest.json"))
if not manifest_paths and (data_dir / split).exists():
manifest_paths = sorted((data_dir / split).glob("*_manifest.json"))
if not manifest_paths:
raise FileNotFoundError(f"Shard manifest not found in {data_dir}.")
shards: list[dict] = []
for manifest_path in manifest_paths:
with manifest_path.open("r", encoding="utf-8") as handle:
items = json.load(handle)
for item in items:
item["absolute_path"] = str((manifest_path.parent / item["path"]).resolve())
shards.append(item)
if not shards:
raise RuntimeError(f"Shard manifest {manifest_paths} is empty.")
return shards
class RandomTokenDataset(IterableDataset):
def __init__(
self,
data_dir: str | Path,
split: str,
seq_len: int,
seed: int = 42,
) -> None:
super().__init__()
self.seq_len = seq_len
self.seed = seed
self.shards = load_shard_manifest(data_dir, split)
self.arrays = [np.memmap(item["absolute_path"], dtype=np.uint16, mode="r") for item in self.shards]
capacities = [max(0, int(item["num_tokens"]) - seq_len - 1) for item in self.shards]
valid_pairs = [(item, array, capacity) for item, array, capacity in zip(self.shards, self.arrays, capacities) if capacity > 0]
if not valid_pairs:
raise RuntimeError("No shard contains enough tokens for the selected sequence length.")
self.shards, self.arrays, capacities = map(list, zip(*valid_pairs))
weights = np.asarray(capacities, dtype=np.float64)
self.probabilities = weights / weights.sum()
self.capacities = capacities
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
worker_info = torch.utils.data.get_worker_info()
worker_id = worker_info.id if worker_info is not None else 0
rng = np.random.default_rng(self.seed + worker_id)
while True:
shard_index = int(rng.choice(len(self.arrays), p=self.probabilities))
capacity = self.capacities[shard_index]
start = int(rng.integers(0, capacity))
array = self.arrays[shard_index]
window = np.asarray(array[start : start + self.seq_len + 1], dtype=np.int64)
input_ids = torch.from_numpy(window[:-1].copy()).long()
labels = torch.from_numpy(window[1:].copy()).long()
attention_mask = torch.ones(self.seq_len, dtype=torch.long)
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
}
class SequentialEvalDataset(IterableDataset):
def __init__(
self,
data_dir: str | Path,
split: str,
seq_len: int,
max_batches: int,
) -> None:
super().__init__()
self.seq_len = seq_len
self.max_batches = max_batches
self.shards = load_shard_manifest(data_dir, split)
self.arrays = [np.memmap(item["absolute_path"], dtype=np.uint16, mode="r") for item in self.shards]
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
yielded = 0
for array in self.arrays:
max_start = len(array) - self.seq_len - 1
if max_start <= 0:
continue
for start in range(0, max_start, self.seq_len):
if yielded >= self.max_batches:
return
window = np.asarray(array[start : start + self.seq_len + 1], dtype=np.int64)
if len(window) < self.seq_len + 1:
break
input_ids = torch.from_numpy(window[:-1].copy()).long()
labels = torch.from_numpy(window[1:].copy()).long()
attention_mask = torch.ones(self.seq_len, dtype=torch.long)
yield {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
}
yielded += 1
class FixedSFTDataset(Dataset):
def __init__(self, dataset_dir: str | Path, split: str) -> None:
dataset_dir = Path(dataset_dir)
metadata_path = dataset_dir / f"{split}_metadata.json"
if not metadata_path.exists():
raise FileNotFoundError(f"Metadata file not found: {metadata_path}")
with metadata_path.open("r", encoding="utf-8") as handle:
metadata = json.load(handle)
self.seq_len = int(metadata["seq_len"])
self.num_examples = int(metadata["num_examples"])
self.input_ids = np.memmap(
dataset_dir / metadata["input_ids_path"],
dtype=np.uint16,
mode="r",
shape=(self.num_examples, self.seq_len),
)
self.labels = np.memmap(
dataset_dir / metadata["labels_path"],
dtype=np.int32,
mode="r",
shape=(self.num_examples, self.seq_len),
)
def __len__(self) -> int:
return self.num_examples
def __getitem__(self, index: int) -> dict[str, torch.Tensor]:
input_ids = torch.from_numpy(np.asarray(self.input_ids[index], dtype=np.int64).copy()).long()
labels = torch.from_numpy(np.asarray(self.labels[index], dtype=np.int64).copy()).long()
attention_mask = (input_ids != 0).long()
return {
"input_ids": input_ids,
"labels": labels,
"attention_mask": attention_mask,
}