gpt2 / src /data /dataset.py
triton329's picture
Upload folder using huggingface_hub
c21e887 verified
Raw
History Blame Contribute Delete
1.56 kB
import torch
from torch.utils.data import Dataset, DataLoader
from src.data.tokenizer import BaseTokenizer
class GPTDataset(Dataset):
"""Sliding-window dataset for GPT-style next-token prediction.
Args:
text: raw input text
tokenizer: a BaseTokenizer instance
max_len: number of tokens the model sees as context
stride: step size between windows
"""
def __init__(
self, text: str, tokenizer: BaseTokenizer, max_len: int, stride: int
) -> None:
self.max_len = max_len
self.stride = stride
self.tokens = torch.tensor(tokenizer.encode(text), dtype=torch.long)
self.num_samples = max(0, (len(self.tokens) - max_len - 1) // stride + 1)
def __len__(self) -> int:
return self.num_samples
def __getitem__(self, idx: int) -> tuple[torch.Tensor, torch.Tensor]:
start = idx * self.stride
return (
self.tokens[start : start + self.max_len],
self.tokens[start + 1 : start + self.max_len + 1],
)
def create_gpt_dataloader(
text: str,
tokenizer: BaseTokenizer,
max_len: int = 3,
stride: int = 1,
batch_size: int = 10,
num_workers: int = 0,
pin_memory: bool = False,
) -> DataLoader:
"""Create a DataLoader from raw text using sliding-window chunking."""
dataset = GPTDataset(text, tokenizer, max_len, stride)
return DataLoader(
dataset,
batch_size=batch_size,
shuffle=True,
num_workers=num_workers,
pin_memory=pin_memory,
)