| | import numpy as np |
| | import torch |
| | import torch.nn.functional as F |
| | from torch.utils.data import Dataset |
| | from pathlib import Path |
| | from datasets import load_dataset |
| |
|
| | from bit_transformer import ( |
| | BitTransformerLM, |
| | configure_optimizer, |
| | expand_model, |
| | text_to_bits, |
| | ) |
| | from bit_transformer.training import train_loop as basic_train |
| |
|
| |
|
| | def _build_memmap(lines, path: Path, max_len: int) -> None: |
| | """Precompute bit tensors into a memory-mapped file.""" |
| | arr = np.memmap(path, mode="w+", shape=(len(lines), max_len), dtype="uint8") |
| | for idx, text in enumerate(lines): |
| | bits = text_to_bits(text)[:max_len] |
| | if len(bits) < max_len: |
| | bits.extend([0] * (max_len - len(bits))) |
| | arr[idx] = np.array(bits, dtype="uint8") |
| | arr.flush() |
| |
|
| |
|
| | class MemmapDataset(Dataset): |
| | """Dataset backed by a memory-mapped array.""" |
| |
|
| | def __init__(self, path: Path, length: int, max_len: int) -> None: |
| | self.path = path |
| | self.length = length |
| | self.max_len = max_len |
| | self._arr = np.memmap(path, mode="r", shape=(length, max_len), dtype="uint8") |
| |
|
| | def __len__(self) -> int: |
| | return self.length |
| |
|
| | def __getitem__(self, idx: int) -> torch.Tensor: |
| | return torch.from_numpy(self._arr[idx].astype("int64")) |
| |
|
| |
|
| | def progressive_scale_schedule(steps=12, max_len=64, dataset_size=128): |
| | """Run deterministic scale-up on WikiText data.""" |
| | ds = load_dataset("wikitext", "wikitext-2-raw-v1") |
| | train_lines = [t for t in ds["train"]["text"] if t.strip()][:dataset_size] |
| | valid_lines = [t for t in ds["validation"]["text"] if t.strip()][: dataset_size // 4] |
| |
|
| | train_path = Path("wikitext_train.memmap") |
| | valid_path = Path("wikitext_valid.memmap") |
| | _build_memmap(train_lines, train_path, max_len) |
| | _build_memmap(valid_lines, valid_path, max_len) |
| |
|
| | train = MemmapDataset(train_path, len(train_lines), max_len) |
| | valid = torch.from_numpy( |
| | np.memmap(valid_path, mode="r", shape=(len(valid_lines), max_len), dtype="uint8") |
| | ).long() |
| |
|
| | layers = 1 |
| | width = 32 |
| | params = dict( |
| | d_model=width, |
| | nhead=4, |
| | num_layers=layers, |
| | dim_feedforward=width * 2, |
| | max_seq_len=max_len, |
| | reversible=True, |
| | chunk_size=max_len, |
| | use_autocast=True, |
| | use_act=True, |
| | act_threshold=0.9, |
| | ) |
| | model = BitTransformerLM(**params) |
| | steps_per_epoch = max(1, (len(train) + 7) // 8) |
| | optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=(steps + 1) * steps_per_epoch) |
| |
|
| | results = [] |
| | for step in range(steps + 1): |
| | basic_train( |
| | model, |
| | train, |
| | epochs=1, |
| | compress_prob=0.5, |
| | log=False, |
| | forward_kwargs=None, |
| | num_workers=2, |
| | ) |
| |
|
| | with torch.no_grad(): |
| | logits, _ = model(valid) |
| | pred = logits[:, :-1, :].reshape(-1, 2) |
| | target = valid[:, 1:].reshape(-1) |
| | val_loss = F.cross_entropy(pred, target).item() |
| | print(f"Step {step} validation loss: {val_loss:.4f}") |
| | results.append((step, val_loss)) |
| |
|
| | if step < steps: |
| | if step % 2 == 0: |
| | layers *= 2 |
| | else: |
| | width *= 2 |
| | params = dict( |
| | d_model=width, |
| | nhead=4, |
| | num_layers=layers, |
| | dim_feedforward=width * 2, |
| | max_seq_len=max_len, |
| | reversible=True, |
| | chunk_size=max_len, |
| | use_autocast=True, |
| | use_act=True, |
| | act_threshold=0.9, |
| | ) |
| | model = expand_model(model, params) |
| | optimizer, scheduler = configure_optimizer(model, lr=1e-3, total_steps=(steps - step) * steps_per_epoch) |
| | print(f"Scaled model to {layers} layers and width {width}") |
| | return results |
| |
|
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| |
|
| | parser = argparse.ArgumentParser(description="Deterministic scale-up benchmark") |
| | parser.add_argument("--steps", type=int, default=12, help="number of scale-up steps") |
| | parser.add_argument("--max-len", type=int, default=64, help="sequence length") |
| | parser.add_argument("--dataset-size", type=int, default=128, help="number of training lines") |
| | args = parser.parse_args() |
| |
|
| | progressive_scale_schedule(steps=args.steps, max_len=args.max_len, dataset_size=args.dataset_size) |
| |
|