WCNegentropy commited on
Commit
68ec438
·
verified ·
1 Parent(s): 90e658a

Remove nested directory: BitTransformerLM/full_bits_train.py

Browse files
Files changed (1) hide show
  1. BitTransformerLM/full_bits_train.py +0 -51
BitTransformerLM/full_bits_train.py DELETED
@@ -1,51 +0,0 @@
1
- import pathlib
2
- import torch
3
- from bit_transformer import BitTransformerLM
4
-
5
- DATA_PATH = pathlib.Path('full_bits.pt')
6
-
7
- class BitSeq(torch.utils.data.IterableDataset):
8
- def __init__(self, path: str | pathlib.Path = DATA_PATH, seq: int = 2048) -> None:
9
- self.bits = torch.load(path, mmap=True)
10
- self.seq = seq
11
-
12
- def __len__(self) -> int:
13
- return (self.bits.numel() // self.seq) - 1
14
-
15
- def __iter__(self):
16
- N = (self.bits.numel() // self.seq) - 1
17
- for i in range(N):
18
- s = i * self.seq
19
- yield (
20
- self.bits[s:s+self.seq].long(),
21
- self.bits[s+1:s+self.seq+1].long(),
22
- )
23
-
24
- def main() -> None:
25
- dl = torch.utils.data.DataLoader(
26
- BitSeq(DATA_PATH, seq=2048),
27
- batch_size=8,
28
- num_workers=0,
29
- pin_memory=False,
30
- )
31
-
32
- model = BitTransformerLM(
33
- d_model=64,
34
- nhead=4,
35
- num_layers=2,
36
- dim_feedforward=256,
37
- max_seq_len=2048,
38
- reversible=True,
39
- use_autocast=True,
40
- )
41
-
42
- loss_fn = torch.nn.CrossEntropyLoss()
43
- xb, yb = next(iter(dl))
44
- logits, _ = model(xb)
45
- pred = logits.reshape(-1, 2)
46
- target = yb.reshape(-1)
47
- loss = loss_fn(pred, target)
48
- print('Batch loss:', float(loss))
49
-
50
- if __name__ == '__main__':
51
- main()