WCNegentropy commited on
Commit
1deb983
·
verified ·
1 Parent(s): d0d145b

Remove nested directory: BitTransformerLM/bit_transformer/training.py

Browse files
BitTransformerLM/bit_transformer/training.py DELETED
@@ -1,250 +0,0 @@
1
- """Common training utilities for BitTransformer models."""
2
-
3
- from __future__ import annotations
4
-
5
- from typing import Callable, Dict, List, Optional
6
- import contextlib
7
- import sys
8
- import warnings
9
- import math
10
-
11
- import torch
12
- import torch.nn.functional as F
13
- from torch.utils.data import DataLoader
14
-
15
- from .compression import compress_bits, pack_bits, unpack_bits
16
- from .optimization import configure_optimizer
17
- from .model import BitTransformerLM
18
- from .utils import set_dropout
19
- from .torch_utils import cpu_autocast
20
-
21
-
22
- def cosine_ramp(step: int, start: float, end: float, total_steps: int) -> float:
23
- """Cosine ramp from ``start`` to ``end`` over ``total_steps``."""
24
- if total_steps <= 0 or step >= total_steps:
25
- return end
26
- cos_inner = math.pi * step / total_steps
27
- return start + (end - start) * (1 - math.cos(cos_inner)) / 2
28
-
29
-
30
- def train_loop(
31
- model: BitTransformerLM,
32
- data: torch.Tensor,
33
- *,
34
- epochs: int = 1,
35
- extra_steps: int = 0,
36
- compress_prob: float = 0.5,
37
- direct_prob: float = 0.0,
38
- batch_size: int = 8,
39
- num_workers: int = 0,
40
- accum_steps: int = 1,
41
- amp: bool = False,
42
- compile_model: bool = False,
43
- log: bool = False,
44
- forward_kwargs: Optional[Dict] = None,
45
- optimizer: Optional[torch.optim.Optimizer] = None,
46
- scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
47
- diffusion: bool = False,
48
- noise_fn: Optional[Callable[[], float]] = None,
49
- diffusion_curriculum: bool = False,
50
- compress_warmup: int = 0,
51
- ) -> List[Dict[str, float]]:
52
- """Generic training loop supporting optional compression and diffusion.
53
-
54
- ``compress_prob`` controls the fraction of batches that are run through
55
- ``forward_compressed``. ``direct_prob`` instead feeds the model with the
56
- bit-packed result of ``compress_bits`` after converting back to a bit
57
- tensor. When enabled, metrics for direct-compressed batches are tracked
58
- separately.
59
-
60
- When ``diffusion`` is ``True`` the loop performs denoising training. Batches
61
- are noised by randomly flipping bits with a probability given by
62
- ``noise_fn`` (defaulting to a uniform draw in ``[0, 0.5]``). When
63
- ``diffusion_curriculum`` is ``True`` the noise probability decreases
64
- linearly from ``0.5`` to ``0.0`` over the training epochs. The model is
65
- then trained to recover the clean sequence using full-context attention
66
- (``causal=False``).
67
-
68
- Existing ``optimizer`` and ``scheduler`` instances may be supplied to allow
69
- integration with long-running training sessions, otherwise new ones are
70
- created automatically.
71
- """
72
- if compile_model and sys.version_info < (3, 12) and torch.__version__ >= "2.1":
73
- model = torch.compile(model)
74
- elif compile_model:
75
- warnings.warn("torch.compile skipped: requires torch>=2.1 and Python<3.12")
76
-
77
- model.train()
78
- set_dropout(model, 0.1)
79
-
80
- device = next(model.parameters()).device
81
- loader = DataLoader(
82
- data,
83
- batch_size=batch_size,
84
- shuffle=True,
85
- num_workers=num_workers,
86
- persistent_workers=num_workers > 0,
87
- )
88
- steps_per_epoch = max(1, len(loader))
89
- total_updates = math.ceil(epochs * (steps_per_epoch + extra_steps) / accum_steps)
90
- if optimizer is None or scheduler is None:
91
- optimizer, scheduler = configure_optimizer(
92
- model, lr=1e-3, total_steps=total_updates
93
- )
94
- metrics: List[Dict[str, float]] = []
95
-
96
- global_step = 0
97
- for epoch in range(epochs):
98
- raw_losses: List[float] = []
99
- raw_accs: List[float] = []
100
- comp_losses: List[float] = []
101
- comp_accs: List[float] = []
102
- comp_ratios: List[float] = []
103
- direct_losses: List[float] = []
104
-
105
- last_batch = None
106
- for step, batch in enumerate(loader):
107
- last_batch = batch
108
- batch = batch.to(device)
109
- cur_compress = (
110
- cosine_ramp(global_step, 0.0, compress_prob, compress_warmup)
111
- if not diffusion
112
- else compress_prob
113
- )
114
- if diffusion:
115
- if diffusion_curriculum:
116
- p = 0.5 * (1 - epoch / max(1, epochs - 1))
117
- else:
118
- p = noise_fn() if noise_fn is not None else float(torch.rand(()) * 0.5)
119
- noise = (torch.rand_like(batch.float()) < p).long()
120
- noisy = batch ^ noise
121
- with (
122
- torch.cuda.amp.autocast(dtype=torch.bfloat16)
123
- if amp and torch.cuda.is_available()
124
- else cpu_autocast() if amp else contextlib.nullcontext()
125
- ):
126
- logits, _ = model(noisy, causal=False)
127
- pred = logits.reshape(-1, 2)
128
- target = batch.reshape(-1)
129
- loss = F.cross_entropy(pred, target) / accum_steps
130
- acc = (pred.argmax(dim=-1) == target).float().mean().item()
131
- raw_losses.append(loss.item() * accum_steps)
132
- raw_accs.append(acc)
133
- loss.backward()
134
- if (step + 1) % accum_steps == 0:
135
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
136
- optimizer.step()
137
- scheduler.step()
138
- optimizer.zero_grad()
139
- global_step += 1
140
- continue
141
-
142
- r = torch.rand(())
143
- key = "raw"
144
- ratio = 1.0
145
- target = batch[:, 1:].reshape(-1)
146
-
147
- if r < direct_prob:
148
- packed = [pack_bits(row.to(torch.uint8)) for row in batch]
149
- unpacked = [unpack_bits(p, n_bits=batch.size(1)) for p in packed]
150
- max_len = min(
151
- max(u.numel() for u in unpacked),
152
- model.pos_enc.pe.size(0),
153
- )
154
- padded = [F.pad(u[:max_len], (0, max_len - min(u.numel(), max_len))) for u in unpacked]
155
- dc_batch = torch.stack(padded).long()
156
- with (
157
- torch.cuda.amp.autocast(dtype=torch.bfloat16)
158
- if amp and torch.cuda.is_available()
159
- else cpu_autocast() if amp else contextlib.nullcontext()
160
- ):
161
- logits, _ = model(dc_batch, **(forward_kwargs or {}))
162
- ratio = sum(p.numel() for p in packed) / batch.numel()
163
- target = dc_batch[:, 1:].reshape(-1)
164
- key = "direct"
165
- elif r < direct_prob + cur_compress:
166
- comp_batch = [compress_bits(row.to(torch.uint8)) for row in batch]
167
- with (
168
- torch.cuda.amp.autocast(dtype=torch.bfloat16)
169
- if amp and torch.cuda.is_available()
170
- else cpu_autocast() if amp else contextlib.nullcontext()
171
- ):
172
- logits, _ = model.forward_compressed(comp_batch, **(forward_kwargs or {}))
173
- ratio = sum(c.numel() for c in comp_batch) / batch.numel()
174
- target = batch[:, 1:].reshape(-1)
175
- key = "compressed"
176
- else:
177
- with (
178
- torch.cuda.amp.autocast(dtype=torch.bfloat16)
179
- if amp and torch.cuda.is_available()
180
- else cpu_autocast() if amp else contextlib.nullcontext()
181
- ):
182
- logits, _ = model(batch, **(forward_kwargs or {}))
183
-
184
- pred = logits[:, :-1, :].reshape(-1, 2)
185
- loss = F.cross_entropy(pred, target) / accum_steps
186
- acc = (pred.argmax(dim=-1) == target).float().mean().item()
187
-
188
- loss.backward()
189
- if (step + 1) % accum_steps == 0:
190
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
191
- optimizer.step()
192
- scheduler.step()
193
- optimizer.zero_grad()
194
- global_step += 1
195
-
196
- if key == "compressed":
197
- comp_losses.append(loss.item() * accum_steps)
198
- comp_accs.append(acc)
199
- comp_ratios.append(ratio)
200
- elif key == "direct":
201
- direct_losses.append(loss.item() * accum_steps)
202
- comp_ratios.append(ratio)
203
- else:
204
- raw_losses.append(loss.item() * accum_steps)
205
- raw_accs.append(acc)
206
-
207
- # run extra gradient updates using the final batch
208
- if extra_steps > 0 and last_batch is not None and not diffusion:
209
- for step in range(extra_steps):
210
- with (
211
- torch.cuda.amp.autocast(dtype=torch.bfloat16)
212
- if amp and torch.cuda.is_available()
213
- else cpu_autocast() if amp else contextlib.nullcontext()
214
- ):
215
- logits, _ = model(last_batch, **(forward_kwargs or {}))
216
- pred = logits[:, :-1, :].reshape(-1, 2)
217
- target = last_batch[:, 1:].reshape(-1)
218
- loss = F.cross_entropy(pred, target) / accum_steps
219
- acc = (pred.argmax(dim=-1) == target).float().mean().item()
220
- loss.backward()
221
- if (step + 1) % accum_steps == 0:
222
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
223
- optimizer.step()
224
- scheduler.step()
225
- optimizer.zero_grad()
226
- raw_losses.append(loss.item() * accum_steps)
227
- raw_accs.append(acc)
228
- global_step += 1
229
-
230
- m = {
231
- "raw_loss": float(sum(raw_losses) / len(raw_losses)) if raw_losses else 0.0,
232
- "raw_acc": float(sum(raw_accs) / len(raw_accs)) if raw_accs else 0.0,
233
- "compressed_loss": float(sum(comp_losses) / len(comp_losses)) if comp_losses else 0.0,
234
- "compressed_acc": float(sum(comp_accs) / len(comp_accs)) if comp_accs else 0.0,
235
- "direct_loss": float(sum(direct_losses) / len(direct_losses)) if direct_losses else 0.0,
236
- "compression_ratio": float(sum(comp_ratios) / len(comp_ratios)) if comp_ratios else 0.0,
237
- }
238
- metrics.append(m)
239
-
240
- if log:
241
- print(
242
- f"Epoch {epoch} "
243
- f"raw_loss={m['raw_loss']:.4f} acc={m['raw_acc']:.3f} | "
244
- f"compressed_loss={m['compressed_loss']:.4f} acc={m['compressed_acc']:.3f} "
245
- f"direct_loss={m['direct_loss']:.4f} ratio={m['compression_ratio']:.2f}"
246
- )
247
-
248
- return metrics
249
-
250
- __all__ = ["train_loop"]