karthick
Upload TinyStories 24.5M model - article generation success
fb67af8
"""Dataset and DataLoader utilities for TinyStories training.
This module provides:
1. TinyStoriesDataset class for loading and processing TinyStories
2. create_dataloaders function for creating train/val DataLoaders
3. Sequence packing for efficient training
TinyStories is a synthetic dataset of short stories generated by GPT-3.5/4
using a limited vocabulary suitable for children. Perfect for fast training
and testing language models.
"""
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
from pathlib import Path
import pickle
import logging
from typing import Dict, List, Tuple, Optional
from tqdm import tqdm
logger = logging.getLogger(__name__)
class TinyStoriesDataset(Dataset):
"""TinyStories dataset with sequence packing for efficient training.
TinyStories is a synthetic dataset of short stories generated by GPT-3.5/4
using a limited vocabulary suitable for children. The dataset contains
~2.1M stories and is excellent for:
- Fast training (only ~1GB)
- Clean, well-formed English
- Testing model architecture
- Educational purposes
This dataset:
1. Loads TinyStories from HuggingFace datasets
2. Tokenizes the text
3. Packs sequences to max_seq_len for efficiency
4. Caches processed data for fast subsequent loading
"""
def __init__(
self,
tokenizer,
split: str = "train",
max_seq_len: int = 512,
cache_dir: Optional[str] = None,
):
"""Initialize TinyStories dataset.
Args:
tokenizer: Tokenizer instance (must have encode method)
split: Dataset split ("train" or "validation")
max_seq_len: Maximum sequence length (default: 512, matches official paper)
cache_dir: Directory for caching processed data
"""
self.tokenizer = tokenizer
self.split = split
self.max_seq_len = max_seq_len
self.cache_dir = Path(cache_dir) if cache_dir else Path("./data/cache")
self.cache_dir.mkdir(parents=True, exist_ok=True)
# Cache file path
cache_file = self.cache_dir / f"tinystories_{split}_{max_seq_len}.pkl"
# Try to load from cache
if cache_file.exists():
logger.info(f"Loading cached dataset from {cache_file}")
with open(cache_file, "rb") as f:
cache_data = pickle.load(f)
self.input_ids = cache_data["input_ids"]
self.labels = cache_data["labels"]
logger.info(f"Loaded {len(self.input_ids)} sequences from cache")
else:
# Process dataset
logger.info(f"Processing TinyStories {split} split...")
self.input_ids, self.labels = self._process_dataset()
# Save to cache
logger.info(f"Saving processed dataset to {cache_file}")
cache_data = {
"input_ids": self.input_ids,
"labels": self.labels,
}
with open(cache_file, "wb") as f:
pickle.dump(cache_data, f)
logger.info(f"Dataset ready: {len(self.input_ids)} sequences")
def _process_dataset(self) -> Tuple[List[torch.Tensor], List[torch.Tensor]]:
"""Process TinyStories dataset into packed sequences.
Returns:
Tuple of (input_ids, labels) lists
"""
# Load dataset
dataset = load_dataset(
"roneneldan/TinyStories",
split=self.split,
)
# Tokenize all text
logger.info("Tokenizing dataset...")
all_token_ids = []
for example in tqdm(dataset, desc="Tokenizing"):
text = example["text"].strip()
if len(text) > 0: # Skip empty stories
# Encode text
if hasattr(self.tokenizer, 'encode'):
token_ids = self.tokenizer.encode(text, add_special_tokens=False)
else:
# Fallback for tokenizers.Tokenizer
token_ids = self.tokenizer.tokenizer.encode(text).ids
all_token_ids.extend(token_ids)
logger.info(f"Total tokens: {len(all_token_ids):,}")
# Pack into sequences
logger.info("Packing sequences...")
input_ids_list = []
labels_list = []
# Pack sequences with stride to maximize data usage
for i in range(0, len(all_token_ids) - 1, self.max_seq_len):
# Get sequence
seq = all_token_ids[i : i + self.max_seq_len]
# Skip if too short
if len(seq) < 2:
continue
# Create input_ids and labels
# input_ids: [0, 1, 2, ..., n-1]
# labels: [1, 2, 3, ..., n]
input_ids = torch.tensor(seq[:-1], dtype=torch.long)
labels = torch.tensor(seq[1:], dtype=torch.long)
# Pad if necessary
if len(input_ids) < self.max_seq_len:
pad_len = self.max_seq_len - len(input_ids)
input_ids = torch.cat([
input_ids,
torch.full((pad_len,), self.tokenizer.pad_token_id, dtype=torch.long)
])
labels = torch.cat([
labels,
torch.full((pad_len,), -100, dtype=torch.long) # -100 is ignored in loss
])
input_ids_list.append(input_ids)
labels_list.append(labels)
logger.info(f"Created {len(input_ids_list)} packed sequences")
return input_ids_list, labels_list
def __len__(self) -> int:
"""Return number of sequences."""
return len(self.input_ids)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
"""Get a single sequence.
Args:
idx: Sequence index
Returns:
Dictionary with 'input_ids' and 'labels'
"""
return {
"input_ids": self.input_ids[idx],
"labels": self.labels[idx],
}
def collate_fn(batch: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
"""Collate function for DataLoader.
Args:
batch: List of dictionaries with 'input_ids' and 'labels'
Returns:
Batched dictionary
"""
input_ids = torch.stack([item["input_ids"] for item in batch])
labels = torch.stack([item["labels"] for item in batch])
return {
"input_ids": input_ids,
"labels": labels,
}
def create_dataloaders(
tokenizer,
batch_size: int,
max_seq_len: int,
cache_dir: str,
dataset_name: str = "tinystories",
num_workers: int = 0,
pin_memory: bool = True,
drop_last: bool = True,
) -> Tuple[DataLoader, DataLoader]:
"""Create train and validation DataLoaders for TinyStories.
Args:
tokenizer: Tokenizer instance
batch_size: Batch size per device
max_seq_len: Maximum sequence length (512 recommended for TinyStories)
cache_dir: Directory for caching processed data
dataset_name: Dataset to use (default: "tinystories")
num_workers: Number of data loading workers (use 0 for Windows)
pin_memory: Whether to pin memory for faster GPU transfer
drop_last: Whether to drop last incomplete batch
Returns:
Tuple of (train_loader, val_loader)
"""
logger.info("Using TinyStories dataset")
logger.info("Creating train dataset...")
train_dataset = TinyStoriesDataset(
tokenizer=tokenizer,
split="train",
max_seq_len=max_seq_len,
cache_dir=cache_dir,
)
logger.info("Creating validation dataset...")
val_dataset = TinyStoriesDataset(
tokenizer=tokenizer,
split="validation",
max_seq_len=max_seq_len,
cache_dir=cache_dir,
)
# Create DataLoaders
train_loader = DataLoader(
train_dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=drop_last,
collate_fn=collate_fn,
)
val_loader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=False,
collate_fn=collate_fn,
)
logger.info(f"Train batches: {len(train_loader)}")
logger.info(f"Validation batches: {len(val_loader)}")
return train_loader, val_loader
# Test the dataset
if __name__ == "__main__":
from .tokenizer import load_tokenizer
print("Testing TinyStoriesDataset...")
# Load tokenizer (assumes it exists)
tokenizer_path = "./tokenizer/wikimini_32k"
if Path(tokenizer_path).exists():
tokenizer = load_tokenizer(tokenizer_path)
# Create small dataset for testing
dataset = TinyStoriesDataset(
tokenizer=tokenizer,
split="validation", # Use smaller split for testing
max_seq_len=128,
cache_dir="./data/cache_test",
)
print(f"\nDataset size: {len(dataset)}")
print(f"Sample batch:")
sample = dataset[0]
print(f" Input IDs shape: {sample['input_ids'].shape}")
print(f" Labels shape: {sample['labels'].shape}")
print(f" First 10 input IDs: {sample['input_ids'][:10]}")
print(f" First 10 labels: {sample['labels'][:10]}")
# Test DataLoader
loader = DataLoader(dataset, batch_size=4, collate_fn=collate_fn)
batch = next(iter(loader))
print(f"\nDataLoader batch:")
print(f" Input IDs shape: {batch['input_ids'].shape}")
print(f" Labels shape: {batch['labels'].shape}")
else:
print(f"Tokenizer not found at {tokenizer_path}")
print("Please train tokenizer first: python scripts/train_tokenizer.py")