CharRNN / src /dataset.py
hoom4n's picture
Upload 18 files
b6447fa verified
raw
history blame contribute delete
939 Bytes
import torch
from torch.utils.data import Dataset
from typing import Callable
class CharRNNDataset(Dataset):
"""Character-level dataset for training RNN language models."""
def __init__(self, text: str, seq_len: int, text_encoder_fn: Callable[[str], torch.Tensor]):
"""
Args:
text (str): Raw text corpus.
seq_len (int): Length of each training sequence.
text_encoder_fn (Callable): Function to encode text into a tensor of IDs.
"""
self.full_sequence = text_encoder_fn(text)
self.seq_len = seq_len
def __len__(self) -> int:
return len(self.full_sequence) - self.seq_len
def __getitem__(self, idx: int):
if idx < 0 or idx >= len(self):
raise IndexError("Index out of range")
end = idx + self.seq_len
x = self.full_sequence[idx:end]
y = self.full_sequence[idx + 1:end + 1]
return x, y