import torch from torch.utils.data import Dataset from typing import Callable class CharRNNDataset(Dataset): """Character-level dataset for training RNN language models.""" def __init__(self, text: str, seq_len: int, text_encoder_fn: Callable[[str], torch.Tensor]): """ Args: text (str): Raw text corpus. seq_len (int): Length of each training sequence. text_encoder_fn (Callable): Function to encode text into a tensor of IDs. """ self.full_sequence = text_encoder_fn(text) self.seq_len = seq_len def __len__(self) -> int: return len(self.full_sequence) - self.seq_len def __getitem__(self, idx: int): if idx < 0 or idx >= len(self): raise IndexError("Index out of range") end = idx + self.seq_len x = self.full_sequence[idx:end] y = self.full_sequence[idx + 1:end + 1] return x, y