Spaces:
Build error
Build error
File size: 1,076 Bytes
6285ada |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
import tiktoken
import torch
class DataLoaderLite:
def __init__(self, B, T, file_path, model_type):
self.B = B
self.T = T
# at init load tokens from disk and store them in memory
with open(file_path, "r") as f:
text = f.read()
enc = tiktoken.get_encoding(model_type)
tokens = enc.encode(text)
self.tokens = torch.tensor(tokens)
print(f"loaded {len(self.tokens)} tokens")
print(f"1 epoch = {len(self.tokens) // (B * T)} batches")
# state
self.current_position = 0
def next_batch(self):
B, T = self.B, self.T
buf = self.tokens[self.current_position : self.current_position + B * T + 1]
x = (buf[:-1]).view(B, T) # inputs
y = (buf[1:]).view(B, T) # targets
# advance the position in the tensor
self.current_position += B * T
# if loading the next batch would be out of bounds, reset
if self.current_position + (B * T + 1) > len(self.tokens):
self.current_position = 0
return x, y
|