GPT2_Demo / data.py
peeyushsinghal's picture
Gradio App
6285ada verified
import tiktoken
import torch
class DataLoaderLite:
def __init__(self, B, T, file_path, model_type):
self.B = B
self.T = T
# at init load tokens from disk and store them in memory
with open(file_path, "r") as f:
text = f.read()
enc = tiktoken.get_encoding(model_type)
tokens = enc.encode(text)
self.tokens = torch.tensor(tokens)
print(f"loaded {len(self.tokens)} tokens")
print(f"1 epoch = {len(self.tokens) // (B * T)} batches")
# state
self.current_position = 0
def next_batch(self):
B, T = self.B, self.T
buf = self.tokens[self.current_position : self.current_position + B * T + 1]
x = (buf[:-1]).view(B, T) # inputs
y = (buf[1:]).view(B, T) # targets
# advance the position in the tensor
self.current_position += B * T
# if loading the next batch would be out of bounds, reset
if self.current_position + (B * T + 1) > len(self.tokens):
self.current_position = 0
return x, y