| from pathlib import Path | |
| import numpy as np | |
| def save_tokens_bin(tokens, path, dtype=np.uint16): | |
| path=Path(path).expanduser(); path.parent.mkdir(parents=True,exist_ok=True) | |
| np.asarray(tokens,dtype=dtype).tofile(path) | |
| def load_tokens_bin(path, dtype=np.uint16): | |
| return np.memmap(Path(path).expanduser(), dtype=dtype, mode='r') | |
| class TokenBlockDataset: | |
| def __init__(self, bin_path, block_size): | |
| self.data=load_tokens_bin(bin_path); self.block_size=block_size | |
| def __len__(self): return max(0, len(self.data)-self.block_size-1) | |
| def get_batch(self, batch_size, device): | |
| import torch | |
| ix=torch.randint(len(self),(batch_size,)) | |
| x=torch.stack([torch.from_numpy(np.array(self.data[i:i+self.block_size], dtype=np.int64)) for i in ix]) | |
| y=torch.stack([torch.from_numpy(np.array(self.data[i+1:i+1+self.block_size], dtype=np.int64)) for i in ix]) | |
| return x.to(device), y.to(device) | |