File size: 940 Bytes
ca2f8ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
from pathlib import Path
import numpy as np

def save_tokens_bin(tokens, path, dtype=np.uint16):
    path=Path(path).expanduser(); path.parent.mkdir(parents=True,exist_ok=True)
    np.asarray(tokens,dtype=dtype).tofile(path)

def load_tokens_bin(path, dtype=np.uint16):
    return np.memmap(Path(path).expanduser(), dtype=dtype, mode='r')

class TokenBlockDataset:
    def __init__(self, bin_path, block_size):
        self.data=load_tokens_bin(bin_path); self.block_size=block_size
    def __len__(self): return max(0, len(self.data)-self.block_size-1)
    def get_batch(self, batch_size, device):
        import torch
        ix=torch.randint(len(self),(batch_size,))
        x=torch.stack([torch.from_numpy(np.array(self.data[i:i+self.block_size], dtype=np.int64)) for i in ix])
        y=torch.stack([torch.from_numpy(np.array(self.data[i+1:i+1+self.block_size], dtype=np.int64)) for i in ix])
        return x.to(device), y.to(device)