WCNegentropy commited on
Commit
3b00f37
·
verified ·
1 Parent(s): 20b2677

Remove nested directory: BitTransformerLM/wikitext_schedule.py

Browse files
Files changed (1) hide show
  1. BitTransformerLM/wikitext_schedule.py +0 -130
BitTransformerLM/wikitext_schedule.py DELETED
@@ -1,130 +0,0 @@
1
- import numpy as np
2
- import torch
3
- import torch.nn.functional as F
4
- from torch.utils.data import Dataset
5
- from pathlib import Path
6
- from datasets import load_dataset
7
-
8
- from bit_transformer import (
9
- BitTransformerLM,
10
- configure_optimizer,
11
- expand_model,
12
- text_to_bits,
13
- )
14
- from bit_transformer.training import train_loop as basic_train
15
-
16
-
17
- def _build_memmap(lines, path: Path, max_len: int) -> None:
18
- """Precompute bit tensors into a memory-mapped file."""
19
- arr = np.memmap(path, mode="w+", shape=(len(lines), max_len), dtype="uint8")
20
- for idx, text in enumerate(lines):
21
- bits = text_to_bits(text)[:max_len]
22
- if len(bits) < max_len:
23
- bits.extend([0] * (max_len - len(bits)))
24
- arr[idx] = np.array(bits, dtype="uint8")
25
- arr.flush()
26
-
27
-
28
- class MemmapDataset(Dataset):
29
- """Dataset backed by a memory-mapped array."""
30
-
31
- def __init__(self, path: Path, length: int, max_len: int) -> None:
32
- self.path = path
33
- self.length = length
34
- self.max_len = max_len
35
- self._arr = np.memmap(path, mode="r", shape=(length, max_len), dtype="uint8")
36
-
37
- def __len__(self) -> int: # pragma: no cover - trivial
38
- return self.length
39
-
40
- def __getitem__(self, idx: int) -> torch.Tensor:
41
- return torch.from_numpy(self._arr[idx].astype("int64"))
42
-
43
-
44
- def progressive_scale_schedule(steps=12, max_len=64, dataset_size=128):
45
- """Run deterministic scale-up on WikiText data."""
46
- ds = load_dataset("wikitext", "wikitext-2-raw-v1")
47
- train_lines = [t for t in ds["train"]["text"] if t.strip()][:dataset_size]
48
- valid_lines = [t for t in ds["validation"]["text"] if t.strip()][: dataset_size // 4]
49
-
50
- train_path = Path("wikitext_train.memmap")
51
- valid_path = Path("wikitext_valid.memmap")
52
- _build_memmap(train_lines, train_path, max_len)
53
- _build_memmap(valid_lines, valid_path, max_len)
54
-
55
- train = MemmapDataset(train_path, len(train_lines), max_len)
56
- valid = torch.from_numpy(
57
- np.memmap(valid_path, mode="r", shape=(len(valid_lines), max_len), dtype="uint8")
58
- ).long()
59
-
60
- layers = 1
61
- width = 32
62
- params = dict(
63
- d_model=width,
64
- nhead=4,
65
- num_layers=layers,
66
- dim_feedforward=width * 2,
67
- max_seq_len=max_len,
68
- reversible=True,
69
- chunk_size=max_len,
70
- use_autocast=True,
71
- use_act=True,
72
- act_threshold=0.9,
73
- )
74
- model = BitTransformerLM(**params)
75
- steps_per_epoch = max(1, (len(train) + 7) // 8)
76
- optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=(steps + 1) * steps_per_epoch)
77
-
78
- results = []
79
- for step in range(steps + 1):
80
- basic_train(
81
- model,
82
- train,
83
- epochs=1,
84
- compress_prob=0.5,
85
- log=False,
86
- forward_kwargs=None,
87
- num_workers=2,
88
- )
89
-
90
- with torch.no_grad():
91
- logits, _ = model(valid)
92
- pred = logits[:, :-1, :].reshape(-1, 2)
93
- target = valid[:, 1:].reshape(-1)
94
- val_loss = F.cross_entropy(pred, target).item()
95
- print(f"Step {step} validation loss: {val_loss:.4f}")
96
- results.append((step, val_loss))
97
-
98
- if step < steps:
99
- if step % 2 == 0:
100
- layers *= 2
101
- else:
102
- width *= 2
103
- params = dict(
104
- d_model=width,
105
- nhead=4,
106
- num_layers=layers,
107
- dim_feedforward=width * 2,
108
- max_seq_len=max_len,
109
- reversible=True,
110
- chunk_size=max_len,
111
- use_autocast=True,
112
- use_act=True,
113
- act_threshold=0.9,
114
- )
115
- model = expand_model(model, params)
116
- optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=(steps - step) * steps_per_epoch)
117
- print(f"Scaled model to {layers} layers and width {width}")
118
- return results
119
-
120
-
121
- if __name__ == "__main__":
122
- import argparse
123
-
124
- parser = argparse.ArgumentParser(description="Deterministic scale-up benchmark")
125
- parser.add_argument("--steps", type=int, default=12, help="number of scale-up steps")
126
- parser.add_argument("--max-len", type=int, default=64, help="sequence length")
127
- parser.add_argument("--dataset-size", type=int, default=128, help="number of training lines")
128
- args = parser.parse_args()
129
-
130
- progressive_scale_schedule(steps=args.steps, max_len=args.max_len, dataset_size=args.dataset_size)