| import torch |
| from torch.utils.data import Dataset |
| from typing import Callable |
|
|
| class CharRNNDataset(Dataset): |
| """Character-level dataset for training RNN language models.""" |
|
|
| def __init__(self, text: str, seq_len: int, text_encoder_fn: Callable[[str], torch.Tensor]): |
| """ |
| Args: |
| text (str): Raw text corpus. |
| seq_len (int): Length of each training sequence. |
| text_encoder_fn (Callable): Function to encode text into a tensor of IDs. |
| """ |
| self.full_sequence = text_encoder_fn(text) |
| self.seq_len = seq_len |
|
|
| def __len__(self) -> int: |
| return len(self.full_sequence) - self.seq_len |
|
|
| def __getitem__(self, idx: int): |
| if idx < 0 or idx >= len(self): |
| raise IndexError("Index out of range") |
| end = idx + self.seq_len |
| x = self.full_sequence[idx:end] |
| y = self.full_sequence[idx + 1:end + 1] |
| return x, y |
|
|