LLM-1B-Lab / llm_lab /data /dataset.py
Vjeong's picture
Fix dead split parameter in PackedStreamingDataset._load_dataset
0cd5689
"""Streaming dataset — sequence packing and validation dataset."""
from typing import Iterator, List, Dict, Optional
import torch
from torch.utils.data import IterableDataset, DataLoader
from llm_lab.config import DataConfig
from .tokenizer import Tokenizer
class PackedStreamingDataset(IterableDataset):
"""Streaming + sequence packing dataset.
Why sequence packing?
- Naive approach: truncate each document to max_seq_len with padding → wastes GPU
- Sequence packing: concatenate multiple documents to fill max_seq_len → 100% utilization
How it works:
Doc1 (300 tokens) + Doc2 (1500 tokens) + Doc3 (248 tokens) = 2048 tokens
→ [Doc1][EOS][Doc2][EOS][Doc3][EOS][... no padding, fits exactly]
Why streaming?
- FineWeb-Edu 10B samples: tens of GB even when compressed
- Full download not feasible on Colab disk limit (~200GB)
- Streaming: reads from the network only as much as needed
Notes for training:
- EOS token inserted at document boundaries so the model recognizes end-of-document
- EOS naturally serves as a boundary marker without cross-attention masking
"""
def __init__(
self,
tokenizer: Tokenizer,
config: DataConfig,
split: str = "train",
seed: int = 42,
):
super().__init__()
self.tokenizer = tokenizer
self.config = config
self.split = split
self.seed = seed
self.max_seq_len = config.max_seq_len
def _load_dataset(self, num_shards: int = 1, shard_index: int = 0):
"""Loads the HuggingFace dataset in streaming mode.
Args:
num_shards: Total number of shards (= DataLoader num_workers)
shard_index: The shard index this worker is responsible for (0 ~ num_shards-1)
Sharding principle:
With num_shards=4, the stream is split into 4 equal parts so each worker
processes a distinct 1/4. Shuffling is applied after sharding so there is
no document overlap between workers.
"""
from datasets import load_dataset
ds = load_dataset(
self.config.dataset_name,
name=self.config.dataset_subset,
split=self.split,
streaming=True, # Key: streaming mode
trust_remote_code=True,
)
# Full partitioning (sharding): worker i processes only 1/num_shards of the stream
# Must be applied before shuffling so each worker has a non-overlapping set of documents
if num_shards > 1:
ds = ds.shard(num_shards=num_shards, index=shard_index)
# Shuffle (approximate buffer-based shuffle in streaming mode)
ds = ds.shuffle(seed=self.seed, buffer_size=10_000)
return ds
def _tokenize_and_pack(self, dataset) -> Iterator[Dict[str, torch.Tensor]]:
"""Tokenizes documents and packs them into sequences.
Yields:
{"input_ids": (max_seq_len,), "targets": (max_seq_len,)}
targets = input_ids shifted by one position:
input_ids: [A, B, C, D, E]
targets: [B, C, D, E, F]
→ The model sees A and predicts B, sees B and predicts C, ...
"""
buffer: List[int] = [] # Token buffer
for example in dataset:
text = example[self.config.text_column]
if not text or not text.strip():
continue
# Tokenize (without special tokens)
token_ids = self.tokenizer.encode(text, add_special_tokens=False)
if not token_ids:
continue
# Append EOS token (marks document boundary)
if self.config.use_eos_separator:
token_ids.append(self.tokenizer.eos_id)
# Add to buffer
buffer.extend(token_ids)
# Generate sequences once the buffer is full enough
# +1 is needed to generate targets (input + next token)
while len(buffer) >= self.max_seq_len + 1:
# Extract max_seq_len + 1 tokens
chunk = buffer[: self.max_seq_len + 1]
buffer = buffer[self.max_seq_len + 1 :]
# input_ids: from the first to the second-to-last token
input_ids = torch.tensor(chunk[:-1], dtype=torch.long)
# targets: from the second to the last token (shifted by one)
targets = torch.tensor(chunk[1:], dtype=torch.long)
yield {"input_ids": input_ids, "targets": targets}
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
"""Iterator called by DataLoader.
Multi-worker support (full partitioning approach):
- Previous: all workers read the same stream with different seeds → possible document duplication
- Improved: ds.shard() splits the stream into num_workers parts → no document overlap between workers
Example (num_workers=4, total N documents):
Worker 0: docs 0, 4, 8, 12, ... (N/4 docs)
Worker 1: docs 1, 5, 9, 13, ... (N/4 docs)
Worker 2: docs 2, 6, 10, 14, ... (N/4 docs)
Worker 3: docs 3, 7, 11, 15, ... (N/4 docs)
"""
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
# Full partitioning: assign a shard per worker + independent shuffle seed
num_shards = worker_info.num_workers
shard_index = worker_info.id
worker_seed = self.seed + worker_info.id
else:
# Single process: process the full stream without sharding
num_shards = 1
shard_index = 0
worker_seed = self.seed
self.seed = worker_seed
dataset = self._load_dataset(num_shards=num_shards, shard_index=shard_index)
return self._tokenize_and_pack(dataset)
class ValidationDataset:
"""Validation dataset.
Pre-fetches a fixed amount of data from the streaming dataset and stores it in memory.
Consistent data across evaluations is necessary for meaningful comparisons between epochs.
"""
def __init__(
self,
tokenizer: Tokenizer,
config: DataConfig,
num_samples: int = 100,
seed: int = 9999,
):
self.tokenizer = tokenizer
self.config = config
self.num_samples = num_samples
self.samples: List[Dict[str, torch.Tensor]] = []
self._prepare(seed)
def _prepare(self, seed: int):
"""Pre-extracts validation samples from the dataset."""
from datasets import load_dataset
print(f"[Validation] Preparing {self.num_samples} validation samples...")
ds = load_dataset(
self.config.dataset_name,
name=self.config.dataset_subset,
split=self.config.dataset_split,
streaming=True,
trust_remote_code=True,
)
# Use a different seed and skip the beginning to avoid overlap with training data
ds = ds.shuffle(seed=seed, buffer_size=5_000)
buffer: List[int] = []
count = 0
for example in ds:
if count >= self.num_samples:
break
text = example[self.config.text_column]
if not text or not text.strip():
continue
token_ids = self.tokenizer.encode(text, add_special_tokens=False)
if not token_ids:
continue
token_ids.append(self.tokenizer.eos_id)
buffer.extend(token_ids)
while len(buffer) >= self.config.max_seq_len + 1 and count < self.num_samples:
chunk = buffer[: self.config.max_seq_len + 1]
buffer = buffer[self.config.max_seq_len + 1 :]
self.samples.append({
"input_ids": torch.tensor(chunk[:-1], dtype=torch.long),
"targets": torch.tensor(chunk[1:], dtype=torch.long),
})
count += 1
print(f"[Validation] {len(self.samples)} samples ready")
def get_dataloader(self, batch_size: int) -> DataLoader:
"""Returns a validation DataLoader."""
return DataLoader(
self.samples,
batch_size=batch_size,
shuffle=False,
num_workers=0,
collate_fn=_collate_fn,
)
class MixedStreamingDataset(IterableDataset):
"""Interleaves multiple streaming datasets by sampling weight.
Used for Continual Pre-Training (CPT) to mix domain-specific data
(e.g., code) with general data to prevent catastrophic forgetting.
How it works:
- Maintains one PackedStreamingDataset per source
- At each yield, randomly selects a source according to mix_weights
- Each source independently packs and tokenizes its own stream
Example (Code CPT):
sources = [FineWeb-Edu (general), StarCoder (Python)]
weights = [0.2, 0.8] → 20% general, 80% code
"""
def __init__(
self,
tokenizer: "Tokenizer",
config: "DataConfig",
seed: int = 42,
):
super().__init__()
self.tokenizer = tokenizer
self.config = config
self.seed = seed
self.max_seq_len = config.max_seq_len
# Build dataset specs: primary + secondary datasets
self.dataset_specs = [
{
"name": config.dataset_name,
"subset": config.dataset_subset,
"split": config.dataset_split,
"text_column": config.text_column,
}
] + list(config.mix_datasets)
self.weights = config.mix_weights
assert len(self.weights) == len(self.dataset_specs), (
f"mix_weights length ({len(self.weights)}) must match "
f"number of datasets ({len(self.dataset_specs)})"
)
assert abs(sum(self.weights) - 1.0) < 1e-6, (
f"mix_weights must sum to 1.0, got {sum(self.weights)}"
)
def _load_single_dataset(self, spec: dict, num_shards: int, shard_index: int, seed: int):
"""Loads a single HuggingFace dataset in streaming mode."""
from datasets import load_dataset
ds = load_dataset(
spec["name"],
name=spec["subset"],
split=spec["split"],
streaming=True,
trust_remote_code=True,
)
if num_shards > 1:
ds = ds.shard(num_shards=num_shards, index=shard_index)
ds = ds.shuffle(seed=seed, buffer_size=10_000)
return ds
def _token_stream(self, dataset, text_column: str) -> Iterator[int]:
"""Yields tokens one-by-one from a dataset stream (with EOS separators)."""
for example in dataset:
text = example[text_column]
if not text or not text.strip():
continue
token_ids = self.tokenizer.encode(text, add_special_tokens=False)
if not token_ids:
continue
if self.config.use_eos_separator:
token_ids.append(self.tokenizer.eos_id)
yield from token_ids
def _pack_from_stream(self, token_iter: Iterator[int]) -> Iterator[List[int]]:
"""Packs token stream into max_seq_len+1 chunks."""
buffer: List[int] = []
for tok in token_iter:
buffer.append(tok)
if len(buffer) >= self.max_seq_len + 1:
yield buffer[: self.max_seq_len + 1]
buffer = buffer[self.max_seq_len + 1 :]
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
import random
worker_info = torch.utils.data.get_worker_info()
if worker_info is not None:
num_shards = worker_info.num_workers
shard_index = worker_info.id
worker_seed = self.seed + worker_info.id
else:
num_shards = 1
shard_index = 0
worker_seed = self.seed
rng = random.Random(worker_seed)
# Create a packed-sequence iterator for each dataset source
source_iters = []
for i, spec in enumerate(self.dataset_specs):
ds = self._load_single_dataset(
spec, num_shards, shard_index, seed=worker_seed + i * 1000
)
token_iter = self._token_stream(ds, spec["text_column"])
pack_iter = self._pack_from_stream(token_iter)
source_iters.append(pack_iter)
# Pre-fetch one chunk from each source (None = exhausted)
buffers: List[Optional[List[int]]] = [None] * len(source_iters)
for i, it in enumerate(source_iters):
try:
buffers[i] = next(it)
except StopIteration:
buffers[i] = None
# Weighted round-robin sampling
while any(b is not None for b in buffers):
# Build active weights (zero out exhausted sources)
active_weights = [
w if buffers[i] is not None else 0.0
for i, w in enumerate(self.weights)
]
total = sum(active_weights)
if total == 0:
break
# Weighted random selection
r = rng.random() * total
cumulative = 0.0
chosen = 0
for i, w in enumerate(active_weights):
cumulative += w
if r <= cumulative:
chosen = i
break
chunk = buffers[chosen]
input_ids = torch.tensor(chunk[:-1], dtype=torch.long)
targets = torch.tensor(chunk[1:], dtype=torch.long)
yield {"input_ids": input_ids, "targets": targets}
# Refill the chosen source
try:
buffers[chosen] = next(source_iters[chosen])
except StopIteration:
buffers[chosen] = None
def _collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
"""Combines samples in a batch into a single tensor.
Because of sequence packing, all samples have the same length (max_seq_len),
so no additional padding is needed.
"""
return {
"input_ids": torch.stack([s["input_ids"] for s in batch]),
"targets": torch.stack([s["targets"] for s in batch]),
}