"""Dataset and DataLoader utilities for TinyStories training. This module provides: 1. TinyStoriesDataset class for loading and processing TinyStories 2. create_dataloaders function for creating train/val DataLoaders 3. Sequence packing for efficient training TinyStories is a synthetic dataset of short stories generated by GPT-3.5/4 using a limited vocabulary suitable for children. Perfect for fast training and testing language models. """ import torch from torch.utils.data import Dataset, DataLoader from datasets import load_dataset from pathlib import Path import pickle import logging from typing import Dict, List, Tuple, Optional from tqdm import tqdm logger = logging.getLogger(__name__) class TinyStoriesDataset(Dataset): """TinyStories dataset with sequence packing for efficient training. TinyStories is a synthetic dataset of short stories generated by GPT-3.5/4 using a limited vocabulary suitable for children. The dataset contains ~2.1M stories and is excellent for: - Fast training (only ~1GB) - Clean, well-formed English - Testing model architecture - Educational purposes This dataset: 1. Loads TinyStories from HuggingFace datasets 2. Tokenizes the text 3. Packs sequences to max_seq_len for efficiency 4. Caches processed data for fast subsequent loading """ def __init__( self, tokenizer, split: str = "train", max_seq_len: int = 512, cache_dir: Optional[str] = None, ): """Initialize TinyStories dataset. Args: tokenizer: Tokenizer instance (must have encode method) split: Dataset split ("train" or "validation") max_seq_len: Maximum sequence length (default: 512, matches official paper) cache_dir: Directory for caching processed data """ self.tokenizer = tokenizer self.split = split self.max_seq_len = max_seq_len self.cache_dir = Path(cache_dir) if cache_dir else Path("./data/cache") self.cache_dir.mkdir(parents=True, exist_ok=True) # Cache file path cache_file = self.cache_dir / f"tinystories_{split}_{max_seq_len}.pkl" # Try to load from cache if cache_file.exists(): logger.info(f"Loading cached dataset from {cache_file}") with open(cache_file, "rb") as f: cache_data = pickle.load(f) self.input_ids = cache_data["input_ids"] self.labels = cache_data["labels"] logger.info(f"Loaded {len(self.input_ids)} sequences from cache") else: # Process dataset logger.info(f"Processing TinyStories {split} split...") self.input_ids, self.labels = self._process_dataset() # Save to cache logger.info(f"Saving processed dataset to {cache_file}") cache_data = { "input_ids": self.input_ids, "labels": self.labels, } with open(cache_file, "wb") as f: pickle.dump(cache_data, f) logger.info(f"Dataset ready: {len(self.input_ids)} sequences") def _process_dataset(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: """Process TinyStories dataset into packed sequences. Returns: Tuple of (input_ids, labels) lists """ # Load dataset dataset = load_dataset( "roneneldan/TinyStories", split=self.split, ) # Tokenize all text logger.info("Tokenizing dataset...") all_token_ids = [] for example in tqdm(dataset, desc="Tokenizing"): text = example["text"].strip() if len(text) > 0: # Skip empty stories # Encode text if hasattr(self.tokenizer, 'encode'): token_ids = self.tokenizer.encode(text, add_special_tokens=False) else: # Fallback for tokenizers.Tokenizer token_ids = self.tokenizer.tokenizer.encode(text).ids all_token_ids.extend(token_ids) logger.info(f"Total tokens: {len(all_token_ids):,}") # Pack into sequences logger.info("Packing sequences...") input_ids_list = [] labels_list = [] # Pack sequences with stride to maximize data usage for i in range(0, len(all_token_ids) - 1, self.max_seq_len): # Get sequence seq = all_token_ids[i : i + self.max_seq_len] # Skip if too short if len(seq) < 2: continue # Create input_ids and labels # input_ids: [0, 1, 2, ..., n-1] # labels: [1, 2, 3, ..., n] input_ids = torch.tensor(seq[:-1], dtype=torch.long) labels = torch.tensor(seq[1:], dtype=torch.long) # Pad if necessary if len(input_ids) < self.max_seq_len: pad_len = self.max_seq_len - len(input_ids) input_ids = torch.cat([ input_ids, torch.full((pad_len,), self.tokenizer.pad_token_id, dtype=torch.long) ]) labels = torch.cat([ labels, torch.full((pad_len,), -100, dtype=torch.long) # -100 is ignored in loss ]) input_ids_list.append(input_ids) labels_list.append(labels) logger.info(f"Created {len(input_ids_list)} packed sequences") return input_ids_list, labels_list def __len__(self) -> int: """Return number of sequences.""" return len(self.input_ids) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: """Get a single sequence. Args: idx: Sequence index Returns: Dictionary with 'input_ids' and 'labels' """ return { "input_ids": self.input_ids[idx], "labels": self.labels[idx], } def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: """Collate function for DataLoader. Args: batch: List of dictionaries with 'input_ids' and 'labels' Returns: Batched dictionary """ input_ids = torch.stack([item["input_ids"] for item in batch]) labels = torch.stack([item["labels"] for item in batch]) return { "input_ids": input_ids, "labels": labels, } def create_dataloaders( tokenizer, batch_size: int, max_seq_len: int, cache_dir: str, dataset_name: str = "tinystories", num_workers: int = 0, pin_memory: bool = True, drop_last: bool = True, ) -> Tuple[DataLoader, DataLoader]: """Create train and validation DataLoaders for TinyStories. Args: tokenizer: Tokenizer instance batch_size: Batch size per device max_seq_len: Maximum sequence length (512 recommended for TinyStories) cache_dir: Directory for caching processed data dataset_name: Dataset to use (default: "tinystories") num_workers: Number of data loading workers (use 0 for Windows) pin_memory: Whether to pin memory for faster GPU transfer drop_last: Whether to drop last incomplete batch Returns: Tuple of (train_loader, val_loader) """ logger.info("Using TinyStories dataset") logger.info("Creating train dataset...") train_dataset = TinyStoriesDataset( tokenizer=tokenizer, split="train", max_seq_len=max_seq_len, cache_dir=cache_dir, ) logger.info("Creating validation dataset...") val_dataset = TinyStoriesDataset( tokenizer=tokenizer, split="validation", max_seq_len=max_seq_len, cache_dir=cache_dir, ) # Create DataLoaders train_loader = DataLoader( train_dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, collate_fn=collate_fn, ) val_loader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, drop_last=False, collate_fn=collate_fn, ) logger.info(f"Train batches: {len(train_loader)}") logger.info(f"Validation batches: {len(val_loader)}") return train_loader, val_loader # Test the dataset if __name__ == "__main__": from .tokenizer import load_tokenizer print("Testing TinyStoriesDataset...") # Load tokenizer (assumes it exists) tokenizer_path = "./tokenizer/wikimini_32k" if Path(tokenizer_path).exists(): tokenizer = load_tokenizer(tokenizer_path) # Create small dataset for testing dataset = TinyStoriesDataset( tokenizer=tokenizer, split="validation", # Use smaller split for testing max_seq_len=128, cache_dir="./data/cache_test", ) print(f"\nDataset size: {len(dataset)}") print(f"Sample batch:") sample = dataset[0] print(f" Input IDs shape: {sample['input_ids'].shape}") print(f" Labels shape: {sample['labels'].shape}") print(f" First 10 input IDs: {sample['input_ids'][:10]}") print(f" First 10 labels: {sample['labels'][:10]}") # Test DataLoader loader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn) batch = next(iter(loader)) print(f"\nDataLoader batch:") print(f" Input IDs shape: {batch['input_ids'].shape}") print(f" Labels shape: {batch['labels'].shape}") else: print(f"Tokenizer not found at {tokenizer_path}") print("Please train tokenizer first: python scripts/train_tokenizer.py")