abhishek4607 commited on
Commit
c1eccb6
·
verified ·
1 Parent(s): 401cd1f

Delete dataloader.py

Browse files
Files changed (1) hide show
  1. dataloader.py +0 -49
dataloader.py DELETED
@@ -1,49 +0,0 @@
1
- import os
2
- import numpy as np
3
- import torch
4
-
5
- script_dir = os.path.dirname(__file__)
6
-
7
-
8
- class DataLoaderLite:
9
- """ A simple dataloader for FineWebEdu-10B dataset """
10
-
11
- def __init__(self, B, T, process_rank, num_processes, split='train'):
12
- super().__init__()
13
- self.B, self.T = B, T
14
- self.process_rank = process_rank
15
- self.num_processes = num_processes
16
- assert split in {'train', 'val'}
17
-
18
- # get the shard filenames
19
- data_root = os.path.join(script_dir, "../data/edu_fineweb10B")
20
- shard_filenames = os.listdir(data_root)
21
- shard_filenames = sorted([filename for filename in shard_filenames if split in filename])
22
- self.shard_filepaths = [os.path.join(data_root, filename) for filename in shard_filenames]
23
- assert len(self.shard_filepaths) > 0, f'no shards found for split {split}'
24
- master_process = process_rank == 0
25
- if master_process:
26
- print(f'found {len(self.shard_filepaths)} shards for split {split}')
27
- self.reset()
28
-
29
- def load_tokens(self, filepath):
30
- tokens = torch.tensor(np.load(filepath).astype(np.int32), dtype=torch.long)
31
- return tokens
32
-
33
- def reset(self):
34
- # state, init at shard 0
35
- self.curr_shard = 0
36
- self.tokens = self.load_tokens(self.shard_filepaths[self.curr_shard])
37
- self.curr_pos = self.B * self.T * self.process_rank
38
-
39
- def next_batch(self):
40
- B, T = self.B, self.T
41
- batch = self.tokens[self.curr_pos : self.curr_pos + B*T + 1]
42
- x_batch = batch[:-1].view(B, T)
43
- y_batch = batch[1:].view(B, T)
44
- self.curr_pos += B * T * self.num_processes
45
- if self.curr_pos + (B * T + 1) > len(self.tokens):
46
- self.curr_shard = (self.curr_shard + 1) % len(self.shard_filepaths)
47
- self.tokens = self.load_tokens(self.shard_filepaths[self.curr_shard])
48
- self.curr_pos = self.B * self.T * self.process_rank
49
- return x_batch, y_batch