WCNegentropy commited on
Commit
651de2e
·
verified ·
1 Parent(s): d3e2188

Remove nested directory: BitTransformerLM/progressive_scaleup.py

Browse files
BitTransformerLM/progressive_scaleup.py DELETED
@@ -1,216 +0,0 @@
1
- """Legacy progressive scale-up demo.
2
-
3
- This script is retained for historical reference but has been superseded by
4
- ``integration_schedule.py`` which provides a more flexible scaling workflow.
5
- """
6
-
7
- import argparse
8
- import warnings
9
- import torch
10
- import torch.nn.functional as F
11
- from bit_transformer import (
12
- BitTransformerLM,
13
- configure_optimizer,
14
- expand_model,
15
- text_to_bits,
16
- )
17
- from bit_transformer.training import train_loop as basic_train
18
-
19
- warnings.warn(
20
- "progressive_scaleup.py is deprecated; use integration_schedule.py instead.",
21
- DeprecationWarning,
22
- stacklevel=2,
23
- )
24
-
25
-
26
- def progressive_scale_up(
27
- eps: float = 0.65,
28
- steps: int = 2,
29
- width_mult: float = 1.0,
30
- forward_kwargs: dict | None = None,
31
- ) -> None:
32
- """Demonstrate automatic scaling of the model on random data."""
33
- params = dict(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=16)
34
- model = BitTransformerLM(**params)
35
- steps_per_epoch = 64 // 8
36
- optimizer, scheduler = configure_optimizer(
37
- model, lr=1e-3, total_steps=steps * steps_per_epoch
38
- )
39
-
40
- train = torch.randint(0, 2, (64, params["max_seq_len"]), dtype=torch.long)
41
- valid = torch.randint(0, 2, (16, params["max_seq_len"]), dtype=torch.long)
42
-
43
- for step in range(steps):
44
- # one epoch over train
45
- basic_train(
46
- model,
47
- train,
48
- epochs=1,
49
- compress_prob=0.5,
50
- log=False,
51
- forward_kwargs=forward_kwargs,
52
- )
53
-
54
- with torch.no_grad():
55
- logits, _ = model(valid, **(forward_kwargs or {}))
56
- pred = logits[:, :-1, :].reshape(-1, 2)
57
- target = valid[:, 1:].reshape(-1)
58
- val_loss = F.cross_entropy(pred, target).item()
59
- print(f"Step {step} validation loss: {val_loss:.4f}")
60
- if val_loss < eps:
61
- params["num_layers"] *= 2
62
- params["d_model"] = int(params["d_model"] * width_mult)
63
- params["dim_feedforward"] = int(params["dim_feedforward"] * width_mult)
64
- model = expand_model(model, params)
65
- optimizer, scheduler = configure_optimizer(
66
- model, lr=1e-3, total_steps=steps * steps_per_epoch
67
- )
68
- print(
69
- "Scaled model to", params["num_layers"], "layers and width", params["d_model"]
70
- )
71
-
72
-
73
- def progressive_scale_up_text(
74
- improve_thresh: float = 0.01,
75
- steps: int = 2,
76
- width_mult: float = 2.0,
77
- max_len: int = 64,
78
- dataset_size: int = 512,
79
- forward_kwargs: dict | None = None,
80
- ) -> None:
81
- """Scale up using WikiText2 lines converted to bits.
82
-
83
- Parameters
84
- ----------
85
- improve_thresh: float
86
- Relative validation loss improvement required to avoid scaling.
87
- If improvement is <= this threshold, model size is increased.
88
- steps: int
89
- Number of training steps.
90
- width_mult: float
91
- Multiplier applied when increasing model width.
92
- max_len: int
93
- Initial sequence length.
94
- dataset_size: int
95
- Number of training lines to load from WikiText2.
96
- forward_kwargs: dict | None
97
- Extra keyword arguments for the forward pass.
98
- """
99
- from datasets import load_dataset
100
-
101
- ds = load_dataset("wikitext", "wikitext-2-raw-v1")
102
- train_iter = ds["train"]["text"]
103
- valid_iter = ds["validation"]["text"]
104
-
105
- train_lines = []
106
- for line in train_iter:
107
- train_lines.append(line)
108
- if len(train_lines) >= dataset_size:
109
- break
110
-
111
- valid_lines = []
112
- for line in valid_iter:
113
- valid_lines.append(line)
114
- if len(valid_lines) >= dataset_size // 4:
115
- break
116
-
117
- def lines_to_tensor(lines: list[str], length: int) -> torch.Tensor:
118
- seqs = []
119
- for text in lines:
120
- bits = text_to_bits(text)[:length]
121
- if len(bits) < length:
122
- bits.extend([0] * (length - len(bits)))
123
- seqs.append(bits)
124
- return torch.tensor(seqs, dtype=torch.long)
125
-
126
- train = lines_to_tensor(train_lines, max_len)
127
- valid = lines_to_tensor(valid_lines, max_len)
128
-
129
- params = dict(
130
- d_model=32,
131
- nhead=4,
132
- num_layers=1,
133
- dim_feedforward=64,
134
- max_seq_len=max_len,
135
- )
136
- model = BitTransformerLM(**params)
137
- steps_per_epoch = len(train) // 8
138
- optimizer, scheduler = configure_optimizer(
139
- model, lr=1e-3, total_steps=steps * max(1, steps_per_epoch)
140
- )
141
-
142
- prev_loss: float | None = None
143
- scale_length = True
144
-
145
- for step in range(steps):
146
- basic_train(
147
- model,
148
- train,
149
- epochs=1,
150
- compress_prob=0.5,
151
- log=False,
152
- forward_kwargs=forward_kwargs,
153
- )
154
-
155
- with torch.no_grad():
156
- logits, _ = model(valid, **(forward_kwargs or {}))
157
- pred = logits[:, :-1, :].reshape(-1, 2)
158
- target = valid[:, 1:].reshape(-1)
159
- val_loss = F.cross_entropy(pred, target).item()
160
- print(f"Step {step} validation loss: {val_loss:.4f}")
161
- if prev_loss is not None:
162
- improvement = (prev_loss - val_loss) / max(prev_loss, 1e-8)
163
- if improvement <= improve_thresh:
164
- if scale_length:
165
- params["max_seq_len"] *= 2
166
- train = lines_to_tensor(train_lines, params["max_seq_len"])
167
- valid = lines_to_tensor(valid_lines, params["max_seq_len"])
168
- model = model.double_length()
169
- steps_per_epoch = len(train) // 8
170
- scale_type = "length"
171
- else:
172
- params["d_model"] = int(params["d_model"] * width_mult)
173
- params["dim_feedforward"] = int(params["dim_feedforward"] * width_mult)
174
- model = expand_model(model, params)
175
- scale_type = "width"
176
- optimizer, scheduler = configure_optimizer(
177
- model, lr=1e-3, total_steps=steps * max(1, steps_per_epoch)
178
- )
179
- scale_length = not scale_length
180
- param_count = sum(p.numel() for p in model.parameters())
181
- print(
182
- f"Scaled {scale_type}; seq_len={params['max_seq_len']} width={params['d_model']} params={param_count}"
183
- )
184
- prev_loss = val_loss
185
-
186
-
187
- if __name__ == "__main__":
188
- parser = argparse.ArgumentParser(description="Progressively scale model length and width")
189
- parser.add_argument("--steps", type=int, default=2, help="number of training steps")
190
- parser.add_argument(
191
- "--improve-thresh",
192
- type=float,
193
- default=0.01,
194
- help="relative loss improvement required to avoid scaling",
195
- )
196
- parser.add_argument(
197
- "--width-mult", type=float, default=2.0, help="width multiplier when scaling"
198
- )
199
- parser.add_argument("--causal", action="store_true", help="use causal attention during training")
200
- parser.add_argument("--wikitext", action="store_true", help="use WikiText2 dataset")
201
- args = parser.parse_args()
202
- if args.wikitext:
203
- progressive_scale_up_text(
204
- improve_thresh=args.improve_thresh,
205
- steps=args.steps,
206
- width_mult=args.width_mult,
207
- forward_kwargs={"causal": args.causal} if args.causal else None,
208
- )
209
- else:
210
- progressive_scale_up(
211
- eps=args.improve_thresh,
212
- steps=args.steps,
213
- width_mult=args.width_mult,
214
- forward_kwargs={"causal": args.causal} if args.causal else None,
215
- )
216
-