|
|
"""Dataset and DataLoader utilities for TinyStories training. |
|
|
|
|
|
This module provides: |
|
|
1. TinyStoriesDataset class for loading and processing TinyStories |
|
|
2. create_dataloaders function for creating train/val DataLoaders |
|
|
3. Sequence packing for efficient training |
|
|
|
|
|
TinyStories is a synthetic dataset of short stories generated by GPT-3.5/4 |
|
|
using a limited vocabulary suitable for children. Perfect for fast training |
|
|
and testing language models. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from datasets import load_dataset |
|
|
from pathlib import Path |
|
|
import pickle |
|
|
import logging |
|
|
from typing import Dict, List, Tuple, Optional |
|
|
from tqdm import tqdm |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class TinyStoriesDataset(Dataset): |
|
|
"""TinyStories dataset with sequence packing for efficient training. |
|
|
|
|
|
TinyStories is a synthetic dataset of short stories generated by GPT-3.5/4 |
|
|
using a limited vocabulary suitable for children. The dataset contains |
|
|
~2.1M stories and is excellent for: |
|
|
- Fast training (only ~1GB) |
|
|
- Clean, well-formed English |
|
|
- Testing model architecture |
|
|
- Educational purposes |
|
|
|
|
|
This dataset: |
|
|
1. Loads TinyStories from HuggingFace datasets |
|
|
2. Tokenizes the text |
|
|
3. Packs sequences to max_seq_len for efficiency |
|
|
4. Caches processed data for fast subsequent loading |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
tokenizer, |
|
|
split: str = "train", |
|
|
max_seq_len: int = 512, |
|
|
cache_dir: Optional[str] = None, |
|
|
): |
|
|
"""Initialize TinyStories dataset. |
|
|
|
|
|
Args: |
|
|
tokenizer: Tokenizer instance (must have encode method) |
|
|
split: Dataset split ("train" or "validation") |
|
|
max_seq_len: Maximum sequence length (default: 512, matches official paper) |
|
|
cache_dir: Directory for caching processed data |
|
|
""" |
|
|
self.tokenizer = tokenizer |
|
|
self.split = split |
|
|
self.max_seq_len = max_seq_len |
|
|
self.cache_dir = Path(cache_dir) if cache_dir else Path("./data/cache") |
|
|
self.cache_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
cache_file = self.cache_dir / f"tinystories_{split}_{max_seq_len}.pkl" |
|
|
|
|
|
|
|
|
if cache_file.exists(): |
|
|
logger.info(f"Loading cached dataset from {cache_file}") |
|
|
with open(cache_file, "rb") as f: |
|
|
cache_data = pickle.load(f) |
|
|
self.input_ids = cache_data["input_ids"] |
|
|
self.labels = cache_data["labels"] |
|
|
logger.info(f"Loaded {len(self.input_ids)} sequences from cache") |
|
|
else: |
|
|
|
|
|
logger.info(f"Processing TinyStories {split} split...") |
|
|
self.input_ids, self.labels = self._process_dataset() |
|
|
|
|
|
|
|
|
logger.info(f"Saving processed dataset to {cache_file}") |
|
|
cache_data = { |
|
|
"input_ids": self.input_ids, |
|
|
"labels": self.labels, |
|
|
} |
|
|
with open(cache_file, "wb") as f: |
|
|
pickle.dump(cache_data, f) |
|
|
|
|
|
logger.info(f"Dataset ready: {len(self.input_ids)} sequences") |
|
|
|
|
|
def _process_dataset(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: |
|
|
"""Process TinyStories dataset into packed sequences. |
|
|
|
|
|
Returns: |
|
|
Tuple of (input_ids, labels) lists |
|
|
""" |
|
|
|
|
|
dataset = load_dataset( |
|
|
"roneneldan/TinyStories", |
|
|
split=self.split, |
|
|
) |
|
|
|
|
|
|
|
|
logger.info("Tokenizing dataset...") |
|
|
all_token_ids = [] |
|
|
|
|
|
for example in tqdm(dataset, desc="Tokenizing"): |
|
|
text = example["text"].strip() |
|
|
if len(text) > 0: |
|
|
|
|
|
if hasattr(self.tokenizer, 'encode'): |
|
|
token_ids = self.tokenizer.encode(text, add_special_tokens=False) |
|
|
else: |
|
|
|
|
|
token_ids = self.tokenizer.tokenizer.encode(text).ids |
|
|
|
|
|
all_token_ids.extend(token_ids) |
|
|
|
|
|
logger.info(f"Total tokens: {len(all_token_ids):,}") |
|
|
|
|
|
|
|
|
logger.info("Packing sequences...") |
|
|
input_ids_list = [] |
|
|
labels_list = [] |
|
|
|
|
|
|
|
|
for i in range(0, len(all_token_ids) - 1, self.max_seq_len): |
|
|
|
|
|
seq = all_token_ids[i : i + self.max_seq_len] |
|
|
|
|
|
|
|
|
if len(seq) < 2: |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_ids = torch.tensor(seq[:-1], dtype=torch.long) |
|
|
labels = torch.tensor(seq[1:], dtype=torch.long) |
|
|
|
|
|
|
|
|
if len(input_ids) < self.max_seq_len: |
|
|
pad_len = self.max_seq_len - len(input_ids) |
|
|
input_ids = torch.cat([ |
|
|
input_ids, |
|
|
torch.full((pad_len,), self.tokenizer.pad_token_id, dtype=torch.long) |
|
|
]) |
|
|
labels = torch.cat([ |
|
|
labels, |
|
|
torch.full((pad_len,), -100, dtype=torch.long) |
|
|
]) |
|
|
|
|
|
input_ids_list.append(input_ids) |
|
|
labels_list.append(labels) |
|
|
|
|
|
logger.info(f"Created {len(input_ids_list)} packed sequences") |
|
|
|
|
|
return input_ids_list, labels_list |
|
|
|
|
|
def __len__(self) -> int: |
|
|
"""Return number of sequences.""" |
|
|
return len(self.input_ids) |
|
|
|
|
|
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
|
|
"""Get a single sequence. |
|
|
|
|
|
Args: |
|
|
idx: Sequence index |
|
|
|
|
|
Returns: |
|
|
Dictionary with 'input_ids' and 'labels' |
|
|
""" |
|
|
return { |
|
|
"input_ids": self.input_ids[idx], |
|
|
"labels": self.labels[idx], |
|
|
} |
|
|
|
|
|
|
|
|
def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: |
|
|
"""Collate function for DataLoader. |
|
|
|
|
|
Args: |
|
|
batch: List of dictionaries with 'input_ids' and 'labels' |
|
|
|
|
|
Returns: |
|
|
Batched dictionary |
|
|
""" |
|
|
input_ids = torch.stack([item["input_ids"] for item in batch]) |
|
|
labels = torch.stack([item["labels"] for item in batch]) |
|
|
|
|
|
return { |
|
|
"input_ids": input_ids, |
|
|
"labels": labels, |
|
|
} |
|
|
|
|
|
|
|
|
def create_dataloaders( |
|
|
tokenizer, |
|
|
batch_size: int, |
|
|
max_seq_len: int, |
|
|
cache_dir: str, |
|
|
dataset_name: str = "tinystories", |
|
|
num_workers: int = 0, |
|
|
pin_memory: bool = True, |
|
|
drop_last: bool = True, |
|
|
) -> Tuple[DataLoader, DataLoader]: |
|
|
"""Create train and validation DataLoaders for TinyStories. |
|
|
|
|
|
Args: |
|
|
tokenizer: Tokenizer instance |
|
|
batch_size: Batch size per device |
|
|
max_seq_len: Maximum sequence length (512 recommended for TinyStories) |
|
|
cache_dir: Directory for caching processed data |
|
|
dataset_name: Dataset to use (default: "tinystories") |
|
|
num_workers: Number of data loading workers (use 0 for Windows) |
|
|
pin_memory: Whether to pin memory for faster GPU transfer |
|
|
drop_last: Whether to drop last incomplete batch |
|
|
|
|
|
Returns: |
|
|
Tuple of (train_loader, val_loader) |
|
|
""" |
|
|
logger.info("Using TinyStories dataset") |
|
|
|
|
|
logger.info("Creating train dataset...") |
|
|
train_dataset = TinyStoriesDataset( |
|
|
tokenizer=tokenizer, |
|
|
split="train", |
|
|
max_seq_len=max_seq_len, |
|
|
cache_dir=cache_dir, |
|
|
) |
|
|
|
|
|
logger.info("Creating validation dataset...") |
|
|
val_dataset = TinyStoriesDataset( |
|
|
tokenizer=tokenizer, |
|
|
split="validation", |
|
|
max_seq_len=max_seq_len, |
|
|
cache_dir=cache_dir, |
|
|
) |
|
|
|
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=True, |
|
|
num_workers=num_workers, |
|
|
pin_memory=pin_memory, |
|
|
drop_last=drop_last, |
|
|
collate_fn=collate_fn, |
|
|
) |
|
|
|
|
|
val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=num_workers, |
|
|
pin_memory=pin_memory, |
|
|
drop_last=False, |
|
|
collate_fn=collate_fn, |
|
|
) |
|
|
|
|
|
logger.info(f"Train batches: {len(train_loader)}") |
|
|
logger.info(f"Validation batches: {len(val_loader)}") |
|
|
|
|
|
return train_loader, val_loader |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
from .tokenizer import load_tokenizer |
|
|
|
|
|
print("Testing TinyStoriesDataset...") |
|
|
|
|
|
|
|
|
tokenizer_path = "./tokenizer/wikimini_32k" |
|
|
if Path(tokenizer_path).exists(): |
|
|
tokenizer = load_tokenizer(tokenizer_path) |
|
|
|
|
|
|
|
|
dataset = TinyStoriesDataset( |
|
|
tokenizer=tokenizer, |
|
|
split="validation", |
|
|
max_seq_len=128, |
|
|
cache_dir="./data/cache_test", |
|
|
) |
|
|
|
|
|
print(f"\nDataset size: {len(dataset)}") |
|
|
print(f"Sample batch:") |
|
|
sample = dataset[0] |
|
|
print(f" Input IDs shape: {sample['input_ids'].shape}") |
|
|
print(f" Labels shape: {sample['labels'].shape}") |
|
|
print(f" First 10 input IDs: {sample['input_ids'][:10]}") |
|
|
print(f" First 10 labels: {sample['labels'][:10]}") |
|
|
|
|
|
|
|
|
loader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn) |
|
|
batch = next(iter(loader)) |
|
|
print(f"\nDataLoader batch:") |
|
|
print(f" Input IDs shape: {batch['input_ids'].shape}") |
|
|
print(f" Labels shape: {batch['labels'].shape}") |
|
|
else: |
|
|
print(f"Tokenizer not found at {tokenizer_path}") |
|
|
print("Please train tokenizer first: python scripts/train_tokenizer.py") |
|
|
|