| """Streaming dataset — sequence packing and validation dataset.""" |
|
|
| from typing import Iterator, List, Dict, Optional |
|
|
| import torch |
| from torch.utils.data import IterableDataset, DataLoader |
|
|
| from llm_lab.config import DataConfig |
| from .tokenizer import Tokenizer |
|
|
|
|
| class PackedStreamingDataset(IterableDataset): |
| """Streaming + sequence packing dataset. |
| |
| Why sequence packing? |
| - Naive approach: truncate each document to max_seq_len with padding → wastes GPU |
| - Sequence packing: concatenate multiple documents to fill max_seq_len → 100% utilization |
| |
| How it works: |
| Doc1 (300 tokens) + Doc2 (1500 tokens) + Doc3 (248 tokens) = 2048 tokens |
| → [Doc1][EOS][Doc2][EOS][Doc3][EOS][... no padding, fits exactly] |
| |
| Why streaming? |
| - FineWeb-Edu 10B samples: tens of GB even when compressed |
| - Full download not feasible on Colab disk limit (~200GB) |
| - Streaming: reads from the network only as much as needed |
| |
| Notes for training: |
| - EOS token inserted at document boundaries so the model recognizes end-of-document |
| - EOS naturally serves as a boundary marker without cross-attention masking |
| """ |
|
|
| def __init__( |
| self, |
| tokenizer: Tokenizer, |
| config: DataConfig, |
| split: str = "train", |
| seed: int = 42, |
| ): |
| super().__init__() |
| self.tokenizer = tokenizer |
| self.config = config |
| self.split = split |
| self.seed = seed |
| self.max_seq_len = config.max_seq_len |
|
|
| def _load_dataset(self, num_shards: int = 1, shard_index: int = 0): |
| """Loads the HuggingFace dataset in streaming mode. |
| |
| Args: |
| num_shards: Total number of shards (= DataLoader num_workers) |
| shard_index: The shard index this worker is responsible for (0 ~ num_shards-1) |
| |
| Sharding principle: |
| With num_shards=4, the stream is split into 4 equal parts so each worker |
| processes a distinct 1/4. Shuffling is applied after sharding so there is |
| no document overlap between workers. |
| """ |
| from datasets import load_dataset |
|
|
| ds = load_dataset( |
| self.config.dataset_name, |
| name=self.config.dataset_subset, |
| split=self.split, |
| streaming=True, |
| trust_remote_code=True, |
| ) |
|
|
| |
| |
| if num_shards > 1: |
| ds = ds.shard(num_shards=num_shards, index=shard_index) |
|
|
| |
| ds = ds.shuffle(seed=self.seed, buffer_size=10_000) |
|
|
| return ds |
|
|
| def _tokenize_and_pack(self, dataset) -> Iterator[Dict[str, torch.Tensor]]: |
| """Tokenizes documents and packs them into sequences. |
| |
| Yields: |
| {"input_ids": (max_seq_len,), "targets": (max_seq_len,)} |
| |
| targets = input_ids shifted by one position: |
| input_ids: [A, B, C, D, E] |
| targets: [B, C, D, E, F] |
| → The model sees A and predicts B, sees B and predicts C, ... |
| """ |
| buffer: List[int] = [] |
|
|
| for example in dataset: |
| text = example[self.config.text_column] |
| if not text or not text.strip(): |
| continue |
|
|
| |
| token_ids = self.tokenizer.encode(text, add_special_tokens=False) |
|
|
| if not token_ids: |
| continue |
|
|
| |
| if self.config.use_eos_separator: |
| token_ids.append(self.tokenizer.eos_id) |
|
|
| |
| buffer.extend(token_ids) |
|
|
| |
| |
| while len(buffer) >= self.max_seq_len + 1: |
| |
| chunk = buffer[: self.max_seq_len + 1] |
| buffer = buffer[self.max_seq_len + 1 :] |
|
|
| |
| input_ids = torch.tensor(chunk[:-1], dtype=torch.long) |
| |
| targets = torch.tensor(chunk[1:], dtype=torch.long) |
|
|
| yield {"input_ids": input_ids, "targets": targets} |
|
|
| def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: |
| """Iterator called by DataLoader. |
| |
| Multi-worker support (full partitioning approach): |
| - Previous: all workers read the same stream with different seeds → possible document duplication |
| - Improved: ds.shard() splits the stream into num_workers parts → no document overlap between workers |
| |
| Example (num_workers=4, total N documents): |
| Worker 0: docs 0, 4, 8, 12, ... (N/4 docs) |
| Worker 1: docs 1, 5, 9, 13, ... (N/4 docs) |
| Worker 2: docs 2, 6, 10, 14, ... (N/4 docs) |
| Worker 3: docs 3, 7, 11, 15, ... (N/4 docs) |
| """ |
| worker_info = torch.utils.data.get_worker_info() |
|
|
| if worker_info is not None: |
| |
| num_shards = worker_info.num_workers |
| shard_index = worker_info.id |
| worker_seed = self.seed + worker_info.id |
| else: |
| |
| num_shards = 1 |
| shard_index = 0 |
| worker_seed = self.seed |
|
|
| self.seed = worker_seed |
| dataset = self._load_dataset(num_shards=num_shards, shard_index=shard_index) |
|
|
| return self._tokenize_and_pack(dataset) |
|
|
|
|
| class ValidationDataset: |
| """Validation dataset. |
| |
| Pre-fetches a fixed amount of data from the streaming dataset and stores it in memory. |
| Consistent data across evaluations is necessary for meaningful comparisons between epochs. |
| """ |
|
|
| def __init__( |
| self, |
| tokenizer: Tokenizer, |
| config: DataConfig, |
| num_samples: int = 100, |
| seed: int = 9999, |
| ): |
| self.tokenizer = tokenizer |
| self.config = config |
| self.num_samples = num_samples |
| self.samples: List[Dict[str, torch.Tensor]] = [] |
|
|
| self._prepare(seed) |
|
|
| def _prepare(self, seed: int): |
| """Pre-extracts validation samples from the dataset.""" |
| from datasets import load_dataset |
|
|
| print(f"[Validation] Preparing {self.num_samples} validation samples...") |
|
|
| ds = load_dataset( |
| self.config.dataset_name, |
| name=self.config.dataset_subset, |
| split=self.config.dataset_split, |
| streaming=True, |
| trust_remote_code=True, |
| ) |
| |
| ds = ds.shuffle(seed=seed, buffer_size=5_000) |
|
|
| buffer: List[int] = [] |
| count = 0 |
|
|
| for example in ds: |
| if count >= self.num_samples: |
| break |
|
|
| text = example[self.config.text_column] |
| if not text or not text.strip(): |
| continue |
|
|
| token_ids = self.tokenizer.encode(text, add_special_tokens=False) |
| if not token_ids: |
| continue |
|
|
| token_ids.append(self.tokenizer.eos_id) |
| buffer.extend(token_ids) |
|
|
| while len(buffer) >= self.config.max_seq_len + 1 and count < self.num_samples: |
| chunk = buffer[: self.config.max_seq_len + 1] |
| buffer = buffer[self.config.max_seq_len + 1 :] |
|
|
| self.samples.append({ |
| "input_ids": torch.tensor(chunk[:-1], dtype=torch.long), |
| "targets": torch.tensor(chunk[1:], dtype=torch.long), |
| }) |
| count += 1 |
|
|
| print(f"[Validation] {len(self.samples)} samples ready") |
|
|
| def get_dataloader(self, batch_size: int) -> DataLoader: |
| """Returns a validation DataLoader.""" |
| return DataLoader( |
| self.samples, |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=0, |
| collate_fn=_collate_fn, |
| ) |
|
|
|
|
| class MixedStreamingDataset(IterableDataset): |
| """Interleaves multiple streaming datasets by sampling weight. |
| |
| Used for Continual Pre-Training (CPT) to mix domain-specific data |
| (e.g., code) with general data to prevent catastrophic forgetting. |
| |
| How it works: |
| - Maintains one PackedStreamingDataset per source |
| - At each yield, randomly selects a source according to mix_weights |
| - Each source independently packs and tokenizes its own stream |
| |
| Example (Code CPT): |
| sources = [FineWeb-Edu (general), StarCoder (Python)] |
| weights = [0.2, 0.8] → 20% general, 80% code |
| """ |
|
|
| def __init__( |
| self, |
| tokenizer: "Tokenizer", |
| config: "DataConfig", |
| seed: int = 42, |
| ): |
| super().__init__() |
| self.tokenizer = tokenizer |
| self.config = config |
| self.seed = seed |
| self.max_seq_len = config.max_seq_len |
|
|
| |
| self.dataset_specs = [ |
| { |
| "name": config.dataset_name, |
| "subset": config.dataset_subset, |
| "split": config.dataset_split, |
| "text_column": config.text_column, |
| } |
| ] + list(config.mix_datasets) |
|
|
| self.weights = config.mix_weights |
| assert len(self.weights) == len(self.dataset_specs), ( |
| f"mix_weights length ({len(self.weights)}) must match " |
| f"number of datasets ({len(self.dataset_specs)})" |
| ) |
| assert abs(sum(self.weights) - 1.0) < 1e-6, ( |
| f"mix_weights must sum to 1.0, got {sum(self.weights)}" |
| ) |
|
|
| def _load_single_dataset(self, spec: dict, num_shards: int, shard_index: int, seed: int): |
| """Loads a single HuggingFace dataset in streaming mode.""" |
| from datasets import load_dataset |
|
|
| ds = load_dataset( |
| spec["name"], |
| name=spec["subset"], |
| split=spec["split"], |
| streaming=True, |
| trust_remote_code=True, |
| ) |
| if num_shards > 1: |
| ds = ds.shard(num_shards=num_shards, index=shard_index) |
| ds = ds.shuffle(seed=seed, buffer_size=10_000) |
| return ds |
|
|
| def _token_stream(self, dataset, text_column: str) -> Iterator[int]: |
| """Yields tokens one-by-one from a dataset stream (with EOS separators).""" |
| for example in dataset: |
| text = example[text_column] |
| if not text or not text.strip(): |
| continue |
| token_ids = self.tokenizer.encode(text, add_special_tokens=False) |
| if not token_ids: |
| continue |
| if self.config.use_eos_separator: |
| token_ids.append(self.tokenizer.eos_id) |
| yield from token_ids |
|
|
| def _pack_from_stream(self, token_iter: Iterator[int]) -> Iterator[List[int]]: |
| """Packs token stream into max_seq_len+1 chunks.""" |
| buffer: List[int] = [] |
| for tok in token_iter: |
| buffer.append(tok) |
| if len(buffer) >= self.max_seq_len + 1: |
| yield buffer[: self.max_seq_len + 1] |
| buffer = buffer[self.max_seq_len + 1 :] |
|
|
| def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: |
| import random |
|
|
| worker_info = torch.utils.data.get_worker_info() |
| if worker_info is not None: |
| num_shards = worker_info.num_workers |
| shard_index = worker_info.id |
| worker_seed = self.seed + worker_info.id |
| else: |
| num_shards = 1 |
| shard_index = 0 |
| worker_seed = self.seed |
|
|
| rng = random.Random(worker_seed) |
|
|
| |
| source_iters = [] |
| for i, spec in enumerate(self.dataset_specs): |
| ds = self._load_single_dataset( |
| spec, num_shards, shard_index, seed=worker_seed + i * 1000 |
| ) |
| token_iter = self._token_stream(ds, spec["text_column"]) |
| pack_iter = self._pack_from_stream(token_iter) |
| source_iters.append(pack_iter) |
|
|
| |
| buffers: List[Optional[List[int]]] = [None] * len(source_iters) |
| for i, it in enumerate(source_iters): |
| try: |
| buffers[i] = next(it) |
| except StopIteration: |
| buffers[i] = None |
|
|
| |
| while any(b is not None for b in buffers): |
| |
| active_weights = [ |
| w if buffers[i] is not None else 0.0 |
| for i, w in enumerate(self.weights) |
| ] |
| total = sum(active_weights) |
| if total == 0: |
| break |
|
|
| |
| r = rng.random() * total |
| cumulative = 0.0 |
| chosen = 0 |
| for i, w in enumerate(active_weights): |
| cumulative += w |
| if r <= cumulative: |
| chosen = i |
| break |
|
|
| chunk = buffers[chosen] |
| input_ids = torch.tensor(chunk[:-1], dtype=torch.long) |
| targets = torch.tensor(chunk[1:], dtype=torch.long) |
| yield {"input_ids": input_ids, "targets": targets} |
|
|
| |
| try: |
| buffers[chosen] = next(source_iters[chosen]) |
| except StopIteration: |
| buffers[chosen] = None |
|
|
|
|
| def _collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: |
| """Combines samples in a batch into a single tensor. |
| |
| Because of sequence packing, all samples have the same length (max_seq_len), |
| so no additional padding is needed. |
| """ |
| return { |
| "input_ids": torch.stack([s["input_ids"] for s in batch]), |
| "targets": torch.stack([s["targets"] for s in batch]), |
| } |
|
|