"""Streaming dataset — sequence packing and validation dataset.""" from typing import Iterator, List, Dict, Optional import torch from torch.utils.data import IterableDataset, DataLoader from llm_lab.config import DataConfig from .tokenizer import Tokenizer class PackedStreamingDataset(IterableDataset): """Streaming + sequence packing dataset. Why sequence packing? - Naive approach: truncate each document to max_seq_len with padding → wastes GPU - Sequence packing: concatenate multiple documents to fill max_seq_len → 100% utilization How it works: Doc1 (300 tokens) + Doc2 (1500 tokens) + Doc3 (248 tokens) = 2048 tokens → [Doc1][EOS][Doc2][EOS][Doc3][EOS][... no padding, fits exactly] Why streaming? - FineWeb-Edu 10B samples: tens of GB even when compressed - Full download not feasible on Colab disk limit (~200GB) - Streaming: reads from the network only as much as needed Notes for training: - EOS token inserted at document boundaries so the model recognizes end-of-document - EOS naturally serves as a boundary marker without cross-attention masking """ def __init__( self, tokenizer: Tokenizer, config: DataConfig, split: str = "train", seed: int = 42, ): super().__init__() self.tokenizer = tokenizer self.config = config self.split = split self.seed = seed self.max_seq_len = config.max_seq_len def _load_dataset(self, num_shards: int = 1, shard_index: int = 0): """Loads the HuggingFace dataset in streaming mode. Args: num_shards: Total number of shards (= DataLoader num_workers) shard_index: The shard index this worker is responsible for (0 ~ num_shards-1) Sharding principle: With num_shards=4, the stream is split into 4 equal parts so each worker processes a distinct 1/4. Shuffling is applied after sharding so there is no document overlap between workers. """ from datasets import load_dataset ds = load_dataset( self.config.dataset_name, name=self.config.dataset_subset, split=self.split, streaming=True, # Key: streaming mode trust_remote_code=True, ) # Full partitioning (sharding): worker i processes only 1/num_shards of the stream # Must be applied before shuffling so each worker has a non-overlapping set of documents if num_shards > 1: ds = ds.shard(num_shards=num_shards, index=shard_index) # Shuffle (approximate buffer-based shuffle in streaming mode) ds = ds.shuffle(seed=self.seed, buffer_size=10_000) return ds def _tokenize_and_pack(self, dataset) -> Iterator[Dict[str, torch.Tensor]]: """Tokenizes documents and packs them into sequences. Yields: {"input_ids": (max_seq_len,), "targets": (max_seq_len,)} targets = input_ids shifted by one position: input_ids: [A, B, C, D, E] targets: [B, C, D, E, F] → The model sees A and predicts B, sees B and predicts C, ... """ buffer: List[int] = [] # Token buffer for example in dataset: text = example[self.config.text_column] if not text or not text.strip(): continue # Tokenize (without special tokens) token_ids = self.tokenizer.encode(text, add_special_tokens=False) if not token_ids: continue # Append EOS token (marks document boundary) if self.config.use_eos_separator: token_ids.append(self.tokenizer.eos_id) # Add to buffer buffer.extend(token_ids) # Generate sequences once the buffer is full enough # +1 is needed to generate targets (input + next token) while len(buffer) >= self.max_seq_len + 1: # Extract max_seq_len + 1 tokens chunk = buffer[: self.max_seq_len + 1] buffer = buffer[self.max_seq_len + 1 :] # input_ids: from the first to the second-to-last token input_ids = torch.tensor(chunk[:-1], dtype=torch.long) # targets: from the second to the last token (shifted by one) targets = torch.tensor(chunk[1:], dtype=torch.long) yield {"input_ids": input_ids, "targets": targets} def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: """Iterator called by DataLoader. Multi-worker support (full partitioning approach): - Previous: all workers read the same stream with different seeds → possible document duplication - Improved: ds.shard() splits the stream into num_workers parts → no document overlap between workers Example (num_workers=4, total N documents): Worker 0: docs 0, 4, 8, 12, ... (N/4 docs) Worker 1: docs 1, 5, 9, 13, ... (N/4 docs) Worker 2: docs 2, 6, 10, 14, ... (N/4 docs) Worker 3: docs 3, 7, 11, 15, ... (N/4 docs) """ worker_info = torch.utils.data.get_worker_info() if worker_info is not None: # Full partitioning: assign a shard per worker + independent shuffle seed num_shards = worker_info.num_workers shard_index = worker_info.id worker_seed = self.seed + worker_info.id else: # Single process: process the full stream without sharding num_shards = 1 shard_index = 0 worker_seed = self.seed self.seed = worker_seed dataset = self._load_dataset(num_shards=num_shards, shard_index=shard_index) return self._tokenize_and_pack(dataset) class ValidationDataset: """Validation dataset. Pre-fetches a fixed amount of data from the streaming dataset and stores it in memory. Consistent data across evaluations is necessary for meaningful comparisons between epochs. """ def __init__( self, tokenizer: Tokenizer, config: DataConfig, num_samples: int = 100, seed: int = 9999, ): self.tokenizer = tokenizer self.config = config self.num_samples = num_samples self.samples: List[Dict[str, torch.Tensor]] = [] self._prepare(seed) def _prepare(self, seed: int): """Pre-extracts validation samples from the dataset.""" from datasets import load_dataset print(f"[Validation] Preparing {self.num_samples} validation samples...") ds = load_dataset( self.config.dataset_name, name=self.config.dataset_subset, split=self.config.dataset_split, streaming=True, trust_remote_code=True, ) # Use a different seed and skip the beginning to avoid overlap with training data ds = ds.shuffle(seed=seed, buffer_size=5_000) buffer: List[int] = [] count = 0 for example in ds: if count >= self.num_samples: break text = example[self.config.text_column] if not text or not text.strip(): continue token_ids = self.tokenizer.encode(text, add_special_tokens=False) if not token_ids: continue token_ids.append(self.tokenizer.eos_id) buffer.extend(token_ids) while len(buffer) >= self.config.max_seq_len + 1 and count < self.num_samples: chunk = buffer[: self.config.max_seq_len + 1] buffer = buffer[self.config.max_seq_len + 1 :] self.samples.append({ "input_ids": torch.tensor(chunk[:-1], dtype=torch.long), "targets": torch.tensor(chunk[1:], dtype=torch.long), }) count += 1 print(f"[Validation] {len(self.samples)} samples ready") def get_dataloader(self, batch_size: int) -> DataLoader: """Returns a validation DataLoader.""" return DataLoader( self.samples, batch_size=batch_size, shuffle=False, num_workers=0, collate_fn=_collate_fn, ) class MixedStreamingDataset(IterableDataset): """Interleaves multiple streaming datasets by sampling weight. Used for Continual Pre-Training (CPT) to mix domain-specific data (e.g., code) with general data to prevent catastrophic forgetting. How it works: - Maintains one PackedStreamingDataset per source - At each yield, randomly selects a source according to mix_weights - Each source independently packs and tokenizes its own stream Example (Code CPT): sources = [FineWeb-Edu (general), StarCoder (Python)] weights = [0.2, 0.8] → 20% general, 80% code """ def __init__( self, tokenizer: "Tokenizer", config: "DataConfig", seed: int = 42, ): super().__init__() self.tokenizer = tokenizer self.config = config self.seed = seed self.max_seq_len = config.max_seq_len # Build dataset specs: primary + secondary datasets self.dataset_specs = [ { "name": config.dataset_name, "subset": config.dataset_subset, "split": config.dataset_split, "text_column": config.text_column, } ] + list(config.mix_datasets) self.weights = config.mix_weights assert len(self.weights) == len(self.dataset_specs), ( f"mix_weights length ({len(self.weights)}) must match " f"number of datasets ({len(self.dataset_specs)})" ) assert abs(sum(self.weights) - 1.0) < 1e-6, ( f"mix_weights must sum to 1.0, got {sum(self.weights)}" ) def _load_single_dataset(self, spec: dict, num_shards: int, shard_index: int, seed: int): """Loads a single HuggingFace dataset in streaming mode.""" from datasets import load_dataset ds = load_dataset( spec["name"], name=spec["subset"], split=spec["split"], streaming=True, trust_remote_code=True, ) if num_shards > 1: ds = ds.shard(num_shards=num_shards, index=shard_index) ds = ds.shuffle(seed=seed, buffer_size=10_000) return ds def _token_stream(self, dataset, text_column: str) -> Iterator[int]: """Yields tokens one-by-one from a dataset stream (with EOS separators).""" for example in dataset: text = example[text_column] if not text or not text.strip(): continue token_ids = self.tokenizer.encode(text, add_special_tokens=False) if not token_ids: continue if self.config.use_eos_separator: token_ids.append(self.tokenizer.eos_id) yield from token_ids def _pack_from_stream(self, token_iter: Iterator[int]) -> Iterator[List[int]]: """Packs token stream into max_seq_len+1 chunks.""" buffer: List[int] = [] for tok in token_iter: buffer.append(tok) if len(buffer) >= self.max_seq_len + 1: yield buffer[: self.max_seq_len + 1] buffer = buffer[self.max_seq_len + 1 :] def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: import random worker_info = torch.utils.data.get_worker_info() if worker_info is not None: num_shards = worker_info.num_workers shard_index = worker_info.id worker_seed = self.seed + worker_info.id else: num_shards = 1 shard_index = 0 worker_seed = self.seed rng = random.Random(worker_seed) # Create a packed-sequence iterator for each dataset source source_iters = [] for i, spec in enumerate(self.dataset_specs): ds = self._load_single_dataset( spec, num_shards, shard_index, seed=worker_seed + i * 1000 ) token_iter = self._token_stream(ds, spec["text_column"]) pack_iter = self._pack_from_stream(token_iter) source_iters.append(pack_iter) # Pre-fetch one chunk from each source (None = exhausted) buffers: List[Optional[List[int]]] = [None] * len(source_iters) for i, it in enumerate(source_iters): try: buffers[i] = next(it) except StopIteration: buffers[i] = None # Weighted round-robin sampling while any(b is not None for b in buffers): # Build active weights (zero out exhausted sources) active_weights = [ w if buffers[i] is not None else 0.0 for i, w in enumerate(self.weights) ] total = sum(active_weights) if total == 0: break # Weighted random selection r = rng.random() * total cumulative = 0.0 chosen = 0 for i, w in enumerate(active_weights): cumulative += w if r <= cumulative: chosen = i break chunk = buffers[chosen] input_ids = torch.tensor(chunk[:-1], dtype=torch.long) targets = torch.tensor(chunk[1:], dtype=torch.long) yield {"input_ids": input_ids, "targets": targets} # Refill the chosen source try: buffers[chosen] = next(source_iters[chosen]) except StopIteration: buffers[chosen] = None def _collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: """Combines samples in a batch into a single tensor. Because of sequence packing, all samples have the same length (max_seq_len), so no additional padding is needed. """ return { "input_ids": torch.stack([s["input_ids"] for s in batch]), "targets": torch.stack([s["targets"] for s in batch]), }