import torch from torch.utils.data import Dataset from tokenizers import Tokenizer class TextDataset(Dataset): """ Dataset for language modeling. Tokenizes text and splits into fixed-length sequences. Each item is a tuple (input_ids, target_ids) of length `block_size`. """ def __init__(self, file_path: str, tokenizer_path: str, block_size: int): super().__init__() # Load tokenizer self.tokenizer = Tokenizer.from_file(tokenizer_path) # Read entire text file with open(file_path, 'r', encoding='utf-8') as f: text = f.read() # Tokenize the text to a sequence of token IDs token_ids = self.tokenizer.encode(text).ids # Determine number of full chunks of length block_size+1 chunk_length = block_size + 1 num_chunks = len(token_ids) // chunk_length # Truncate token list to a multiple of chunk_length token_ids = token_ids[:num_chunks * chunk_length] # Convert to tensor and reshape data_tensor = torch.tensor(token_ids, dtype=torch.long) self.data = data_tensor.view(num_chunks, chunk_length) self.block_size = block_size def __len__(self): return self.data.size(0) def __getitem__(self, idx): seq = self.data[idx] # Tensor of shape (block_size+1,) input_ids = seq[:-1] # All but last token for input target_ids = seq[1:] # All but first token for target return input_ids, target_ids