dexifried commited on
Commit
a6ecbb1
·
0 Parent(s):

Sovereign Deploy: Live-Streaming H200 Terminal

Browse files
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Dex Evolution Outpost
3
+ emoji: 🧬
4
+ colorFrom: blue
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 5.0.1
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
app.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+
4
+ # --- ⚔️ THE SOVEREIGN PATCH ---
5
+ import huggingface_hub
6
+ if not hasattr(huggingface_hub, 'HfFolder'):
7
+ class SovereignHfFolder:
8
+ @staticmethod
9
+ def get_token(): return None
10
+ huggingface_hub.HfFolder = SovereignHfFolder
11
+ sys.modules['huggingface_hub.HfFolder'] = SovereignHfFolder
12
+
13
+ import subprocess
14
+ import gradio as gr
15
+ import spaces
16
+ import time
17
+
18
+ REPO_DIR = os.path.join(os.getcwd(), "autoresearch")
19
+
20
+ @spaces.GPU(duration=300) # Maximum ZeroGPU burst (5 Minutes)
21
+ def execute_on_h200(command, openrouter_key):
22
+ if not command:
23
+ yield "❌ Please enter a command."
24
+ return
25
+
26
+ env = os.environ.copy()
27
+ if openrouter_key:
28
+ env["OPENAI_API_KEY"] = openrouter_key
29
+ env["OPENAI_BASE_URL"] = "https://openrouter.ai/api/v1"
30
+
31
+ # Force Python to stream logs instantly instead of buffering them
32
+ env["PYTHONUNBUFFERED"] = "1"
33
+
34
+ output_log = f"[*] Attaching H200 and executing: {command}\n\n"
35
+ yield output_log
36
+
37
+ try:
38
+ # Popen allows us to stream the output live
39
+ process = subprocess.Popen(
40
+ command,
41
+ shell=True,
42
+ cwd=REPO_DIR,
43
+ env=env,
44
+ stdout=subprocess.PIPE,
45
+ stderr=subprocess.STDOUT,
46
+ text=True,
47
+ bufsize=1,
48
+ universal_newlines=True
49
+ )
50
+
51
+ # Stream the output line-by-line to the Gradio UI
52
+ for line in iter(process.stdout.readline, ''):
53
+ output_log += line
54
+ yield output_log
55
+
56
+ process.wait()
57
+ output_log += f"\n--- EXIT CODE: {process.returncode} ---"
58
+ yield output_log
59
+
60
+ except Exception as e:
61
+ yield output_log + f"\n❌ CRASH: {str(e)}"
62
+
63
+ def get_readme():
64
+ try:
65
+ with open(os.path.join(REPO_DIR, "README.md"), "r") as f:
66
+ return f.read()
67
+ except:
68
+ return "README not found."
69
+
70
+ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
71
+ gr.Markdown("# 🧬 Dex Sovereign H200 Terminal (Live Stream Edition)")
72
+ gr.Markdown("Direct command-line execution on the Hugging Face ZeroGPU. Output streams live.")
73
+
74
+ with gr.Row():
75
+ with gr.Column(scale=2):
76
+ cmd_input = gr.Textbox(
77
+ value="ls -la && python prepare.py",
78
+ label="Execute Command on H200"
79
+ )
80
+ or_key = gr.Textbox(label="OpenRouter Key (Optional, bypasses OpenAI)", type="password")
81
+ btn = gr.Button("🚀 Run Command on GPU", variant="primary")
82
+ with gr.Column(scale=3):
83
+ output = gr.Textbox(label="Live Terminal Output", lines=20, max_lines=40)
84
+
85
+ gr.Markdown("### 📖 Framework Documentation (README.md)")
86
+ gr.Markdown(get_readme())
87
+
88
+ btn.click(fn=execute_on_h200, inputs=[cmd_input, or_key], outputs=[output])
89
+
90
+ if __name__ == "__main__":
91
+ demo.launch()
autoresearch/.gitignore ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+
9
+ # Virtual environments
10
+ .venv
11
+ worktrees/
12
+ results/
13
+ queue/
14
+
15
+ # Agent prompt files (generated per-session by launchers)
16
+ CLAUDE.md
17
+ AGENTS.md
18
+
19
+ # Experimental code/artifacts
20
+ dev/
21
+
22
+ # Results file
23
+ results.tsv
autoresearch/.python-version ADDED
@@ -0,0 +1 @@
 
 
1
+ 3.10
autoresearch/README.md ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # autoresearch
2
+
3
+ ![teaser](progress.png)
4
+
5
+ *One day, frontier AI research used to be done by meat computers in between eating, sleeping, having other fun, and synchronizing once in a while using sound wave interconnect in the ritual of "group meeting". That era is long gone. Research is now entirely the domain of autonomous swarms of AI agents running across compute cluster megastructures in the skies. The agents claim that we are now in the 10,205th generation of the code base, in any case no one could tell if that's right or wrong as the "code" is now a self-modifying binary that has grown beyond human comprehension. This repo is the story of how it all began. -@karpathy, March 2026*.
6
+
7
+ The idea: give an AI agent a small but real LLM training setup and let it experiment autonomously overnight. It modifies the code, trains for 5 minutes, checks if the result improved, keeps or discards, and repeats. You wake up in the morning to a log of experiments and (hopefully) a better model. The training code here is a simplified single-GPU implementation of [nanochat](https://github.com/karpathy/nanochat). The core idea is that you're not touching any of the Python files like you normally would as a researcher. Instead, you are programming the `program.md` Markdown files that provide context to the AI agents and set up your autonomous research org. The default `program.md` in this repo is intentionally kept as a bare bones baseline, though it's obvious how one would iterate on it over time to find the "research org code" that achieves the fastest research progress, how you'd add more agents to the mix, etc. A bit more context on this project is here in this [tweet](https://x.com/karpathy/status/2029701092347630069).
8
+
9
+ ## How it works
10
+
11
+ The repo is deliberately kept small and only really has three files that matter:
12
+
13
+ - **`prepare.py`** — fixed constants, one-time data prep (downloads training data, trains a BPE tokenizer), and runtime utilities (dataloader, evaluation). Not modified.
14
+ - **`train.py`** — the single file the agent edits. Contains the full GPT model, optimizer (Muon + AdamW), and training loop. Everything is fair game: architecture, hyperparameters, optimizer, batch size, etc. **This file is edited and iterated on by the agent**.
15
+ - **`program.md`** — baseline instructions for one agent. Point your agent here and let it go. **This file is edited and iterated on by the human**.
16
+
17
+ By design, training runs for a **fixed 5-minute time budget** (wall clock, excluding startup/compilation), regardless of the details of your compute. The metric is **val_bpb** (validation bits per byte) — lower is better, and vocab-size-independent so architectural changes are fairly compared.
18
+
19
+ If you are new to neural networks, this ["Dummy's Guide"](https://x.com/hooeem/status/2030720614752039185) looks pretty good for a lot more context.
20
+
21
+ ## Quick start
22
+
23
+ **Requirements:** A single NVIDIA GPU (tested on H100), Python 3.10+, [uv](https://docs.astral.sh/uv/).
24
+
25
+ ```bash
26
+
27
+ # 1. Install uv project manager (if you don't already have it)
28
+ curl -LsSf https://astral.sh/uv/install.sh | sh
29
+
30
+ # 2. Install dependencies
31
+ uv sync
32
+
33
+ # 3. Download data and train tokenizer (one-time, ~2 min)
34
+ uv run prepare.py
35
+
36
+ # 4. Manually run a single training experiment (~5 min)
37
+ uv run train.py
38
+ ```
39
+
40
+ If the above commands all work ok, your setup is working and you can go into autonomous research mode.
41
+
42
+ ## Running the agent
43
+
44
+ Simply spin up your Claude/Codex or whatever you want in this repo (and disable all permissions), then you can prompt something like:
45
+
46
+ ```
47
+ Hi have a look at program.md and let's kick off a new experiment! let's do the setup first.
48
+ ```
49
+
50
+ The `program.md` file is essentially a super lightweight "skill".
51
+
52
+ ## Project structure
53
+
54
+ ```
55
+ prepare.py — constants, data prep + runtime utilities (do not modify)
56
+ train.py — model, optimizer, training loop (agent modifies this)
57
+ program.md — agent instructions
58
+ pyproject.toml — dependencies
59
+ ```
60
+
61
+ ## Design choices
62
+
63
+ - **Single file to modify.** The agent only touches `train.py`. This keeps the scope manageable and diffs reviewable.
64
+ - **Fixed time budget.** Training always runs for exactly 5 minutes, regardless of your specific platform. This means you can expect approx 12 experiments/hour and approx 100 experiments while you sleep. There are two upsides of this design decision. First, this makes experiments directly comparable regardless of what the agent changes (model size, batch size, architecture, etc). Second, this means that autoresearch will find the most optimal model for your platform in that time budget. The downside is that your runs (and results) become not comparable to other people running on other compute platforms.
65
+ - **Self-contained.** No external dependencies beyond PyTorch and a few small packages. No distributed training, no complex configs. One GPU, one file, one metric.
66
+
67
+ ## Platform support
68
+
69
+ This code currently requires that you have a single NVIDIA GPU. In principle it is quite possible to support CPU, MPS and other platforms but this would also bloat the code. I'm not 100% sure that I want to take this on personally right now. People can reference (or have their agents reference) the full/parent nanochat repository that has wider platform support and shows the various solutions (e.g. a Flash Attention 3 kernels fallback implementation, generic device support, autodetection, etc.), feel free to create forks or discussions for other platforms and I'm happy to link to them here in the README in some new notable forks section or etc.
70
+
71
+ Seeing as there seems to be a lot of interest in tinkering with autoresearch on much smaller compute platforms than an H100, a few extra words. If you're going to try running autoresearch on smaller computers (Macbooks etc.), I'd recommend one of the forks below. On top of this, here are some recommendations for how to tune the defaults for much smaller models for aspiring forks:
72
+
73
+ 1. To get half-decent results I'd use a dataset with a lot less entropy, e.g. this [TinyStories dataset](https://huggingface.co/datasets/karpathy/tinystories-gpt4-clean). These are GPT-4 generated short stories. Because the data is a lot narrower in scope, you will see reasonable results with a lot smaller models (if you try to sample from them after training).
74
+ 2. You might experiment with decreasing `vocab_size`, e.g. from 8192 down to 4096, 2048, 1024, or even - simply byte-level tokenizer with 256 possibly bytes after utf-8 encoding.
75
+ 3. In `prepare.py`, you'll want to lower `MAX_SEQ_LEN` a lot, depending on the computer even down to 256 etc. As you lower `MAX_SEQ_LEN`, you may want to experiment with increasing `DEVICE_BATCH_SIZE` in `train.py` slightly to compensate. The number of tokens per fwd/bwd pass is the product of these two.
76
+ 4. Also in `prepare.py`, you'll want to decrease `EVAL_TOKENS` so that your validation loss is evaluated on a lot less data.
77
+ 5. In `train.py`, the primary single knob that controls model complexity is the `DEPTH` (default 8, here). A lot of variables are just functions of this, so e.g. lower it down to e.g. 4.
78
+ 6. You'll want to most likely use `WINDOW_PATTERN` of just "L", because "SSSL" uses alternating banded attention pattern that may be very inefficient for you. Try it.
79
+ 7. You'll want to lower `TOTAL_BATCH_SIZE` a lot, but keep it powers of 2, e.g. down to `2**14` (~16K) or so even, hard to tell.
80
+
81
+ I think these would be the reasonable hyperparameters to play with. Ask your favorite coding agent for help and copy paste them this guide, as well as the full source code.
82
+
83
+ ## Notable forks
84
+
85
+ - [miolini/autoresearch-macos](https://github.com/miolini/autoresearch-macos) (MacOS)
86
+ - [trevin-creator/autoresearch-mlx](https://github.com/trevin-creator/autoresearch-mlx) (MacOS)
87
+ - [jsegov/autoresearch-win-rtx](https://github.com/jsegov/autoresearch-win-rtx) (Windows)
88
+
89
+ ## License
90
+
91
+ MIT
autoresearch/prepare.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ One-time data preparation for autoresearch experiments.
3
+ Downloads data shards and trains a BPE tokenizer.
4
+
5
+ Usage:
6
+ python prepare.py # full prep (download + tokenizer)
7
+ python prepare.py --num-shards 8 # download only 8 shards (for testing)
8
+
9
+ Data and tokenizer are stored in ~/.cache/autoresearch/.
10
+ """
11
+
12
+ import os
13
+ import sys
14
+ import time
15
+ import math
16
+ import argparse
17
+ import pickle
18
+ from multiprocessing import Pool
19
+
20
+ import requests
21
+ import pyarrow.parquet as pq
22
+ import rustbpe
23
+ import tiktoken
24
+ import torch
25
+
26
+ # ---------------------------------------------------------------------------
27
+ # Constants (fixed, do not modify)
28
+ # ---------------------------------------------------------------------------
29
+
30
+ MAX_SEQ_LEN = 2048 # context length
31
+ TIME_BUDGET = 300 # training time budget in seconds (5 minutes)
32
+ EVAL_TOKENS = 40 * 524288 # number of tokens for val eval
33
+
34
+ # ---------------------------------------------------------------------------
35
+ # Configuration
36
+ # ---------------------------------------------------------------------------
37
+
38
+ CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "autoresearch")
39
+ DATA_DIR = os.path.join(CACHE_DIR, "data")
40
+ TOKENIZER_DIR = os.path.join(CACHE_DIR, "tokenizer")
41
+ BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main"
42
+ MAX_SHARD = 6542 # the last datashard is shard_06542.parquet
43
+ VAL_SHARD = MAX_SHARD # pinned validation shard (shard_06542)
44
+ VAL_FILENAME = f"shard_{VAL_SHARD:05d}.parquet"
45
+ VOCAB_SIZE = 8192
46
+
47
+ # BPE split pattern (GPT-4 style, with \p{N}{1,2} instead of {1,3})
48
+ SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
49
+
50
+ SPECIAL_TOKENS = [f"<|reserved_{i}|>" for i in range(4)]
51
+ BOS_TOKEN = "<|reserved_0|>"
52
+
53
+ # ---------------------------------------------------------------------------
54
+ # Data download
55
+ # ---------------------------------------------------------------------------
56
+
57
+ def download_single_shard(index):
58
+ """Download one parquet shard with retries. Returns True on success."""
59
+ filename = f"shard_{index:05d}.parquet"
60
+ filepath = os.path.join(DATA_DIR, filename)
61
+ if os.path.exists(filepath):
62
+ return True
63
+
64
+ url = f"{BASE_URL}/{filename}"
65
+ max_attempts = 5
66
+ for attempt in range(1, max_attempts + 1):
67
+ try:
68
+ response = requests.get(url, stream=True, timeout=30)
69
+ response.raise_for_status()
70
+ temp_path = filepath + ".tmp"
71
+ with open(temp_path, "wb") as f:
72
+ for chunk in response.iter_content(chunk_size=1024 * 1024):
73
+ if chunk:
74
+ f.write(chunk)
75
+ os.rename(temp_path, filepath)
76
+ print(f" Downloaded {filename}")
77
+ return True
78
+ except (requests.RequestException, IOError) as e:
79
+ print(f" Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
80
+ for path in [filepath + ".tmp", filepath]:
81
+ if os.path.exists(path):
82
+ try:
83
+ os.remove(path)
84
+ except OSError:
85
+ pass
86
+ if attempt < max_attempts:
87
+ time.sleep(2 ** attempt)
88
+ return False
89
+
90
+
91
+ def download_data(num_shards, download_workers=8):
92
+ """Download training shards + pinned validation shard."""
93
+ os.makedirs(DATA_DIR, exist_ok=True)
94
+ num_train = min(num_shards, MAX_SHARD)
95
+ ids = list(range(num_train))
96
+ if VAL_SHARD not in ids:
97
+ ids.append(VAL_SHARD)
98
+
99
+ # Count what's already downloaded
100
+ existing = sum(1 for i in ids if os.path.exists(os.path.join(DATA_DIR, f"shard_{i:05d}.parquet")))
101
+ if existing == len(ids):
102
+ print(f"Data: all {len(ids)} shards already downloaded at {DATA_DIR}")
103
+ return
104
+
105
+ needed = len(ids) - existing
106
+ print(f"Data: downloading {needed} shards ({existing} already exist)...")
107
+
108
+ workers = max(1, min(download_workers, needed))
109
+ with Pool(processes=workers) as pool:
110
+ results = pool.map(download_single_shard, ids)
111
+
112
+ ok = sum(1 for r in results if r)
113
+ print(f"Data: {ok}/{len(ids)} shards ready at {DATA_DIR}")
114
+
115
+ # ---------------------------------------------------------------------------
116
+ # Tokenizer training
117
+ # ---------------------------------------------------------------------------
118
+
119
+ def list_parquet_files():
120
+ """Return sorted list of parquet file paths in the data directory."""
121
+ files = sorted(f for f in os.listdir(DATA_DIR) if f.endswith(".parquet") and not f.endswith(".tmp"))
122
+ return [os.path.join(DATA_DIR, f) for f in files]
123
+
124
+
125
+ def text_iterator(max_chars=1_000_000_000, doc_cap=10_000):
126
+ """Yield documents from training split (all shards except pinned val shard)."""
127
+ parquet_paths = [p for p in list_parquet_files() if not p.endswith(VAL_FILENAME)]
128
+ nchars = 0
129
+ for filepath in parquet_paths:
130
+ pf = pq.ParquetFile(filepath)
131
+ for rg_idx in range(pf.num_row_groups):
132
+ rg = pf.read_row_group(rg_idx)
133
+ for text in rg.column("text").to_pylist():
134
+ doc = text[:doc_cap] if len(text) > doc_cap else text
135
+ nchars += len(doc)
136
+ yield doc
137
+ if nchars >= max_chars:
138
+ return
139
+
140
+
141
+ def train_tokenizer():
142
+ """Train BPE tokenizer using rustbpe, save as tiktoken pickle."""
143
+ tokenizer_pkl = os.path.join(TOKENIZER_DIR, "tokenizer.pkl")
144
+ token_bytes_path = os.path.join(TOKENIZER_DIR, "token_bytes.pt")
145
+
146
+ if os.path.exists(tokenizer_pkl) and os.path.exists(token_bytes_path):
147
+ print(f"Tokenizer: already trained at {TOKENIZER_DIR}")
148
+ return
149
+
150
+ os.makedirs(TOKENIZER_DIR, exist_ok=True)
151
+
152
+ parquet_files = list_parquet_files()
153
+ if len(parquet_files) < 2:
154
+ print("Tokenizer: need at least 2 data shards (1 train + 1 val). Download more data first.")
155
+ sys.exit(1)
156
+
157
+ # --- Train with rustbpe ---
158
+ print("Tokenizer: training BPE tokenizer...")
159
+ t0 = time.time()
160
+
161
+ tokenizer = rustbpe.Tokenizer()
162
+ vocab_size_no_special = VOCAB_SIZE - len(SPECIAL_TOKENS)
163
+ tokenizer.train_from_iterator(text_iterator(), vocab_size_no_special, pattern=SPLIT_PATTERN)
164
+
165
+ # Build tiktoken encoding from trained merges
166
+ pattern = tokenizer.get_pattern()
167
+ mergeable_ranks = {bytes(k): v for k, v in tokenizer.get_mergeable_ranks()}
168
+ tokens_offset = len(mergeable_ranks)
169
+ special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
170
+ enc = tiktoken.Encoding(
171
+ name="rustbpe",
172
+ pat_str=pattern,
173
+ mergeable_ranks=mergeable_ranks,
174
+ special_tokens=special_tokens,
175
+ )
176
+
177
+ # Save tokenizer
178
+ with open(tokenizer_pkl, "wb") as f:
179
+ pickle.dump(enc, f)
180
+
181
+ t1 = time.time()
182
+ print(f"Tokenizer: trained in {t1 - t0:.1f}s, saved to {tokenizer_pkl}")
183
+
184
+ # --- Build token_bytes lookup for BPB evaluation ---
185
+ print("Tokenizer: building token_bytes lookup...")
186
+ special_set = set(SPECIAL_TOKENS)
187
+ token_bytes_list = []
188
+ for token_id in range(enc.n_vocab):
189
+ token_str = enc.decode([token_id])
190
+ if token_str in special_set:
191
+ token_bytes_list.append(0)
192
+ else:
193
+ token_bytes_list.append(len(token_str.encode("utf-8")))
194
+ token_bytes_tensor = torch.tensor(token_bytes_list, dtype=torch.int32)
195
+ torch.save(token_bytes_tensor, token_bytes_path)
196
+ print(f"Tokenizer: saved token_bytes to {token_bytes_path}")
197
+
198
+ # Sanity check
199
+ test = "Hello world! Numbers: 123. Unicode: 你好"
200
+ encoded = enc.encode_ordinary(test)
201
+ decoded = enc.decode(encoded)
202
+ assert decoded == test, f"Tokenizer roundtrip failed: {test!r} -> {decoded!r}"
203
+ print(f"Tokenizer: sanity check passed (vocab_size={enc.n_vocab})")
204
+
205
+ # ---------------------------------------------------------------------------
206
+ # Runtime utilities (imported by train.py)
207
+ # ---------------------------------------------------------------------------
208
+
209
+ class Tokenizer:
210
+ """Minimal tokenizer wrapper. Training is handled above."""
211
+
212
+ def __init__(self, enc):
213
+ self.enc = enc
214
+ self.bos_token_id = enc.encode_single_token(BOS_TOKEN)
215
+
216
+ @classmethod
217
+ def from_directory(cls, tokenizer_dir=TOKENIZER_DIR):
218
+ with open(os.path.join(tokenizer_dir, "tokenizer.pkl"), "rb") as f:
219
+ enc = pickle.load(f)
220
+ return cls(enc)
221
+
222
+ def get_vocab_size(self):
223
+ return self.enc.n_vocab
224
+
225
+ def get_bos_token_id(self):
226
+ return self.bos_token_id
227
+
228
+ def encode(self, text, prepend=None, num_threads=8):
229
+ if prepend is not None:
230
+ prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend)
231
+ if isinstance(text, str):
232
+ ids = self.enc.encode_ordinary(text)
233
+ if prepend is not None:
234
+ ids.insert(0, prepend_id)
235
+ elif isinstance(text, list):
236
+ ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
237
+ if prepend is not None:
238
+ for row in ids:
239
+ row.insert(0, prepend_id)
240
+ else:
241
+ raise ValueError(f"Invalid input type: {type(text)}")
242
+ return ids
243
+
244
+ def decode(self, ids):
245
+ return self.enc.decode(ids)
246
+
247
+
248
+ def get_token_bytes(device="cpu"):
249
+ path = os.path.join(TOKENIZER_DIR, "token_bytes.pt")
250
+ with open(path, "rb") as f:
251
+ return torch.load(f, map_location=device)
252
+
253
+
254
+ def _document_batches(split, tokenizer_batch_size=128):
255
+ """Infinite iterator over document batches from parquet files."""
256
+ parquet_paths = list_parquet_files()
257
+ assert len(parquet_paths) > 0, "No parquet files found. Run prepare.py first."
258
+ val_path = os.path.join(DATA_DIR, VAL_FILENAME)
259
+ if split == "train":
260
+ parquet_paths = [p for p in parquet_paths if p != val_path]
261
+ assert len(parquet_paths) > 0, "No training shards found."
262
+ else:
263
+ parquet_paths = [val_path]
264
+ epoch = 1
265
+ while True:
266
+ for filepath in parquet_paths:
267
+ pf = pq.ParquetFile(filepath)
268
+ for rg_idx in range(pf.num_row_groups):
269
+ rg = pf.read_row_group(rg_idx)
270
+ batch = rg.column('text').to_pylist()
271
+ for i in range(0, len(batch), tokenizer_batch_size):
272
+ yield batch[i:i+tokenizer_batch_size], epoch
273
+ epoch += 1
274
+
275
+
276
+ def make_dataloader(tokenizer, B, T, split, buffer_size=1000):
277
+ """
278
+ BOS-aligned dataloader with best-fit packing.
279
+ Every row starts with BOS. Documents packed using best-fit to minimize cropping.
280
+ When no document fits remaining space, crops shortest doc to fill exactly.
281
+ 100% utilization (no padding).
282
+ """
283
+ assert split in ["train", "val"]
284
+ row_capacity = T + 1
285
+ batches = _document_batches(split)
286
+ bos_token = tokenizer.get_bos_token_id()
287
+ doc_buffer = []
288
+ epoch = 1
289
+
290
+ def refill_buffer():
291
+ nonlocal epoch
292
+ doc_batch, epoch = next(batches)
293
+ token_lists = tokenizer.encode(doc_batch, prepend=bos_token)
294
+ doc_buffer.extend(token_lists)
295
+
296
+ # Pre-allocate buffers: [inputs (B*T) | targets (B*T)]
297
+ row_buffer = torch.empty((B, row_capacity), dtype=torch.long)
298
+ cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=True)
299
+ gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device="cuda")
300
+ cpu_inputs = cpu_buffer[:B * T].view(B, T)
301
+ cpu_targets = cpu_buffer[B * T:].view(B, T)
302
+ inputs = gpu_buffer[:B * T].view(B, T)
303
+ targets = gpu_buffer[B * T:].view(B, T)
304
+
305
+ while True:
306
+ for row_idx in range(B):
307
+ pos = 0
308
+ while pos < row_capacity:
309
+ while len(doc_buffer) < buffer_size:
310
+ refill_buffer()
311
+
312
+ remaining = row_capacity - pos
313
+
314
+ # Find largest doc that fits entirely
315
+ best_idx = -1
316
+ best_len = 0
317
+ for i, doc in enumerate(doc_buffer):
318
+ doc_len = len(doc)
319
+ if doc_len <= remaining and doc_len > best_len:
320
+ best_idx = i
321
+ best_len = doc_len
322
+
323
+ if best_idx >= 0:
324
+ doc = doc_buffer.pop(best_idx)
325
+ row_buffer[row_idx, pos:pos + len(doc)] = torch.tensor(doc, dtype=torch.long)
326
+ pos += len(doc)
327
+ else:
328
+ # No doc fits — crop shortest to fill remaining
329
+ shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
330
+ doc = doc_buffer.pop(shortest_idx)
331
+ row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
332
+ pos += remaining
333
+
334
+ cpu_inputs.copy_(row_buffer[:, :-1])
335
+ cpu_targets.copy_(row_buffer[:, 1:])
336
+ gpu_buffer.copy_(cpu_buffer, non_blocking=True)
337
+ yield inputs, targets, epoch
338
+
339
+ # ---------------------------------------------------------------------------
340
+ # Evaluation (DO NOT CHANGE — this is the fixed metric)
341
+ # ---------------------------------------------------------------------------
342
+
343
+ @torch.no_grad()
344
+ def evaluate_bpb(model, tokenizer, batch_size):
345
+ """
346
+ Bits per byte (BPB): vocab size-independent evaluation metric.
347
+ Sums per-token cross-entropy (in nats), sums target byte lengths,
348
+ then converts nats/byte to bits/byte. Special tokens (byte length 0)
349
+ are excluded from both sums.
350
+ Uses fixed MAX_SEQ_LEN so results are comparable across configs.
351
+ """
352
+ token_bytes = get_token_bytes(device="cuda")
353
+ val_loader = make_dataloader(tokenizer, batch_size, MAX_SEQ_LEN, "val")
354
+ steps = EVAL_TOKENS // (batch_size * MAX_SEQ_LEN)
355
+ total_nats = 0.0
356
+ total_bytes = 0
357
+ for _ in range(steps):
358
+ x, y, _ = next(val_loader)
359
+ loss_flat = model(x, y, reduction='none').view(-1)
360
+ y_flat = y.view(-1)
361
+ nbytes = token_bytes[y_flat]
362
+ mask = nbytes > 0
363
+ total_nats += (loss_flat * mask).sum().item()
364
+ total_bytes += nbytes.sum().item()
365
+ return total_nats / (math.log(2) * total_bytes)
366
+
367
+ # ---------------------------------------------------------------------------
368
+ # Main
369
+ # ---------------------------------------------------------------------------
370
+
371
+ if __name__ == "__main__":
372
+ parser = argparse.ArgumentParser(description="Prepare data and tokenizer for autoresearch")
373
+ parser.add_argument("--num-shards", type=int, default=10, help="Number of training shards to download (-1 = all). Val shard is always pinned.")
374
+ parser.add_argument("--download-workers", type=int, default=8, help="Number of parallel download workers")
375
+ args = parser.parse_args()
376
+
377
+ num_shards = MAX_SHARD if args.num_shards == -1 else args.num_shards
378
+
379
+ print(f"Cache directory: {CACHE_DIR}")
380
+ print()
381
+
382
+ # Step 1: Download data
383
+ download_data(num_shards, download_workers=args.download_workers)
384
+ print()
385
+
386
+ # Step 2: Train tokenizer
387
+ train_tokenizer()
388
+ print()
389
+ print("Done! Ready to train.")
autoresearch/program.md ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # autoresearch
2
+
3
+ This is an experiment to have the LLM do its own research.
4
+
5
+ ## Setup
6
+
7
+ To set up a new experiment, work with the user to:
8
+
9
+ 1. **Agree on a run tag**: propose a tag based on today's date (e.g. `mar5`). The branch `autoresearch/<tag>` must not already exist — this is a fresh run.
10
+ 2. **Create the branch**: `git checkout -b autoresearch/<tag>` from current master.
11
+ 3. **Read the in-scope files**: The repo is small. Read these files for full context:
12
+ - `README.md` — repository context.
13
+ - `prepare.py` — fixed constants, data prep, tokenizer, dataloader, evaluation. Do not modify.
14
+ - `train.py` — the file you modify. Model architecture, optimizer, training loop.
15
+ 4. **Verify data exists**: Check that `~/.cache/autoresearch/` contains data shards and a tokenizer. If not, tell the human to run `uv run prepare.py`.
16
+ 5. **Initialize results.tsv**: Create `results.tsv` with just the header row. The baseline will be recorded after the first run.
17
+ 6. **Confirm and go**: Confirm setup looks good.
18
+
19
+ Once you get confirmation, kick off the experimentation.
20
+
21
+ ## Experimentation
22
+
23
+ Each experiment runs on a single GPU. The training script runs for a **fixed time budget of 5 minutes** (wall clock training time, excluding startup/compilation). You launch it simply as: `uv run train.py`.
24
+
25
+ **What you CAN do:**
26
+ - Modify `train.py` — this is the only file you edit. Everything is fair game: model architecture, optimizer, hyperparameters, training loop, batch size, model size, etc.
27
+
28
+ **What you CANNOT do:**
29
+ - Modify `prepare.py`. It is read-only. It contains the fixed evaluation, data loading, tokenizer, and training constants (time budget, sequence length, etc).
30
+ - Install new packages or add dependencies. You can only use what's already in `pyproject.toml`.
31
+ - Modify the evaluation harness. The `evaluate_bpb` function in `prepare.py` is the ground truth metric.
32
+
33
+ **The goal is simple: get the lowest val_bpb.** Since the time budget is fixed, you don't need to worry about training time — it's always 5 minutes. Everything is fair game: change the architecture, the optimizer, the hyperparameters, the batch size, the model size. The only constraint is that the code runs without crashing and finishes within the time budget.
34
+
35
+ **VRAM** is a soft constraint. Some increase is acceptable for meaningful val_bpb gains, but it should not blow up dramatically.
36
+
37
+ **Simplicity criterion**: All else being equal, simpler is better. A small improvement that adds ugly complexity is not worth it. Conversely, removing something and getting equal or better results is a great outcome — that's a simplification win. When evaluating whether to keep a change, weigh the complexity cost against the improvement magnitude. A 0.001 val_bpb improvement that adds 20 lines of hacky code? Probably not worth it. A 0.001 val_bpb improvement from deleting code? Definitely keep. An improvement of ~0 but much simpler code? Keep.
38
+
39
+ **The first run**: Your very first run should always be to establish the baseline, so you will run the training script as is.
40
+
41
+ ## Output format
42
+
43
+ Once the script finishes it prints a summary like this:
44
+
45
+ ```
46
+ ---
47
+ val_bpb: 0.997900
48
+ training_seconds: 300.1
49
+ total_seconds: 325.9
50
+ peak_vram_mb: 45060.2
51
+ mfu_percent: 39.80
52
+ total_tokens_M: 499.6
53
+ num_steps: 953
54
+ num_params_M: 50.3
55
+ depth: 8
56
+ ```
57
+
58
+ Note that the script is configured to always stop after 5 minutes, so depending on the computing platform of this computer the numbers might look different. You can extract the key metric from the log file:
59
+
60
+ ```
61
+ grep "^val_bpb:" run.log
62
+ ```
63
+
64
+ ## Logging results
65
+
66
+ When an experiment is done, log it to `results.tsv` (tab-separated, NOT comma-separated — commas break in descriptions).
67
+
68
+ The TSV has a header row and 5 columns:
69
+
70
+ ```
71
+ commit val_bpb memory_gb status description
72
+ ```
73
+
74
+ 1. git commit hash (short, 7 chars)
75
+ 2. val_bpb achieved (e.g. 1.234567) — use 0.000000 for crashes
76
+ 3. peak memory in GB, round to .1f (e.g. 12.3 — divide peak_vram_mb by 1024) — use 0.0 for crashes
77
+ 4. status: `keep`, `discard`, or `crash`
78
+ 5. short text description of what this experiment tried
79
+
80
+ Example:
81
+
82
+ ```
83
+ commit val_bpb memory_gb status description
84
+ a1b2c3d 0.997900 44.0 keep baseline
85
+ b2c3d4e 0.993200 44.2 keep increase LR to 0.04
86
+ c3d4e5f 1.005000 44.0 discard switch to GeLU activation
87
+ d4e5f6g 0.000000 0.0 crash double model width (OOM)
88
+ ```
89
+
90
+ ## The experiment loop
91
+
92
+ The experiment runs on a dedicated branch (e.g. `autoresearch/mar5` or `autoresearch/mar5-gpu0`).
93
+
94
+ LOOP FOREVER:
95
+
96
+ 1. Look at the git state: the current branch/commit we're on
97
+ 2. Tune `train.py` with an experimental idea by directly hacking the code.
98
+ 3. git commit
99
+ 4. Run the experiment: `uv run train.py > run.log 2>&1` (redirect everything — do NOT use tee or let output flood your context)
100
+ 5. Read out the results: `grep "^val_bpb:\|^peak_vram_mb:" run.log`
101
+ 6. If the grep output is empty, the run crashed. Run `tail -n 50 run.log` to read the Python stack trace and attempt a fix. If you can't get things to work after more than a few attempts, give up.
102
+ 7. Record the results in the tsv (NOTE: do not commit the results.tsv file, leave it untracked by git)
103
+ 8. If val_bpb improved (lower), you "advance" the branch, keeping the git commit
104
+ 9. If val_bpb is equal or worse, you git reset back to where you started
105
+
106
+ The idea is that you are a completely autonomous researcher trying things out. If they work, keep. If they don't, discard. And you're advancing the branch so that you can iterate. If you feel like you're getting stuck in some way, you can rewind but you should probably do this very very sparingly (if ever).
107
+
108
+ **Timeout**: Each experiment should take ~5 minutes total (+ a few seconds for startup and eval overhead). If a run exceeds 10 minutes, kill it and treat it as a failure (discard and revert).
109
+
110
+ **Crashes**: If a run crashes (OOM, or a bug, or etc.), use your judgment: If it's something dumb and easy to fix (e.g. a typo, a missing import), fix it and re-run. If the idea itself is fundamentally broken, just skip it, log "crash" as the status in the tsv, and move on.
111
+
112
+ **NEVER STOP**: Once the experiment loop has begun (after the initial setup), do NOT pause to ask the human if you should continue. Do NOT ask "should I keep going?" or "is this a good stopping point?". The human might be asleep, or gone from a computer and expects you to continue working *indefinitely* until you are manually stopped. You are autonomous. If you run out of ideas, think harder — read papers referenced in the code, re-read the in-scope files for new angles, try combining previous near-misses, try more radical architectural changes. The loop runs until the human interrupts you, period.
113
+
114
+ As an example use case, a user might leave you running while they sleep. If each experiment takes you ~5 minutes then you can run approx 12/hour, for a total of about 100 over the duration of the average human sleep. The user then wakes up to experimental results, all completed by you while they slept!
autoresearch/pyproject.toml ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "autoresearch"
3
+ version = "0.1.0"
4
+ description = "Autonomous pretraining research swarm"
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ dependencies = [
8
+ "kernels>=0.11.7",
9
+ "matplotlib>=3.10.8",
10
+ "numpy>=2.2.6",
11
+ "pandas>=2.3.3",
12
+ "pyarrow>=21.0.0",
13
+ "requests>=2.32.0",
14
+ "rustbpe>=0.1.0",
15
+ "tiktoken>=0.11.0",
16
+ "torch==2.9.1",
17
+ ]
18
+
19
+ [tool.uv.sources]
20
+ torch = [
21
+ { index = "pytorch-cu128" },
22
+ ]
23
+
24
+ [[tool.uv.index]]
25
+ name = "pytorch-cu128"
26
+ url = "https://download.pytorch.org/whl/cu128"
27
+ explicit = true
autoresearch/train.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Autoresearch pretraining script. Single-GPU, single-file.
3
+ Cherry-picked and simplified from nanochat.
4
+ Usage: uv run train.py
5
+ """
6
+
7
+ import os
8
+ os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
9
+ os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
10
+
11
+ import gc
12
+ import math
13
+ import time
14
+ from dataclasses import dataclass, asdict
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ from kernels import get_kernel
21
+ cap = torch.cuda.get_device_capability()
22
+ # varunneal's FA3 is Hopper only, use kernels-community on non-Hopper GPUs
23
+ repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3"
24
+ fa3 = get_kernel(repo).flash_attn_interface
25
+
26
+ from prepare import MAX_SEQ_LEN, TIME_BUDGET, Tokenizer, make_dataloader, evaluate_bpb
27
+
28
+ # ---------------------------------------------------------------------------
29
+ # GPT Model
30
+ # ---------------------------------------------------------------------------
31
+
32
+ @dataclass
33
+ class GPTConfig:
34
+ sequence_len: int = 2048
35
+ vocab_size: int = 32768
36
+ n_layer: int = 12
37
+ n_head: int = 6
38
+ n_kv_head: int = 6
39
+ n_embd: int = 768
40
+ window_pattern: str = "SSSL"
41
+
42
+
43
+ def norm(x):
44
+ return F.rms_norm(x, (x.size(-1),))
45
+
46
+
47
+ def has_ve(layer_idx, n_layer):
48
+ """Returns True if layer should have Value Embedding (alternating, last always included)."""
49
+ return layer_idx % 2 == (n_layer - 1) % 2
50
+
51
+
52
+ def apply_rotary_emb(x, cos, sin):
53
+ assert x.ndim == 4
54
+ d = x.shape[3] // 2
55
+ x1, x2 = x[..., :d], x[..., d:]
56
+ y1 = x1 * cos + x2 * sin
57
+ y2 = x1 * (-sin) + x2 * cos
58
+ return torch.cat([y1, y2], 3)
59
+
60
+
61
+ class CausalSelfAttention(nn.Module):
62
+ def __init__(self, config, layer_idx):
63
+ super().__init__()
64
+ self.n_head = config.n_head
65
+ self.n_kv_head = config.n_kv_head
66
+ self.n_embd = config.n_embd
67
+ self.head_dim = self.n_embd // self.n_head
68
+ assert self.n_embd % self.n_head == 0
69
+ assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
70
+ self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
71
+ self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
72
+ self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
73
+ self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
74
+ self.ve_gate_channels = 32
75
+ self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
76
+
77
+ def forward(self, x, ve, cos_sin, window_size):
78
+ B, T, C = x.size()
79
+ q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
80
+ k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
81
+ v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
82
+
83
+ # Value residual (ResFormer): mix in value embedding with input-dependent gate per head
84
+ if ve is not None:
85
+ ve = ve.view(B, T, self.n_kv_head, self.head_dim)
86
+ gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels]))
87
+ v = v + gate.unsqueeze(-1) * ve
88
+
89
+ cos, sin = cos_sin
90
+ q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
91
+ q, k = norm(q), norm(k)
92
+
93
+ y = fa3.flash_attn_func(q, k, v, causal=True, window_size=window_size)
94
+ y = y.contiguous().view(B, T, -1)
95
+ y = self.c_proj(y)
96
+ return y
97
+
98
+
99
+ class MLP(nn.Module):
100
+ def __init__(self, config):
101
+ super().__init__()
102
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
103
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
104
+
105
+ def forward(self, x):
106
+ x = self.c_fc(x)
107
+ x = F.relu(x).square()
108
+ x = self.c_proj(x)
109
+ return x
110
+
111
+
112
+ class Block(nn.Module):
113
+ def __init__(self, config, layer_idx):
114
+ super().__init__()
115
+ self.attn = CausalSelfAttention(config, layer_idx)
116
+ self.mlp = MLP(config)
117
+
118
+ def forward(self, x, ve, cos_sin, window_size):
119
+ x = x + self.attn(norm(x), ve, cos_sin, window_size)
120
+ x = x + self.mlp(norm(x))
121
+ return x
122
+
123
+
124
+ class GPT(nn.Module):
125
+ def __init__(self, config):
126
+ super().__init__()
127
+ self.config = config
128
+ self.window_sizes = self._compute_window_sizes(config)
129
+ self.transformer = nn.ModuleDict({
130
+ "wte": nn.Embedding(config.vocab_size, config.n_embd),
131
+ "h": nn.ModuleList([Block(config, i) for i in range(config.n_layer)]),
132
+ })
133
+ self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
134
+ self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer))
135
+ self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer))
136
+ # Value embeddings
137
+ head_dim = config.n_embd // config.n_head
138
+ kv_dim = config.n_kv_head * head_dim
139
+ self.value_embeds = nn.ModuleDict({
140
+ str(i): nn.Embedding(config.vocab_size, kv_dim)
141
+ for i in range(config.n_layer) if has_ve(i, config.n_layer)
142
+ })
143
+ # Rotary embeddings
144
+ self.rotary_seq_len = config.sequence_len * 10
145
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
146
+ self.register_buffer("cos", cos, persistent=False)
147
+ self.register_buffer("sin", sin, persistent=False)
148
+
149
+ @torch.no_grad()
150
+ def init_weights(self):
151
+ # Embedding and unembedding
152
+ torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
153
+ torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
154
+ # Transformer blocks
155
+ n_embd = self.config.n_embd
156
+ s = 3**0.5 * n_embd**-0.5
157
+ for block in self.transformer.h:
158
+ torch.nn.init.uniform_(block.attn.c_q.weight, -s, s)
159
+ torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
160
+ torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
161
+ torch.nn.init.zeros_(block.attn.c_proj.weight)
162
+ torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
163
+ torch.nn.init.zeros_(block.mlp.c_proj.weight)
164
+ # Per-layer scalars
165
+ self.resid_lambdas.fill_(1.0)
166
+ self.x0_lambdas.fill_(0.1)
167
+ # Value embeddings
168
+ for ve in self.value_embeds.values():
169
+ torch.nn.init.uniform_(ve.weight, -s, s)
170
+ # Gate weights init to zero (sigmoid(0)=0.5, scaled by 2 -> 1.0 = neutral)
171
+ for block in self.transformer.h:
172
+ if block.attn.ve_gate is not None:
173
+ torch.nn.init.zeros_(block.attn.ve_gate.weight)
174
+ # Rotary embeddings
175
+ head_dim = self.config.n_embd // self.config.n_head
176
+ cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
177
+ self.cos, self.sin = cos, sin
178
+ # Cast embeddings to bf16
179
+ self.transformer.wte.to(dtype=torch.bfloat16)
180
+ for ve in self.value_embeds.values():
181
+ ve.to(dtype=torch.bfloat16)
182
+
183
+ def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
184
+ if device is None:
185
+ device = self.transformer.wte.weight.device
186
+ channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
187
+ inv_freq = 1.0 / (base ** (channel_range / head_dim))
188
+ t = torch.arange(seq_len, dtype=torch.float32, device=device)
189
+ freqs = torch.outer(t, inv_freq)
190
+ cos, sin = freqs.cos(), freqs.sin()
191
+ cos, sin = cos.bfloat16(), sin.bfloat16()
192
+ cos, sin = cos[None, :, None, :], sin[None, :, None, :]
193
+ return cos, sin
194
+
195
+ def _compute_window_sizes(self, config):
196
+ pattern = config.window_pattern.upper()
197
+ assert all(c in "SL" for c in pattern)
198
+ long_window = config.sequence_len
199
+ short_window = long_window // 2
200
+ char_to_window = {"L": (long_window, 0), "S": (short_window, 0)}
201
+ window_sizes = []
202
+ for layer_idx in range(config.n_layer):
203
+ char = pattern[layer_idx % len(pattern)]
204
+ window_sizes.append(char_to_window[char])
205
+ window_sizes[-1] = (long_window, 0)
206
+ return window_sizes
207
+
208
+ def estimate_flops(self):
209
+ """Estimated FLOPs per token (forward + backward)."""
210
+ nparams = sum(p.numel() for p in self.parameters())
211
+ value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
212
+ nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
213
+ self.resid_lambdas.numel() + self.x0_lambdas.numel())
214
+ h = self.config.n_head
215
+ q = self.config.n_embd // self.config.n_head
216
+ t = self.config.sequence_len
217
+ attn_flops = 0
218
+ for window_size in self.window_sizes:
219
+ window = window_size[0]
220
+ effective_seq = t if window < 0 else min(window, t)
221
+ attn_flops += 12 * h * q * effective_seq
222
+ return 6 * (nparams - nparams_exclude) + attn_flops
223
+
224
+ def num_scaling_params(self):
225
+ wte = sum(p.numel() for p in self.transformer.wte.parameters())
226
+ value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
227
+ lm_head = sum(p.numel() for p in self.lm_head.parameters())
228
+ transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
229
+ scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
230
+ total = wte + value_embeds + lm_head + transformer_matrices + scalars
231
+ return {
232
+ 'wte': wte, 'value_embeds': value_embeds, 'lm_head': lm_head,
233
+ 'transformer_matrices': transformer_matrices, 'scalars': scalars, 'total': total,
234
+ }
235
+
236
+ def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02,
237
+ weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
238
+ model_dim = self.config.n_embd
239
+ matrix_params = list(self.transformer.h.parameters())
240
+ value_embeds_params = list(self.value_embeds.parameters())
241
+ embedding_params = list(self.transformer.wte.parameters())
242
+ lm_head_params = list(self.lm_head.parameters())
243
+ resid_params = [self.resid_lambdas]
244
+ x0_params = [self.x0_lambdas]
245
+ assert len(list(self.parameters())) == (len(matrix_params) + len(embedding_params) +
246
+ len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params))
247
+ # Scale LR ∝ 1/√dmodel (tuned at 768 dim)
248
+ dmodel_lr_scale = (model_dim / 768) ** -0.5
249
+ print(f"Scaling AdamW LRs by 1/sqrt({model_dim}/768) = {dmodel_lr_scale:.6f}")
250
+ param_groups = [
251
+ dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
252
+ dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
253
+ dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
254
+ dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
255
+ dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0),
256
+ ]
257
+ for shape in sorted({p.shape for p in matrix_params}):
258
+ group_params = [p for p in matrix_params if p.shape == shape]
259
+ param_groups.append(dict(
260
+ kind='muon', params=group_params, lr=matrix_lr,
261
+ momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
262
+ ))
263
+ optimizer = MuonAdamW(param_groups)
264
+ for group in optimizer.param_groups:
265
+ group["initial_lr"] = group["lr"]
266
+ return optimizer
267
+
268
+ def forward(self, idx, targets=None, reduction='mean'):
269
+ B, T = idx.size()
270
+ assert T <= self.cos.size(1)
271
+ cos_sin = self.cos[:, :T], self.sin[:, :T]
272
+
273
+ x = self.transformer.wte(idx)
274
+ x = norm(x)
275
+ x0 = x
276
+ for i, block in enumerate(self.transformer.h):
277
+ x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
278
+ ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
279
+ x = block(x, ve, cos_sin, self.window_sizes[i])
280
+ x = norm(x)
281
+
282
+ softcap = 15
283
+ logits = self.lm_head(x)
284
+ logits = logits.float()
285
+ logits = softcap * torch.tanh(logits / softcap)
286
+
287
+ if targets is not None:
288
+ loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1),
289
+ ignore_index=-1, reduction=reduction)
290
+ return loss
291
+ return logits
292
+
293
+ # ---------------------------------------------------------------------------
294
+ # Optimizer (MuonAdamW, single GPU only)
295
+ # ---------------------------------------------------------------------------
296
+
297
+ polar_express_coeffs = [
298
+ (8.156554524902461, -22.48329292557795, 15.878769915207462),
299
+ (4.042929935166739, -2.808917465908714, 0.5000178451051316),
300
+ (3.8916678022926607, -2.772484153217685, 0.5060648178503393),
301
+ (3.285753657755655, -2.3681294933425376, 0.46449024233003106),
302
+ (2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
303
+ ]
304
+
305
+ @torch.compile(dynamic=False, fullgraph=True)
306
+ def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t):
307
+ p.mul_(1 - lr_t * wd_t)
308
+ exp_avg.lerp_(grad, 1 - beta1_t)
309
+ exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
310
+ bias1 = 1 - beta1_t ** step_t
311
+ bias2 = 1 - beta2_t ** step_t
312
+ denom = (exp_avg_sq / bias2).sqrt() + eps_t
313
+ step_size = lr_t / bias1
314
+ p.add_(exp_avg / denom, alpha=-step_size)
315
+
316
+ @torch.compile(dynamic=False, fullgraph=True)
317
+ def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer,
318
+ momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim):
319
+ # Nesterov momentum
320
+ momentum = momentum_t.to(stacked_grads.dtype)
321
+ momentum_buffer.lerp_(stacked_grads, 1 - momentum)
322
+ g = stacked_grads.lerp_(momentum_buffer, momentum)
323
+ # Polar express orthogonalization
324
+ X = g.bfloat16()
325
+ X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
326
+ if g.size(-2) > g.size(-1):
327
+ for a, b, c in polar_express_coeffs[:ns_steps]:
328
+ A = X.mT @ X
329
+ B = b * A + c * (A @ A)
330
+ X = a * X + X @ B
331
+ else:
332
+ for a, b, c in polar_express_coeffs[:ns_steps]:
333
+ A = X @ X.mT
334
+ B = b * A + c * (A @ A)
335
+ X = a * X + B @ X
336
+ g = X
337
+ # NorMuon variance reduction
338
+ beta2 = beta2_t.to(g.dtype)
339
+ v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
340
+ red_dim_size = g.size(red_dim)
341
+ v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
342
+ v_norm = v_norm_sq.sqrt()
343
+ second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
344
+ step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
345
+ scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
346
+ v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
347
+ final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
348
+ g = g * final_scale.to(g.dtype)
349
+ # Cautious weight decay + parameter update
350
+ lr = lr_t.to(g.dtype)
351
+ wd = wd_t.to(g.dtype)
352
+ mask = (g * stacked_params) >= 0
353
+ stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
354
+
355
+
356
+ class MuonAdamW(torch.optim.Optimizer):
357
+ """Combined optimizer: Muon for 2D matrix params, AdamW for others."""
358
+
359
+ def __init__(self, param_groups):
360
+ super().__init__(param_groups, defaults={})
361
+ # 0-D CPU tensors to avoid torch.compile recompilation when values change
362
+ self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
363
+ self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
364
+ self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
365
+ self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
366
+ self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
367
+ self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
368
+ self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
369
+ self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
370
+ self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
371
+ self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
372
+
373
+ def _step_adamw(self, group):
374
+ for p in group['params']:
375
+ if p.grad is None:
376
+ continue
377
+ grad = p.grad
378
+ state = self.state[p]
379
+ if not state:
380
+ state['step'] = 0
381
+ state['exp_avg'] = torch.zeros_like(p)
382
+ state['exp_avg_sq'] = torch.zeros_like(p)
383
+ state['step'] += 1
384
+ self._adamw_step_t.fill_(state['step'])
385
+ self._adamw_lr_t.fill_(group['lr'])
386
+ self._adamw_beta1_t.fill_(group['betas'][0])
387
+ self._adamw_beta2_t.fill_(group['betas'][1])
388
+ self._adamw_eps_t.fill_(group['eps'])
389
+ self._adamw_wd_t.fill_(group['weight_decay'])
390
+ adamw_step_fused(p, grad, state['exp_avg'], state['exp_avg_sq'],
391
+ self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
392
+ self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)
393
+
394
+ def _step_muon(self, group):
395
+ params = group['params']
396
+ if not params:
397
+ return
398
+ p = params[0]
399
+ state = self.state[p]
400
+ num_params = len(params)
401
+ shape, device, dtype = p.shape, p.device, p.dtype
402
+ if "momentum_buffer" not in state:
403
+ state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
404
+ if "second_momentum_buffer" not in state:
405
+ state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
406
+ state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
407
+ red_dim = -1 if shape[-2] >= shape[-1] else -2
408
+ stacked_grads = torch.stack([p.grad for p in params])
409
+ stacked_params = torch.stack(params)
410
+ self._muon_momentum_t.fill_(group["momentum"])
411
+ self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
412
+ self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
413
+ self._muon_wd_t.fill_(group["weight_decay"])
414
+ muon_step_fused(stacked_grads, stacked_params,
415
+ state["momentum_buffer"], state["second_momentum_buffer"],
416
+ self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t,
417
+ self._muon_beta2_t, group["ns_steps"], red_dim)
418
+ torch._foreach_copy_(params, list(stacked_params.unbind(0)))
419
+
420
+ @torch.no_grad()
421
+ def step(self):
422
+ for group in self.param_groups:
423
+ if group['kind'] == 'adamw':
424
+ self._step_adamw(group)
425
+ elif group['kind'] == 'muon':
426
+ self._step_muon(group)
427
+
428
+ # ---------------------------------------------------------------------------
429
+ # Hyperparameters (edit these directly, no CLI flags needed)
430
+ # ---------------------------------------------------------------------------
431
+
432
+ # Model architecture
433
+ ASPECT_RATIO = 64 # model_dim = depth * ASPECT_RATIO
434
+ HEAD_DIM = 128 # target head dimension for attention
435
+ WINDOW_PATTERN = "SSSL" # sliding window pattern: L=full, S=half context
436
+
437
+ # Optimization
438
+ TOTAL_BATCH_SIZE = 2**19 # ~524K tokens per optimizer step
439
+ EMBEDDING_LR = 0.6 # learning rate for token embeddings (Adam)
440
+ UNEMBEDDING_LR = 0.004 # learning rate for lm_head (Adam)
441
+ MATRIX_LR = 0.04 # learning rate for matrix parameters (Muon)
442
+ SCALAR_LR = 0.5 # learning rate for per-layer scalars (Adam)
443
+ WEIGHT_DECAY = 0.2 # cautious weight decay for Muon
444
+ ADAM_BETAS = (0.8, 0.95) # Adam beta1, beta2
445
+ WARMUP_RATIO = 0.0 # fraction of time budget for LR warmup
446
+ WARMDOWN_RATIO = 0.5 # fraction of time budget for LR warmdown
447
+ FINAL_LR_FRAC = 0.0 # final LR as fraction of initial
448
+
449
+ # Model size
450
+ DEPTH = 8 # number of transformer layers
451
+ DEVICE_BATCH_SIZE = 128 # per-device batch size (reduce if OOM)
452
+
453
+ # ---------------------------------------------------------------------------
454
+ # Setup: tokenizer, model, optimizer, dataloader
455
+ # ---------------------------------------------------------------------------
456
+
457
+ t_start = time.time()
458
+ torch.manual_seed(42)
459
+ torch.cuda.manual_seed(42)
460
+ torch.set_float32_matmul_precision("high")
461
+ device = torch.device("cuda")
462
+ autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
463
+ H100_BF16_PEAK_FLOPS = 989.5e12
464
+
465
+ tokenizer = Tokenizer.from_directory()
466
+ vocab_size = tokenizer.get_vocab_size()
467
+ print(f"Vocab size: {vocab_size:,}")
468
+
469
+ def build_model_config(depth):
470
+ base_dim = depth * ASPECT_RATIO
471
+ model_dim = ((base_dim + HEAD_DIM - 1) // HEAD_DIM) * HEAD_DIM
472
+ num_heads = model_dim // HEAD_DIM
473
+ return GPTConfig(
474
+ sequence_len=MAX_SEQ_LEN, vocab_size=vocab_size,
475
+ n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
476
+ window_pattern=WINDOW_PATTERN,
477
+ )
478
+
479
+ config = build_model_config(DEPTH)
480
+ print(f"Model config: {asdict(config)}")
481
+
482
+ with torch.device("meta"):
483
+ model = GPT(config)
484
+ model.to_empty(device=device)
485
+ model.init_weights()
486
+
487
+ param_counts = model.num_scaling_params()
488
+ print("Parameter counts:")
489
+ for key, value in param_counts.items():
490
+ print(f" {key:24s}: {value:,}")
491
+ num_params = param_counts['total']
492
+ num_flops_per_token = model.estimate_flops()
493
+ print(f"Estimated FLOPs per token: {num_flops_per_token:e}")
494
+
495
+ tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN
496
+ assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0
497
+ grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd
498
+
499
+ optimizer = model.setup_optimizer(
500
+ unembedding_lr=UNEMBEDDING_LR,
501
+ embedding_lr=EMBEDDING_LR,
502
+ scalar_lr=SCALAR_LR,
503
+ adam_betas=ADAM_BETAS,
504
+ matrix_lr=MATRIX_LR,
505
+ weight_decay=WEIGHT_DECAY,
506
+ )
507
+
508
+ model = torch.compile(model, dynamic=False)
509
+
510
+ train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, MAX_SEQ_LEN, "train")
511
+ x, y, epoch = next(train_loader) # prefetch first batch
512
+
513
+ print(f"Time budget: {TIME_BUDGET}s")
514
+ print(f"Gradient accumulation steps: {grad_accum_steps}")
515
+
516
+ # Schedules (all based on progress = training_time / TIME_BUDGET)
517
+
518
+ def get_lr_multiplier(progress):
519
+ if progress < WARMUP_RATIO:
520
+ return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0
521
+ elif progress < 1.0 - WARMDOWN_RATIO:
522
+ return 1.0
523
+ else:
524
+ cooldown = (1.0 - progress) / WARMDOWN_RATIO
525
+ return cooldown * 1.0 + (1 - cooldown) * FINAL_LR_FRAC
526
+
527
+ def get_muon_momentum(step):
528
+ frac = min(step / 300, 1)
529
+ return (1 - frac) * 0.85 + frac * 0.95
530
+
531
+ def get_weight_decay(progress):
532
+ return WEIGHT_DECAY * (1 - progress)
533
+
534
+ # ---------------------------------------------------------------------------
535
+ # Training loop
536
+ # ---------------------------------------------------------------------------
537
+
538
+ t_start_training = time.time()
539
+ smooth_train_loss = 0
540
+ total_training_time = 0
541
+ step = 0
542
+
543
+ while True:
544
+ torch.cuda.synchronize()
545
+ t0 = time.time()
546
+ for micro_step in range(grad_accum_steps):
547
+ with autocast_ctx:
548
+ loss = model(x, y)
549
+ train_loss = loss.detach()
550
+ loss = loss / grad_accum_steps
551
+ loss.backward()
552
+ x, y, epoch = next(train_loader)
553
+
554
+ # Progress and schedules
555
+ progress = min(total_training_time / TIME_BUDGET, 1.0)
556
+ lrm = get_lr_multiplier(progress)
557
+ muon_momentum = get_muon_momentum(step)
558
+ muon_weight_decay = get_weight_decay(progress)
559
+ for group in optimizer.param_groups:
560
+ group["lr"] = group["initial_lr"] * lrm
561
+ if group['kind'] == 'muon':
562
+ group["momentum"] = muon_momentum
563
+ group["weight_decay"] = muon_weight_decay
564
+ optimizer.step()
565
+ model.zero_grad(set_to_none=True)
566
+
567
+ train_loss_f = train_loss.item()
568
+
569
+ # Fast fail: abort if loss is exploding or NaN
570
+ if math.isnan(train_loss_f) or train_loss_f > 100:
571
+ print("FAIL")
572
+ exit(1)
573
+
574
+ torch.cuda.synchronize()
575
+ t1 = time.time()
576
+ dt = t1 - t0
577
+
578
+ if step > 10:
579
+ total_training_time += dt
580
+
581
+ # Logging
582
+ ema_beta = 0.9
583
+ smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f
584
+ debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1))
585
+ pct_done = 100 * progress
586
+ tok_per_sec = int(TOTAL_BATCH_SIZE / dt)
587
+ mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE / dt / H100_BF16_PEAK_FLOPS
588
+ remaining = max(0, TIME_BUDGET - total_training_time)
589
+
590
+ print(f"\rstep {step:05d} ({pct_done:.1f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt*1000:.0f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.1f}% | epoch: {epoch} | remaining: {remaining:.0f}s ", end="", flush=True)
591
+
592
+ # GC management (Python's GC causes ~500ms stalls)
593
+ if step == 0:
594
+ gc.collect()
595
+ gc.freeze()
596
+ gc.disable()
597
+ elif (step + 1) % 5000 == 0:
598
+ gc.collect()
599
+
600
+ step += 1
601
+
602
+ # Time's up — but only stop after warmup steps so we don't count compilation
603
+ if step > 10 and total_training_time >= TIME_BUDGET:
604
+ break
605
+
606
+ print() # newline after \r training log
607
+
608
+ total_tokens = step * TOTAL_BATCH_SIZE
609
+
610
+ # Final eval
611
+ model.eval()
612
+ with autocast_ctx:
613
+ val_bpb = evaluate_bpb(model, tokenizer, DEVICE_BATCH_SIZE)
614
+
615
+ # Final summary
616
+ t_end = time.time()
617
+ startup_time = t_start_training - t_start
618
+ steady_state_mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE * (step - 10) / total_training_time / H100_BF16_PEAK_FLOPS if total_training_time > 0 else 0
619
+ peak_vram_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
620
+
621
+ print("---")
622
+ print(f"val_bpb: {val_bpb:.6f}")
623
+ print(f"training_seconds: {total_training_time:.1f}")
624
+ print(f"total_seconds: {t_end - t_start:.1f}")
625
+ print(f"peak_vram_mb: {peak_vram_mb:.1f}")
626
+ print(f"mfu_percent: {steady_state_mfu:.2f}")
627
+ print(f"total_tokens_M: {total_tokens / 1e6:.1f}")
628
+ print(f"num_steps: {step}")
629
+ print(f"num_params_M: {num_params / 1e6:.1f}")
630
+ print(f"depth: {DEPTH}")
autoresearch/uv.lock ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio>=5.0.0
2
+ spaces
3
+ torch
4
+ transformers
5
+ accelerate
6
+ openai