Spaces:
Build error
Build error
| import tiktoken | |
| import torch | |
| class DataLoaderLite: | |
| def __init__(self, B, T, file_path, model_type): | |
| self.B = B | |
| self.T = T | |
| # at init load tokens from disk and store them in memory | |
| with open(file_path, "r") as f: | |
| text = f.read() | |
| enc = tiktoken.get_encoding(model_type) | |
| tokens = enc.encode(text) | |
| self.tokens = torch.tensor(tokens) | |
| print(f"loaded {len(self.tokens)} tokens") | |
| print(f"1 epoch = {len(self.tokens) // (B * T)} batches") | |
| # state | |
| self.current_position = 0 | |
| def next_batch(self): | |
| B, T = self.B, self.T | |
| buf = self.tokens[self.current_position : self.current_position + B * T + 1] | |
| x = (buf[:-1]).view(B, T) # inputs | |
| y = (buf[1:]).view(B, T) # targets | |
| # advance the position in the tensor | |
| self.current_position += B * T | |
| # if loading the next batch would be out of bounds, reset | |
| if self.current_position + (B * T + 1) > len(self.tokens): | |
| self.current_position = 0 | |
| return x, y | |