Spaces:
Running on Zero
Running on Zero
dexifried commited on
Commit ·
a6ecbb1
0
Parent(s):
Sovereign Deploy: Live-Streaming H200 Terminal
Browse files- README.md +11 -0
- app.py +91 -0
- autoresearch/.gitignore +23 -0
- autoresearch/.python-version +1 -0
- autoresearch/README.md +91 -0
- autoresearch/prepare.py +389 -0
- autoresearch/program.md +114 -0
- autoresearch/pyproject.toml +27 -0
- autoresearch/train.py +630 -0
- autoresearch/uv.lock +0 -0
- requirements.txt +6 -0
README.md
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Dex Evolution Outpost
|
| 3 |
+
emoji: 🧬
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: indigo
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.0.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
---
|
app.py
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
# --- ⚔️ THE SOVEREIGN PATCH ---
|
| 5 |
+
import huggingface_hub
|
| 6 |
+
if not hasattr(huggingface_hub, 'HfFolder'):
|
| 7 |
+
class SovereignHfFolder:
|
| 8 |
+
@staticmethod
|
| 9 |
+
def get_token(): return None
|
| 10 |
+
huggingface_hub.HfFolder = SovereignHfFolder
|
| 11 |
+
sys.modules['huggingface_hub.HfFolder'] = SovereignHfFolder
|
| 12 |
+
|
| 13 |
+
import subprocess
|
| 14 |
+
import gradio as gr
|
| 15 |
+
import spaces
|
| 16 |
+
import time
|
| 17 |
+
|
| 18 |
+
REPO_DIR = os.path.join(os.getcwd(), "autoresearch")
|
| 19 |
+
|
| 20 |
+
@spaces.GPU(duration=300) # Maximum ZeroGPU burst (5 Minutes)
|
| 21 |
+
def execute_on_h200(command, openrouter_key):
|
| 22 |
+
if not command:
|
| 23 |
+
yield "❌ Please enter a command."
|
| 24 |
+
return
|
| 25 |
+
|
| 26 |
+
env = os.environ.copy()
|
| 27 |
+
if openrouter_key:
|
| 28 |
+
env["OPENAI_API_KEY"] = openrouter_key
|
| 29 |
+
env["OPENAI_BASE_URL"] = "https://openrouter.ai/api/v1"
|
| 30 |
+
|
| 31 |
+
# Force Python to stream logs instantly instead of buffering them
|
| 32 |
+
env["PYTHONUNBUFFERED"] = "1"
|
| 33 |
+
|
| 34 |
+
output_log = f"[*] Attaching H200 and executing: {command}\n\n"
|
| 35 |
+
yield output_log
|
| 36 |
+
|
| 37 |
+
try:
|
| 38 |
+
# Popen allows us to stream the output live
|
| 39 |
+
process = subprocess.Popen(
|
| 40 |
+
command,
|
| 41 |
+
shell=True,
|
| 42 |
+
cwd=REPO_DIR,
|
| 43 |
+
env=env,
|
| 44 |
+
stdout=subprocess.PIPE,
|
| 45 |
+
stderr=subprocess.STDOUT,
|
| 46 |
+
text=True,
|
| 47 |
+
bufsize=1,
|
| 48 |
+
universal_newlines=True
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Stream the output line-by-line to the Gradio UI
|
| 52 |
+
for line in iter(process.stdout.readline, ''):
|
| 53 |
+
output_log += line
|
| 54 |
+
yield output_log
|
| 55 |
+
|
| 56 |
+
process.wait()
|
| 57 |
+
output_log += f"\n--- EXIT CODE: {process.returncode} ---"
|
| 58 |
+
yield output_log
|
| 59 |
+
|
| 60 |
+
except Exception as e:
|
| 61 |
+
yield output_log + f"\n❌ CRASH: {str(e)}"
|
| 62 |
+
|
| 63 |
+
def get_readme():
|
| 64 |
+
try:
|
| 65 |
+
with open(os.path.join(REPO_DIR, "README.md"), "r") as f:
|
| 66 |
+
return f.read()
|
| 67 |
+
except:
|
| 68 |
+
return "README not found."
|
| 69 |
+
|
| 70 |
+
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
| 71 |
+
gr.Markdown("# 🧬 Dex Sovereign H200 Terminal (Live Stream Edition)")
|
| 72 |
+
gr.Markdown("Direct command-line execution on the Hugging Face ZeroGPU. Output streams live.")
|
| 73 |
+
|
| 74 |
+
with gr.Row():
|
| 75 |
+
with gr.Column(scale=2):
|
| 76 |
+
cmd_input = gr.Textbox(
|
| 77 |
+
value="ls -la && python prepare.py",
|
| 78 |
+
label="Execute Command on H200"
|
| 79 |
+
)
|
| 80 |
+
or_key = gr.Textbox(label="OpenRouter Key (Optional, bypasses OpenAI)", type="password")
|
| 81 |
+
btn = gr.Button("🚀 Run Command on GPU", variant="primary")
|
| 82 |
+
with gr.Column(scale=3):
|
| 83 |
+
output = gr.Textbox(label="Live Terminal Output", lines=20, max_lines=40)
|
| 84 |
+
|
| 85 |
+
gr.Markdown("### 📖 Framework Documentation (README.md)")
|
| 86 |
+
gr.Markdown(get_readme())
|
| 87 |
+
|
| 88 |
+
btn.click(fn=execute_on_h200, inputs=[cmd_input, or_key], outputs=[output])
|
| 89 |
+
|
| 90 |
+
if __name__ == "__main__":
|
| 91 |
+
demo.launch()
|
autoresearch/.gitignore
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python-generated files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[oc]
|
| 4 |
+
build/
|
| 5 |
+
dist/
|
| 6 |
+
wheels/
|
| 7 |
+
*.egg-info
|
| 8 |
+
|
| 9 |
+
# Virtual environments
|
| 10 |
+
.venv
|
| 11 |
+
worktrees/
|
| 12 |
+
results/
|
| 13 |
+
queue/
|
| 14 |
+
|
| 15 |
+
# Agent prompt files (generated per-session by launchers)
|
| 16 |
+
CLAUDE.md
|
| 17 |
+
AGENTS.md
|
| 18 |
+
|
| 19 |
+
# Experimental code/artifacts
|
| 20 |
+
dev/
|
| 21 |
+
|
| 22 |
+
# Results file
|
| 23 |
+
results.tsv
|
autoresearch/.python-version
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
3.10
|
autoresearch/README.md
ADDED
|
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# autoresearch
|
| 2 |
+
|
| 3 |
+

|
| 4 |
+
|
| 5 |
+
*One day, frontier AI research used to be done by meat computers in between eating, sleeping, having other fun, and synchronizing once in a while using sound wave interconnect in the ritual of "group meeting". That era is long gone. Research is now entirely the domain of autonomous swarms of AI agents running across compute cluster megastructures in the skies. The agents claim that we are now in the 10,205th generation of the code base, in any case no one could tell if that's right or wrong as the "code" is now a self-modifying binary that has grown beyond human comprehension. This repo is the story of how it all began. -@karpathy, March 2026*.
|
| 6 |
+
|
| 7 |
+
The idea: give an AI agent a small but real LLM training setup and let it experiment autonomously overnight. It modifies the code, trains for 5 minutes, checks if the result improved, keeps or discards, and repeats. You wake up in the morning to a log of experiments and (hopefully) a better model. The training code here is a simplified single-GPU implementation of [nanochat](https://github.com/karpathy/nanochat). The core idea is that you're not touching any of the Python files like you normally would as a researcher. Instead, you are programming the `program.md` Markdown files that provide context to the AI agents and set up your autonomous research org. The default `program.md` in this repo is intentionally kept as a bare bones baseline, though it's obvious how one would iterate on it over time to find the "research org code" that achieves the fastest research progress, how you'd add more agents to the mix, etc. A bit more context on this project is here in this [tweet](https://x.com/karpathy/status/2029701092347630069).
|
| 8 |
+
|
| 9 |
+
## How it works
|
| 10 |
+
|
| 11 |
+
The repo is deliberately kept small and only really has three files that matter:
|
| 12 |
+
|
| 13 |
+
- **`prepare.py`** — fixed constants, one-time data prep (downloads training data, trains a BPE tokenizer), and runtime utilities (dataloader, evaluation). Not modified.
|
| 14 |
+
- **`train.py`** — the single file the agent edits. Contains the full GPT model, optimizer (Muon + AdamW), and training loop. Everything is fair game: architecture, hyperparameters, optimizer, batch size, etc. **This file is edited and iterated on by the agent**.
|
| 15 |
+
- **`program.md`** — baseline instructions for one agent. Point your agent here and let it go. **This file is edited and iterated on by the human**.
|
| 16 |
+
|
| 17 |
+
By design, training runs for a **fixed 5-minute time budget** (wall clock, excluding startup/compilation), regardless of the details of your compute. The metric is **val_bpb** (validation bits per byte) — lower is better, and vocab-size-independent so architectural changes are fairly compared.
|
| 18 |
+
|
| 19 |
+
If you are new to neural networks, this ["Dummy's Guide"](https://x.com/hooeem/status/2030720614752039185) looks pretty good for a lot more context.
|
| 20 |
+
|
| 21 |
+
## Quick start
|
| 22 |
+
|
| 23 |
+
**Requirements:** A single NVIDIA GPU (tested on H100), Python 3.10+, [uv](https://docs.astral.sh/uv/).
|
| 24 |
+
|
| 25 |
+
```bash
|
| 26 |
+
|
| 27 |
+
# 1. Install uv project manager (if you don't already have it)
|
| 28 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 29 |
+
|
| 30 |
+
# 2. Install dependencies
|
| 31 |
+
uv sync
|
| 32 |
+
|
| 33 |
+
# 3. Download data and train tokenizer (one-time, ~2 min)
|
| 34 |
+
uv run prepare.py
|
| 35 |
+
|
| 36 |
+
# 4. Manually run a single training experiment (~5 min)
|
| 37 |
+
uv run train.py
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
If the above commands all work ok, your setup is working and you can go into autonomous research mode.
|
| 41 |
+
|
| 42 |
+
## Running the agent
|
| 43 |
+
|
| 44 |
+
Simply spin up your Claude/Codex or whatever you want in this repo (and disable all permissions), then you can prompt something like:
|
| 45 |
+
|
| 46 |
+
```
|
| 47 |
+
Hi have a look at program.md and let's kick off a new experiment! let's do the setup first.
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
The `program.md` file is essentially a super lightweight "skill".
|
| 51 |
+
|
| 52 |
+
## Project structure
|
| 53 |
+
|
| 54 |
+
```
|
| 55 |
+
prepare.py — constants, data prep + runtime utilities (do not modify)
|
| 56 |
+
train.py — model, optimizer, training loop (agent modifies this)
|
| 57 |
+
program.md — agent instructions
|
| 58 |
+
pyproject.toml — dependencies
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
## Design choices
|
| 62 |
+
|
| 63 |
+
- **Single file to modify.** The agent only touches `train.py`. This keeps the scope manageable and diffs reviewable.
|
| 64 |
+
- **Fixed time budget.** Training always runs for exactly 5 minutes, regardless of your specific platform. This means you can expect approx 12 experiments/hour and approx 100 experiments while you sleep. There are two upsides of this design decision. First, this makes experiments directly comparable regardless of what the agent changes (model size, batch size, architecture, etc). Second, this means that autoresearch will find the most optimal model for your platform in that time budget. The downside is that your runs (and results) become not comparable to other people running on other compute platforms.
|
| 65 |
+
- **Self-contained.** No external dependencies beyond PyTorch and a few small packages. No distributed training, no complex configs. One GPU, one file, one metric.
|
| 66 |
+
|
| 67 |
+
## Platform support
|
| 68 |
+
|
| 69 |
+
This code currently requires that you have a single NVIDIA GPU. In principle it is quite possible to support CPU, MPS and other platforms but this would also bloat the code. I'm not 100% sure that I want to take this on personally right now. People can reference (or have their agents reference) the full/parent nanochat repository that has wider platform support and shows the various solutions (e.g. a Flash Attention 3 kernels fallback implementation, generic device support, autodetection, etc.), feel free to create forks or discussions for other platforms and I'm happy to link to them here in the README in some new notable forks section or etc.
|
| 70 |
+
|
| 71 |
+
Seeing as there seems to be a lot of interest in tinkering with autoresearch on much smaller compute platforms than an H100, a few extra words. If you're going to try running autoresearch on smaller computers (Macbooks etc.), I'd recommend one of the forks below. On top of this, here are some recommendations for how to tune the defaults for much smaller models for aspiring forks:
|
| 72 |
+
|
| 73 |
+
1. To get half-decent results I'd use a dataset with a lot less entropy, e.g. this [TinyStories dataset](https://huggingface.co/datasets/karpathy/tinystories-gpt4-clean). These are GPT-4 generated short stories. Because the data is a lot narrower in scope, you will see reasonable results with a lot smaller models (if you try to sample from them after training).
|
| 74 |
+
2. You might experiment with decreasing `vocab_size`, e.g. from 8192 down to 4096, 2048, 1024, or even - simply byte-level tokenizer with 256 possibly bytes after utf-8 encoding.
|
| 75 |
+
3. In `prepare.py`, you'll want to lower `MAX_SEQ_LEN` a lot, depending on the computer even down to 256 etc. As you lower `MAX_SEQ_LEN`, you may want to experiment with increasing `DEVICE_BATCH_SIZE` in `train.py` slightly to compensate. The number of tokens per fwd/bwd pass is the product of these two.
|
| 76 |
+
4. Also in `prepare.py`, you'll want to decrease `EVAL_TOKENS` so that your validation loss is evaluated on a lot less data.
|
| 77 |
+
5. In `train.py`, the primary single knob that controls model complexity is the `DEPTH` (default 8, here). A lot of variables are just functions of this, so e.g. lower it down to e.g. 4.
|
| 78 |
+
6. You'll want to most likely use `WINDOW_PATTERN` of just "L", because "SSSL" uses alternating banded attention pattern that may be very inefficient for you. Try it.
|
| 79 |
+
7. You'll want to lower `TOTAL_BATCH_SIZE` a lot, but keep it powers of 2, e.g. down to `2**14` (~16K) or so even, hard to tell.
|
| 80 |
+
|
| 81 |
+
I think these would be the reasonable hyperparameters to play with. Ask your favorite coding agent for help and copy paste them this guide, as well as the full source code.
|
| 82 |
+
|
| 83 |
+
## Notable forks
|
| 84 |
+
|
| 85 |
+
- [miolini/autoresearch-macos](https://github.com/miolini/autoresearch-macos) (MacOS)
|
| 86 |
+
- [trevin-creator/autoresearch-mlx](https://github.com/trevin-creator/autoresearch-mlx) (MacOS)
|
| 87 |
+
- [jsegov/autoresearch-win-rtx](https://github.com/jsegov/autoresearch-win-rtx) (Windows)
|
| 88 |
+
|
| 89 |
+
## License
|
| 90 |
+
|
| 91 |
+
MIT
|
autoresearch/prepare.py
ADDED
|
@@ -0,0 +1,389 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
One-time data preparation for autoresearch experiments.
|
| 3 |
+
Downloads data shards and trains a BPE tokenizer.
|
| 4 |
+
|
| 5 |
+
Usage:
|
| 6 |
+
python prepare.py # full prep (download + tokenizer)
|
| 7 |
+
python prepare.py --num-shards 8 # download only 8 shards (for testing)
|
| 8 |
+
|
| 9 |
+
Data and tokenizer are stored in ~/.cache/autoresearch/.
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import os
|
| 13 |
+
import sys
|
| 14 |
+
import time
|
| 15 |
+
import math
|
| 16 |
+
import argparse
|
| 17 |
+
import pickle
|
| 18 |
+
from multiprocessing import Pool
|
| 19 |
+
|
| 20 |
+
import requests
|
| 21 |
+
import pyarrow.parquet as pq
|
| 22 |
+
import rustbpe
|
| 23 |
+
import tiktoken
|
| 24 |
+
import torch
|
| 25 |
+
|
| 26 |
+
# ---------------------------------------------------------------------------
|
| 27 |
+
# Constants (fixed, do not modify)
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
|
| 30 |
+
MAX_SEQ_LEN = 2048 # context length
|
| 31 |
+
TIME_BUDGET = 300 # training time budget in seconds (5 minutes)
|
| 32 |
+
EVAL_TOKENS = 40 * 524288 # number of tokens for val eval
|
| 33 |
+
|
| 34 |
+
# ---------------------------------------------------------------------------
|
| 35 |
+
# Configuration
|
| 36 |
+
# ---------------------------------------------------------------------------
|
| 37 |
+
|
| 38 |
+
CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "autoresearch")
|
| 39 |
+
DATA_DIR = os.path.join(CACHE_DIR, "data")
|
| 40 |
+
TOKENIZER_DIR = os.path.join(CACHE_DIR, "tokenizer")
|
| 41 |
+
BASE_URL = "https://huggingface.co/datasets/karpathy/climbmix-400b-shuffle/resolve/main"
|
| 42 |
+
MAX_SHARD = 6542 # the last datashard is shard_06542.parquet
|
| 43 |
+
VAL_SHARD = MAX_SHARD # pinned validation shard (shard_06542)
|
| 44 |
+
VAL_FILENAME = f"shard_{VAL_SHARD:05d}.parquet"
|
| 45 |
+
VOCAB_SIZE = 8192
|
| 46 |
+
|
| 47 |
+
# BPE split pattern (GPT-4 style, with \p{N}{1,2} instead of {1,3})
|
| 48 |
+
SPLIT_PATTERN = r"""'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}+|\p{N}{1,2}| ?[^\s\p{L}\p{N}]++[\r\n]*|\s*[\r\n]|\s+(?!\S)|\s+"""
|
| 49 |
+
|
| 50 |
+
SPECIAL_TOKENS = [f"<|reserved_{i}|>" for i in range(4)]
|
| 51 |
+
BOS_TOKEN = "<|reserved_0|>"
|
| 52 |
+
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
# Data download
|
| 55 |
+
# ---------------------------------------------------------------------------
|
| 56 |
+
|
| 57 |
+
def download_single_shard(index):
|
| 58 |
+
"""Download one parquet shard with retries. Returns True on success."""
|
| 59 |
+
filename = f"shard_{index:05d}.parquet"
|
| 60 |
+
filepath = os.path.join(DATA_DIR, filename)
|
| 61 |
+
if os.path.exists(filepath):
|
| 62 |
+
return True
|
| 63 |
+
|
| 64 |
+
url = f"{BASE_URL}/{filename}"
|
| 65 |
+
max_attempts = 5
|
| 66 |
+
for attempt in range(1, max_attempts + 1):
|
| 67 |
+
try:
|
| 68 |
+
response = requests.get(url, stream=True, timeout=30)
|
| 69 |
+
response.raise_for_status()
|
| 70 |
+
temp_path = filepath + ".tmp"
|
| 71 |
+
with open(temp_path, "wb") as f:
|
| 72 |
+
for chunk in response.iter_content(chunk_size=1024 * 1024):
|
| 73 |
+
if chunk:
|
| 74 |
+
f.write(chunk)
|
| 75 |
+
os.rename(temp_path, filepath)
|
| 76 |
+
print(f" Downloaded {filename}")
|
| 77 |
+
return True
|
| 78 |
+
except (requests.RequestException, IOError) as e:
|
| 79 |
+
print(f" Attempt {attempt}/{max_attempts} failed for {filename}: {e}")
|
| 80 |
+
for path in [filepath + ".tmp", filepath]:
|
| 81 |
+
if os.path.exists(path):
|
| 82 |
+
try:
|
| 83 |
+
os.remove(path)
|
| 84 |
+
except OSError:
|
| 85 |
+
pass
|
| 86 |
+
if attempt < max_attempts:
|
| 87 |
+
time.sleep(2 ** attempt)
|
| 88 |
+
return False
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def download_data(num_shards, download_workers=8):
|
| 92 |
+
"""Download training shards + pinned validation shard."""
|
| 93 |
+
os.makedirs(DATA_DIR, exist_ok=True)
|
| 94 |
+
num_train = min(num_shards, MAX_SHARD)
|
| 95 |
+
ids = list(range(num_train))
|
| 96 |
+
if VAL_SHARD not in ids:
|
| 97 |
+
ids.append(VAL_SHARD)
|
| 98 |
+
|
| 99 |
+
# Count what's already downloaded
|
| 100 |
+
existing = sum(1 for i in ids if os.path.exists(os.path.join(DATA_DIR, f"shard_{i:05d}.parquet")))
|
| 101 |
+
if existing == len(ids):
|
| 102 |
+
print(f"Data: all {len(ids)} shards already downloaded at {DATA_DIR}")
|
| 103 |
+
return
|
| 104 |
+
|
| 105 |
+
needed = len(ids) - existing
|
| 106 |
+
print(f"Data: downloading {needed} shards ({existing} already exist)...")
|
| 107 |
+
|
| 108 |
+
workers = max(1, min(download_workers, needed))
|
| 109 |
+
with Pool(processes=workers) as pool:
|
| 110 |
+
results = pool.map(download_single_shard, ids)
|
| 111 |
+
|
| 112 |
+
ok = sum(1 for r in results if r)
|
| 113 |
+
print(f"Data: {ok}/{len(ids)} shards ready at {DATA_DIR}")
|
| 114 |
+
|
| 115 |
+
# ---------------------------------------------------------------------------
|
| 116 |
+
# Tokenizer training
|
| 117 |
+
# ---------------------------------------------------------------------------
|
| 118 |
+
|
| 119 |
+
def list_parquet_files():
|
| 120 |
+
"""Return sorted list of parquet file paths in the data directory."""
|
| 121 |
+
files = sorted(f for f in os.listdir(DATA_DIR) if f.endswith(".parquet") and not f.endswith(".tmp"))
|
| 122 |
+
return [os.path.join(DATA_DIR, f) for f in files]
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def text_iterator(max_chars=1_000_000_000, doc_cap=10_000):
|
| 126 |
+
"""Yield documents from training split (all shards except pinned val shard)."""
|
| 127 |
+
parquet_paths = [p for p in list_parquet_files() if not p.endswith(VAL_FILENAME)]
|
| 128 |
+
nchars = 0
|
| 129 |
+
for filepath in parquet_paths:
|
| 130 |
+
pf = pq.ParquetFile(filepath)
|
| 131 |
+
for rg_idx in range(pf.num_row_groups):
|
| 132 |
+
rg = pf.read_row_group(rg_idx)
|
| 133 |
+
for text in rg.column("text").to_pylist():
|
| 134 |
+
doc = text[:doc_cap] if len(text) > doc_cap else text
|
| 135 |
+
nchars += len(doc)
|
| 136 |
+
yield doc
|
| 137 |
+
if nchars >= max_chars:
|
| 138 |
+
return
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def train_tokenizer():
|
| 142 |
+
"""Train BPE tokenizer using rustbpe, save as tiktoken pickle."""
|
| 143 |
+
tokenizer_pkl = os.path.join(TOKENIZER_DIR, "tokenizer.pkl")
|
| 144 |
+
token_bytes_path = os.path.join(TOKENIZER_DIR, "token_bytes.pt")
|
| 145 |
+
|
| 146 |
+
if os.path.exists(tokenizer_pkl) and os.path.exists(token_bytes_path):
|
| 147 |
+
print(f"Tokenizer: already trained at {TOKENIZER_DIR}")
|
| 148 |
+
return
|
| 149 |
+
|
| 150 |
+
os.makedirs(TOKENIZER_DIR, exist_ok=True)
|
| 151 |
+
|
| 152 |
+
parquet_files = list_parquet_files()
|
| 153 |
+
if len(parquet_files) < 2:
|
| 154 |
+
print("Tokenizer: need at least 2 data shards (1 train + 1 val). Download more data first.")
|
| 155 |
+
sys.exit(1)
|
| 156 |
+
|
| 157 |
+
# --- Train with rustbpe ---
|
| 158 |
+
print("Tokenizer: training BPE tokenizer...")
|
| 159 |
+
t0 = time.time()
|
| 160 |
+
|
| 161 |
+
tokenizer = rustbpe.Tokenizer()
|
| 162 |
+
vocab_size_no_special = VOCAB_SIZE - len(SPECIAL_TOKENS)
|
| 163 |
+
tokenizer.train_from_iterator(text_iterator(), vocab_size_no_special, pattern=SPLIT_PATTERN)
|
| 164 |
+
|
| 165 |
+
# Build tiktoken encoding from trained merges
|
| 166 |
+
pattern = tokenizer.get_pattern()
|
| 167 |
+
mergeable_ranks = {bytes(k): v for k, v in tokenizer.get_mergeable_ranks()}
|
| 168 |
+
tokens_offset = len(mergeable_ranks)
|
| 169 |
+
special_tokens = {name: tokens_offset + i for i, name in enumerate(SPECIAL_TOKENS)}
|
| 170 |
+
enc = tiktoken.Encoding(
|
| 171 |
+
name="rustbpe",
|
| 172 |
+
pat_str=pattern,
|
| 173 |
+
mergeable_ranks=mergeable_ranks,
|
| 174 |
+
special_tokens=special_tokens,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# Save tokenizer
|
| 178 |
+
with open(tokenizer_pkl, "wb") as f:
|
| 179 |
+
pickle.dump(enc, f)
|
| 180 |
+
|
| 181 |
+
t1 = time.time()
|
| 182 |
+
print(f"Tokenizer: trained in {t1 - t0:.1f}s, saved to {tokenizer_pkl}")
|
| 183 |
+
|
| 184 |
+
# --- Build token_bytes lookup for BPB evaluation ---
|
| 185 |
+
print("Tokenizer: building token_bytes lookup...")
|
| 186 |
+
special_set = set(SPECIAL_TOKENS)
|
| 187 |
+
token_bytes_list = []
|
| 188 |
+
for token_id in range(enc.n_vocab):
|
| 189 |
+
token_str = enc.decode([token_id])
|
| 190 |
+
if token_str in special_set:
|
| 191 |
+
token_bytes_list.append(0)
|
| 192 |
+
else:
|
| 193 |
+
token_bytes_list.append(len(token_str.encode("utf-8")))
|
| 194 |
+
token_bytes_tensor = torch.tensor(token_bytes_list, dtype=torch.int32)
|
| 195 |
+
torch.save(token_bytes_tensor, token_bytes_path)
|
| 196 |
+
print(f"Tokenizer: saved token_bytes to {token_bytes_path}")
|
| 197 |
+
|
| 198 |
+
# Sanity check
|
| 199 |
+
test = "Hello world! Numbers: 123. Unicode: 你好"
|
| 200 |
+
encoded = enc.encode_ordinary(test)
|
| 201 |
+
decoded = enc.decode(encoded)
|
| 202 |
+
assert decoded == test, f"Tokenizer roundtrip failed: {test!r} -> {decoded!r}"
|
| 203 |
+
print(f"Tokenizer: sanity check passed (vocab_size={enc.n_vocab})")
|
| 204 |
+
|
| 205 |
+
# ---------------------------------------------------------------------------
|
| 206 |
+
# Runtime utilities (imported by train.py)
|
| 207 |
+
# ---------------------------------------------------------------------------
|
| 208 |
+
|
| 209 |
+
class Tokenizer:
|
| 210 |
+
"""Minimal tokenizer wrapper. Training is handled above."""
|
| 211 |
+
|
| 212 |
+
def __init__(self, enc):
|
| 213 |
+
self.enc = enc
|
| 214 |
+
self.bos_token_id = enc.encode_single_token(BOS_TOKEN)
|
| 215 |
+
|
| 216 |
+
@classmethod
|
| 217 |
+
def from_directory(cls, tokenizer_dir=TOKENIZER_DIR):
|
| 218 |
+
with open(os.path.join(tokenizer_dir, "tokenizer.pkl"), "rb") as f:
|
| 219 |
+
enc = pickle.load(f)
|
| 220 |
+
return cls(enc)
|
| 221 |
+
|
| 222 |
+
def get_vocab_size(self):
|
| 223 |
+
return self.enc.n_vocab
|
| 224 |
+
|
| 225 |
+
def get_bos_token_id(self):
|
| 226 |
+
return self.bos_token_id
|
| 227 |
+
|
| 228 |
+
def encode(self, text, prepend=None, num_threads=8):
|
| 229 |
+
if prepend is not None:
|
| 230 |
+
prepend_id = prepend if isinstance(prepend, int) else self.enc.encode_single_token(prepend)
|
| 231 |
+
if isinstance(text, str):
|
| 232 |
+
ids = self.enc.encode_ordinary(text)
|
| 233 |
+
if prepend is not None:
|
| 234 |
+
ids.insert(0, prepend_id)
|
| 235 |
+
elif isinstance(text, list):
|
| 236 |
+
ids = self.enc.encode_ordinary_batch(text, num_threads=num_threads)
|
| 237 |
+
if prepend is not None:
|
| 238 |
+
for row in ids:
|
| 239 |
+
row.insert(0, prepend_id)
|
| 240 |
+
else:
|
| 241 |
+
raise ValueError(f"Invalid input type: {type(text)}")
|
| 242 |
+
return ids
|
| 243 |
+
|
| 244 |
+
def decode(self, ids):
|
| 245 |
+
return self.enc.decode(ids)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def get_token_bytes(device="cpu"):
|
| 249 |
+
path = os.path.join(TOKENIZER_DIR, "token_bytes.pt")
|
| 250 |
+
with open(path, "rb") as f:
|
| 251 |
+
return torch.load(f, map_location=device)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def _document_batches(split, tokenizer_batch_size=128):
|
| 255 |
+
"""Infinite iterator over document batches from parquet files."""
|
| 256 |
+
parquet_paths = list_parquet_files()
|
| 257 |
+
assert len(parquet_paths) > 0, "No parquet files found. Run prepare.py first."
|
| 258 |
+
val_path = os.path.join(DATA_DIR, VAL_FILENAME)
|
| 259 |
+
if split == "train":
|
| 260 |
+
parquet_paths = [p for p in parquet_paths if p != val_path]
|
| 261 |
+
assert len(parquet_paths) > 0, "No training shards found."
|
| 262 |
+
else:
|
| 263 |
+
parquet_paths = [val_path]
|
| 264 |
+
epoch = 1
|
| 265 |
+
while True:
|
| 266 |
+
for filepath in parquet_paths:
|
| 267 |
+
pf = pq.ParquetFile(filepath)
|
| 268 |
+
for rg_idx in range(pf.num_row_groups):
|
| 269 |
+
rg = pf.read_row_group(rg_idx)
|
| 270 |
+
batch = rg.column('text').to_pylist()
|
| 271 |
+
for i in range(0, len(batch), tokenizer_batch_size):
|
| 272 |
+
yield batch[i:i+tokenizer_batch_size], epoch
|
| 273 |
+
epoch += 1
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def make_dataloader(tokenizer, B, T, split, buffer_size=1000):
|
| 277 |
+
"""
|
| 278 |
+
BOS-aligned dataloader with best-fit packing.
|
| 279 |
+
Every row starts with BOS. Documents packed using best-fit to minimize cropping.
|
| 280 |
+
When no document fits remaining space, crops shortest doc to fill exactly.
|
| 281 |
+
100% utilization (no padding).
|
| 282 |
+
"""
|
| 283 |
+
assert split in ["train", "val"]
|
| 284 |
+
row_capacity = T + 1
|
| 285 |
+
batches = _document_batches(split)
|
| 286 |
+
bos_token = tokenizer.get_bos_token_id()
|
| 287 |
+
doc_buffer = []
|
| 288 |
+
epoch = 1
|
| 289 |
+
|
| 290 |
+
def refill_buffer():
|
| 291 |
+
nonlocal epoch
|
| 292 |
+
doc_batch, epoch = next(batches)
|
| 293 |
+
token_lists = tokenizer.encode(doc_batch, prepend=bos_token)
|
| 294 |
+
doc_buffer.extend(token_lists)
|
| 295 |
+
|
| 296 |
+
# Pre-allocate buffers: [inputs (B*T) | targets (B*T)]
|
| 297 |
+
row_buffer = torch.empty((B, row_capacity), dtype=torch.long)
|
| 298 |
+
cpu_buffer = torch.empty(2 * B * T, dtype=torch.long, pin_memory=True)
|
| 299 |
+
gpu_buffer = torch.empty(2 * B * T, dtype=torch.long, device="cuda")
|
| 300 |
+
cpu_inputs = cpu_buffer[:B * T].view(B, T)
|
| 301 |
+
cpu_targets = cpu_buffer[B * T:].view(B, T)
|
| 302 |
+
inputs = gpu_buffer[:B * T].view(B, T)
|
| 303 |
+
targets = gpu_buffer[B * T:].view(B, T)
|
| 304 |
+
|
| 305 |
+
while True:
|
| 306 |
+
for row_idx in range(B):
|
| 307 |
+
pos = 0
|
| 308 |
+
while pos < row_capacity:
|
| 309 |
+
while len(doc_buffer) < buffer_size:
|
| 310 |
+
refill_buffer()
|
| 311 |
+
|
| 312 |
+
remaining = row_capacity - pos
|
| 313 |
+
|
| 314 |
+
# Find largest doc that fits entirely
|
| 315 |
+
best_idx = -1
|
| 316 |
+
best_len = 0
|
| 317 |
+
for i, doc in enumerate(doc_buffer):
|
| 318 |
+
doc_len = len(doc)
|
| 319 |
+
if doc_len <= remaining and doc_len > best_len:
|
| 320 |
+
best_idx = i
|
| 321 |
+
best_len = doc_len
|
| 322 |
+
|
| 323 |
+
if best_idx >= 0:
|
| 324 |
+
doc = doc_buffer.pop(best_idx)
|
| 325 |
+
row_buffer[row_idx, pos:pos + len(doc)] = torch.tensor(doc, dtype=torch.long)
|
| 326 |
+
pos += len(doc)
|
| 327 |
+
else:
|
| 328 |
+
# No doc fits — crop shortest to fill remaining
|
| 329 |
+
shortest_idx = min(range(len(doc_buffer)), key=lambda i: len(doc_buffer[i]))
|
| 330 |
+
doc = doc_buffer.pop(shortest_idx)
|
| 331 |
+
row_buffer[row_idx, pos:pos + remaining] = torch.tensor(doc[:remaining], dtype=torch.long)
|
| 332 |
+
pos += remaining
|
| 333 |
+
|
| 334 |
+
cpu_inputs.copy_(row_buffer[:, :-1])
|
| 335 |
+
cpu_targets.copy_(row_buffer[:, 1:])
|
| 336 |
+
gpu_buffer.copy_(cpu_buffer, non_blocking=True)
|
| 337 |
+
yield inputs, targets, epoch
|
| 338 |
+
|
| 339 |
+
# ---------------------------------------------------------------------------
|
| 340 |
+
# Evaluation (DO NOT CHANGE — this is the fixed metric)
|
| 341 |
+
# ---------------------------------------------------------------------------
|
| 342 |
+
|
| 343 |
+
@torch.no_grad()
|
| 344 |
+
def evaluate_bpb(model, tokenizer, batch_size):
|
| 345 |
+
"""
|
| 346 |
+
Bits per byte (BPB): vocab size-independent evaluation metric.
|
| 347 |
+
Sums per-token cross-entropy (in nats), sums target byte lengths,
|
| 348 |
+
then converts nats/byte to bits/byte. Special tokens (byte length 0)
|
| 349 |
+
are excluded from both sums.
|
| 350 |
+
Uses fixed MAX_SEQ_LEN so results are comparable across configs.
|
| 351 |
+
"""
|
| 352 |
+
token_bytes = get_token_bytes(device="cuda")
|
| 353 |
+
val_loader = make_dataloader(tokenizer, batch_size, MAX_SEQ_LEN, "val")
|
| 354 |
+
steps = EVAL_TOKENS // (batch_size * MAX_SEQ_LEN)
|
| 355 |
+
total_nats = 0.0
|
| 356 |
+
total_bytes = 0
|
| 357 |
+
for _ in range(steps):
|
| 358 |
+
x, y, _ = next(val_loader)
|
| 359 |
+
loss_flat = model(x, y, reduction='none').view(-1)
|
| 360 |
+
y_flat = y.view(-1)
|
| 361 |
+
nbytes = token_bytes[y_flat]
|
| 362 |
+
mask = nbytes > 0
|
| 363 |
+
total_nats += (loss_flat * mask).sum().item()
|
| 364 |
+
total_bytes += nbytes.sum().item()
|
| 365 |
+
return total_nats / (math.log(2) * total_bytes)
|
| 366 |
+
|
| 367 |
+
# ---------------------------------------------------------------------------
|
| 368 |
+
# Main
|
| 369 |
+
# ---------------------------------------------------------------------------
|
| 370 |
+
|
| 371 |
+
if __name__ == "__main__":
|
| 372 |
+
parser = argparse.ArgumentParser(description="Prepare data and tokenizer for autoresearch")
|
| 373 |
+
parser.add_argument("--num-shards", type=int, default=10, help="Number of training shards to download (-1 = all). Val shard is always pinned.")
|
| 374 |
+
parser.add_argument("--download-workers", type=int, default=8, help="Number of parallel download workers")
|
| 375 |
+
args = parser.parse_args()
|
| 376 |
+
|
| 377 |
+
num_shards = MAX_SHARD if args.num_shards == -1 else args.num_shards
|
| 378 |
+
|
| 379 |
+
print(f"Cache directory: {CACHE_DIR}")
|
| 380 |
+
print()
|
| 381 |
+
|
| 382 |
+
# Step 1: Download data
|
| 383 |
+
download_data(num_shards, download_workers=args.download_workers)
|
| 384 |
+
print()
|
| 385 |
+
|
| 386 |
+
# Step 2: Train tokenizer
|
| 387 |
+
train_tokenizer()
|
| 388 |
+
print()
|
| 389 |
+
print("Done! Ready to train.")
|
autoresearch/program.md
ADDED
|
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# autoresearch
|
| 2 |
+
|
| 3 |
+
This is an experiment to have the LLM do its own research.
|
| 4 |
+
|
| 5 |
+
## Setup
|
| 6 |
+
|
| 7 |
+
To set up a new experiment, work with the user to:
|
| 8 |
+
|
| 9 |
+
1. **Agree on a run tag**: propose a tag based on today's date (e.g. `mar5`). The branch `autoresearch/<tag>` must not already exist — this is a fresh run.
|
| 10 |
+
2. **Create the branch**: `git checkout -b autoresearch/<tag>` from current master.
|
| 11 |
+
3. **Read the in-scope files**: The repo is small. Read these files for full context:
|
| 12 |
+
- `README.md` — repository context.
|
| 13 |
+
- `prepare.py` — fixed constants, data prep, tokenizer, dataloader, evaluation. Do not modify.
|
| 14 |
+
- `train.py` — the file you modify. Model architecture, optimizer, training loop.
|
| 15 |
+
4. **Verify data exists**: Check that `~/.cache/autoresearch/` contains data shards and a tokenizer. If not, tell the human to run `uv run prepare.py`.
|
| 16 |
+
5. **Initialize results.tsv**: Create `results.tsv` with just the header row. The baseline will be recorded after the first run.
|
| 17 |
+
6. **Confirm and go**: Confirm setup looks good.
|
| 18 |
+
|
| 19 |
+
Once you get confirmation, kick off the experimentation.
|
| 20 |
+
|
| 21 |
+
## Experimentation
|
| 22 |
+
|
| 23 |
+
Each experiment runs on a single GPU. The training script runs for a **fixed time budget of 5 minutes** (wall clock training time, excluding startup/compilation). You launch it simply as: `uv run train.py`.
|
| 24 |
+
|
| 25 |
+
**What you CAN do:**
|
| 26 |
+
- Modify `train.py` — this is the only file you edit. Everything is fair game: model architecture, optimizer, hyperparameters, training loop, batch size, model size, etc.
|
| 27 |
+
|
| 28 |
+
**What you CANNOT do:**
|
| 29 |
+
- Modify `prepare.py`. It is read-only. It contains the fixed evaluation, data loading, tokenizer, and training constants (time budget, sequence length, etc).
|
| 30 |
+
- Install new packages or add dependencies. You can only use what's already in `pyproject.toml`.
|
| 31 |
+
- Modify the evaluation harness. The `evaluate_bpb` function in `prepare.py` is the ground truth metric.
|
| 32 |
+
|
| 33 |
+
**The goal is simple: get the lowest val_bpb.** Since the time budget is fixed, you don't need to worry about training time — it's always 5 minutes. Everything is fair game: change the architecture, the optimizer, the hyperparameters, the batch size, the model size. The only constraint is that the code runs without crashing and finishes within the time budget.
|
| 34 |
+
|
| 35 |
+
**VRAM** is a soft constraint. Some increase is acceptable for meaningful val_bpb gains, but it should not blow up dramatically.
|
| 36 |
+
|
| 37 |
+
**Simplicity criterion**: All else being equal, simpler is better. A small improvement that adds ugly complexity is not worth it. Conversely, removing something and getting equal or better results is a great outcome — that's a simplification win. When evaluating whether to keep a change, weigh the complexity cost against the improvement magnitude. A 0.001 val_bpb improvement that adds 20 lines of hacky code? Probably not worth it. A 0.001 val_bpb improvement from deleting code? Definitely keep. An improvement of ~0 but much simpler code? Keep.
|
| 38 |
+
|
| 39 |
+
**The first run**: Your very first run should always be to establish the baseline, so you will run the training script as is.
|
| 40 |
+
|
| 41 |
+
## Output format
|
| 42 |
+
|
| 43 |
+
Once the script finishes it prints a summary like this:
|
| 44 |
+
|
| 45 |
+
```
|
| 46 |
+
---
|
| 47 |
+
val_bpb: 0.997900
|
| 48 |
+
training_seconds: 300.1
|
| 49 |
+
total_seconds: 325.9
|
| 50 |
+
peak_vram_mb: 45060.2
|
| 51 |
+
mfu_percent: 39.80
|
| 52 |
+
total_tokens_M: 499.6
|
| 53 |
+
num_steps: 953
|
| 54 |
+
num_params_M: 50.3
|
| 55 |
+
depth: 8
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
+
Note that the script is configured to always stop after 5 minutes, so depending on the computing platform of this computer the numbers might look different. You can extract the key metric from the log file:
|
| 59 |
+
|
| 60 |
+
```
|
| 61 |
+
grep "^val_bpb:" run.log
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
## Logging results
|
| 65 |
+
|
| 66 |
+
When an experiment is done, log it to `results.tsv` (tab-separated, NOT comma-separated — commas break in descriptions).
|
| 67 |
+
|
| 68 |
+
The TSV has a header row and 5 columns:
|
| 69 |
+
|
| 70 |
+
```
|
| 71 |
+
commit val_bpb memory_gb status description
|
| 72 |
+
```
|
| 73 |
+
|
| 74 |
+
1. git commit hash (short, 7 chars)
|
| 75 |
+
2. val_bpb achieved (e.g. 1.234567) — use 0.000000 for crashes
|
| 76 |
+
3. peak memory in GB, round to .1f (e.g. 12.3 — divide peak_vram_mb by 1024) — use 0.0 for crashes
|
| 77 |
+
4. status: `keep`, `discard`, or `crash`
|
| 78 |
+
5. short text description of what this experiment tried
|
| 79 |
+
|
| 80 |
+
Example:
|
| 81 |
+
|
| 82 |
+
```
|
| 83 |
+
commit val_bpb memory_gb status description
|
| 84 |
+
a1b2c3d 0.997900 44.0 keep baseline
|
| 85 |
+
b2c3d4e 0.993200 44.2 keep increase LR to 0.04
|
| 86 |
+
c3d4e5f 1.005000 44.0 discard switch to GeLU activation
|
| 87 |
+
d4e5f6g 0.000000 0.0 crash double model width (OOM)
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
## The experiment loop
|
| 91 |
+
|
| 92 |
+
The experiment runs on a dedicated branch (e.g. `autoresearch/mar5` or `autoresearch/mar5-gpu0`).
|
| 93 |
+
|
| 94 |
+
LOOP FOREVER:
|
| 95 |
+
|
| 96 |
+
1. Look at the git state: the current branch/commit we're on
|
| 97 |
+
2. Tune `train.py` with an experimental idea by directly hacking the code.
|
| 98 |
+
3. git commit
|
| 99 |
+
4. Run the experiment: `uv run train.py > run.log 2>&1` (redirect everything — do NOT use tee or let output flood your context)
|
| 100 |
+
5. Read out the results: `grep "^val_bpb:\|^peak_vram_mb:" run.log`
|
| 101 |
+
6. If the grep output is empty, the run crashed. Run `tail -n 50 run.log` to read the Python stack trace and attempt a fix. If you can't get things to work after more than a few attempts, give up.
|
| 102 |
+
7. Record the results in the tsv (NOTE: do not commit the results.tsv file, leave it untracked by git)
|
| 103 |
+
8. If val_bpb improved (lower), you "advance" the branch, keeping the git commit
|
| 104 |
+
9. If val_bpb is equal or worse, you git reset back to where you started
|
| 105 |
+
|
| 106 |
+
The idea is that you are a completely autonomous researcher trying things out. If they work, keep. If they don't, discard. And you're advancing the branch so that you can iterate. If you feel like you're getting stuck in some way, you can rewind but you should probably do this very very sparingly (if ever).
|
| 107 |
+
|
| 108 |
+
**Timeout**: Each experiment should take ~5 minutes total (+ a few seconds for startup and eval overhead). If a run exceeds 10 minutes, kill it and treat it as a failure (discard and revert).
|
| 109 |
+
|
| 110 |
+
**Crashes**: If a run crashes (OOM, or a bug, or etc.), use your judgment: If it's something dumb and easy to fix (e.g. a typo, a missing import), fix it and re-run. If the idea itself is fundamentally broken, just skip it, log "crash" as the status in the tsv, and move on.
|
| 111 |
+
|
| 112 |
+
**NEVER STOP**: Once the experiment loop has begun (after the initial setup), do NOT pause to ask the human if you should continue. Do NOT ask "should I keep going?" or "is this a good stopping point?". The human might be asleep, or gone from a computer and expects you to continue working *indefinitely* until you are manually stopped. You are autonomous. If you run out of ideas, think harder — read papers referenced in the code, re-read the in-scope files for new angles, try combining previous near-misses, try more radical architectural changes. The loop runs until the human interrupts you, period.
|
| 113 |
+
|
| 114 |
+
As an example use case, a user might leave you running while they sleep. If each experiment takes you ~5 minutes then you can run approx 12/hour, for a total of about 100 over the duration of the average human sleep. The user then wakes up to experimental results, all completed by you while they slept!
|
autoresearch/pyproject.toml
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "autoresearch"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Autonomous pretraining research swarm"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.10"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"kernels>=0.11.7",
|
| 9 |
+
"matplotlib>=3.10.8",
|
| 10 |
+
"numpy>=2.2.6",
|
| 11 |
+
"pandas>=2.3.3",
|
| 12 |
+
"pyarrow>=21.0.0",
|
| 13 |
+
"requests>=2.32.0",
|
| 14 |
+
"rustbpe>=0.1.0",
|
| 15 |
+
"tiktoken>=0.11.0",
|
| 16 |
+
"torch==2.9.1",
|
| 17 |
+
]
|
| 18 |
+
|
| 19 |
+
[tool.uv.sources]
|
| 20 |
+
torch = [
|
| 21 |
+
{ index = "pytorch-cu128" },
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
[[tool.uv.index]]
|
| 25 |
+
name = "pytorch-cu128"
|
| 26 |
+
url = "https://download.pytorch.org/whl/cu128"
|
| 27 |
+
explicit = true
|
autoresearch/train.py
ADDED
|
@@ -0,0 +1,630 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Autoresearch pretraining script. Single-GPU, single-file.
|
| 3 |
+
Cherry-picked and simplified from nanochat.
|
| 4 |
+
Usage: uv run train.py
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
os.environ["PYTORCH_ALLOC_CONF"] = "expandable_segments:True"
|
| 9 |
+
os.environ["HF_HUB_DISABLE_PROGRESS_BARS"] = "1"
|
| 10 |
+
|
| 11 |
+
import gc
|
| 12 |
+
import math
|
| 13 |
+
import time
|
| 14 |
+
from dataclasses import dataclass, asdict
|
| 15 |
+
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
|
| 20 |
+
from kernels import get_kernel
|
| 21 |
+
cap = torch.cuda.get_device_capability()
|
| 22 |
+
# varunneal's FA3 is Hopper only, use kernels-community on non-Hopper GPUs
|
| 23 |
+
repo = "varunneal/flash-attention-3" if cap == (9, 0) else "kernels-community/flash-attn3"
|
| 24 |
+
fa3 = get_kernel(repo).flash_attn_interface
|
| 25 |
+
|
| 26 |
+
from prepare import MAX_SEQ_LEN, TIME_BUDGET, Tokenizer, make_dataloader, evaluate_bpb
|
| 27 |
+
|
| 28 |
+
# ---------------------------------------------------------------------------
|
| 29 |
+
# GPT Model
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class GPTConfig:
|
| 34 |
+
sequence_len: int = 2048
|
| 35 |
+
vocab_size: int = 32768
|
| 36 |
+
n_layer: int = 12
|
| 37 |
+
n_head: int = 6
|
| 38 |
+
n_kv_head: int = 6
|
| 39 |
+
n_embd: int = 768
|
| 40 |
+
window_pattern: str = "SSSL"
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def norm(x):
|
| 44 |
+
return F.rms_norm(x, (x.size(-1),))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def has_ve(layer_idx, n_layer):
|
| 48 |
+
"""Returns True if layer should have Value Embedding (alternating, last always included)."""
|
| 49 |
+
return layer_idx % 2 == (n_layer - 1) % 2
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def apply_rotary_emb(x, cos, sin):
|
| 53 |
+
assert x.ndim == 4
|
| 54 |
+
d = x.shape[3] // 2
|
| 55 |
+
x1, x2 = x[..., :d], x[..., d:]
|
| 56 |
+
y1 = x1 * cos + x2 * sin
|
| 57 |
+
y2 = x1 * (-sin) + x2 * cos
|
| 58 |
+
return torch.cat([y1, y2], 3)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class CausalSelfAttention(nn.Module):
|
| 62 |
+
def __init__(self, config, layer_idx):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.n_head = config.n_head
|
| 65 |
+
self.n_kv_head = config.n_kv_head
|
| 66 |
+
self.n_embd = config.n_embd
|
| 67 |
+
self.head_dim = self.n_embd // self.n_head
|
| 68 |
+
assert self.n_embd % self.n_head == 0
|
| 69 |
+
assert self.n_kv_head <= self.n_head and self.n_head % self.n_kv_head == 0
|
| 70 |
+
self.c_q = nn.Linear(self.n_embd, self.n_head * self.head_dim, bias=False)
|
| 71 |
+
self.c_k = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
| 72 |
+
self.c_v = nn.Linear(self.n_embd, self.n_kv_head * self.head_dim, bias=False)
|
| 73 |
+
self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=False)
|
| 74 |
+
self.ve_gate_channels = 32
|
| 75 |
+
self.ve_gate = nn.Linear(self.ve_gate_channels, self.n_kv_head, bias=False) if has_ve(layer_idx, config.n_layer) else None
|
| 76 |
+
|
| 77 |
+
def forward(self, x, ve, cos_sin, window_size):
|
| 78 |
+
B, T, C = x.size()
|
| 79 |
+
q = self.c_q(x).view(B, T, self.n_head, self.head_dim)
|
| 80 |
+
k = self.c_k(x).view(B, T, self.n_kv_head, self.head_dim)
|
| 81 |
+
v = self.c_v(x).view(B, T, self.n_kv_head, self.head_dim)
|
| 82 |
+
|
| 83 |
+
# Value residual (ResFormer): mix in value embedding with input-dependent gate per head
|
| 84 |
+
if ve is not None:
|
| 85 |
+
ve = ve.view(B, T, self.n_kv_head, self.head_dim)
|
| 86 |
+
gate = 2 * torch.sigmoid(self.ve_gate(x[..., :self.ve_gate_channels]))
|
| 87 |
+
v = v + gate.unsqueeze(-1) * ve
|
| 88 |
+
|
| 89 |
+
cos, sin = cos_sin
|
| 90 |
+
q, k = apply_rotary_emb(q, cos, sin), apply_rotary_emb(k, cos, sin)
|
| 91 |
+
q, k = norm(q), norm(k)
|
| 92 |
+
|
| 93 |
+
y = fa3.flash_attn_func(q, k, v, causal=True, window_size=window_size)
|
| 94 |
+
y = y.contiguous().view(B, T, -1)
|
| 95 |
+
y = self.c_proj(y)
|
| 96 |
+
return y
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class MLP(nn.Module):
|
| 100 |
+
def __init__(self, config):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=False)
|
| 103 |
+
self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=False)
|
| 104 |
+
|
| 105 |
+
def forward(self, x):
|
| 106 |
+
x = self.c_fc(x)
|
| 107 |
+
x = F.relu(x).square()
|
| 108 |
+
x = self.c_proj(x)
|
| 109 |
+
return x
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class Block(nn.Module):
|
| 113 |
+
def __init__(self, config, layer_idx):
|
| 114 |
+
super().__init__()
|
| 115 |
+
self.attn = CausalSelfAttention(config, layer_idx)
|
| 116 |
+
self.mlp = MLP(config)
|
| 117 |
+
|
| 118 |
+
def forward(self, x, ve, cos_sin, window_size):
|
| 119 |
+
x = x + self.attn(norm(x), ve, cos_sin, window_size)
|
| 120 |
+
x = x + self.mlp(norm(x))
|
| 121 |
+
return x
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
class GPT(nn.Module):
|
| 125 |
+
def __init__(self, config):
|
| 126 |
+
super().__init__()
|
| 127 |
+
self.config = config
|
| 128 |
+
self.window_sizes = self._compute_window_sizes(config)
|
| 129 |
+
self.transformer = nn.ModuleDict({
|
| 130 |
+
"wte": nn.Embedding(config.vocab_size, config.n_embd),
|
| 131 |
+
"h": nn.ModuleList([Block(config, i) for i in range(config.n_layer)]),
|
| 132 |
+
})
|
| 133 |
+
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
|
| 134 |
+
self.resid_lambdas = nn.Parameter(torch.ones(config.n_layer))
|
| 135 |
+
self.x0_lambdas = nn.Parameter(torch.zeros(config.n_layer))
|
| 136 |
+
# Value embeddings
|
| 137 |
+
head_dim = config.n_embd // config.n_head
|
| 138 |
+
kv_dim = config.n_kv_head * head_dim
|
| 139 |
+
self.value_embeds = nn.ModuleDict({
|
| 140 |
+
str(i): nn.Embedding(config.vocab_size, kv_dim)
|
| 141 |
+
for i in range(config.n_layer) if has_ve(i, config.n_layer)
|
| 142 |
+
})
|
| 143 |
+
# Rotary embeddings
|
| 144 |
+
self.rotary_seq_len = config.sequence_len * 10
|
| 145 |
+
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
| 146 |
+
self.register_buffer("cos", cos, persistent=False)
|
| 147 |
+
self.register_buffer("sin", sin, persistent=False)
|
| 148 |
+
|
| 149 |
+
@torch.no_grad()
|
| 150 |
+
def init_weights(self):
|
| 151 |
+
# Embedding and unembedding
|
| 152 |
+
torch.nn.init.normal_(self.transformer.wte.weight, mean=0.0, std=1.0)
|
| 153 |
+
torch.nn.init.normal_(self.lm_head.weight, mean=0.0, std=0.001)
|
| 154 |
+
# Transformer blocks
|
| 155 |
+
n_embd = self.config.n_embd
|
| 156 |
+
s = 3**0.5 * n_embd**-0.5
|
| 157 |
+
for block in self.transformer.h:
|
| 158 |
+
torch.nn.init.uniform_(block.attn.c_q.weight, -s, s)
|
| 159 |
+
torch.nn.init.uniform_(block.attn.c_k.weight, -s, s)
|
| 160 |
+
torch.nn.init.uniform_(block.attn.c_v.weight, -s, s)
|
| 161 |
+
torch.nn.init.zeros_(block.attn.c_proj.weight)
|
| 162 |
+
torch.nn.init.uniform_(block.mlp.c_fc.weight, -s, s)
|
| 163 |
+
torch.nn.init.zeros_(block.mlp.c_proj.weight)
|
| 164 |
+
# Per-layer scalars
|
| 165 |
+
self.resid_lambdas.fill_(1.0)
|
| 166 |
+
self.x0_lambdas.fill_(0.1)
|
| 167 |
+
# Value embeddings
|
| 168 |
+
for ve in self.value_embeds.values():
|
| 169 |
+
torch.nn.init.uniform_(ve.weight, -s, s)
|
| 170 |
+
# Gate weights init to zero (sigmoid(0)=0.5, scaled by 2 -> 1.0 = neutral)
|
| 171 |
+
for block in self.transformer.h:
|
| 172 |
+
if block.attn.ve_gate is not None:
|
| 173 |
+
torch.nn.init.zeros_(block.attn.ve_gate.weight)
|
| 174 |
+
# Rotary embeddings
|
| 175 |
+
head_dim = self.config.n_embd // self.config.n_head
|
| 176 |
+
cos, sin = self._precompute_rotary_embeddings(self.rotary_seq_len, head_dim)
|
| 177 |
+
self.cos, self.sin = cos, sin
|
| 178 |
+
# Cast embeddings to bf16
|
| 179 |
+
self.transformer.wte.to(dtype=torch.bfloat16)
|
| 180 |
+
for ve in self.value_embeds.values():
|
| 181 |
+
ve.to(dtype=torch.bfloat16)
|
| 182 |
+
|
| 183 |
+
def _precompute_rotary_embeddings(self, seq_len, head_dim, base=10000, device=None):
|
| 184 |
+
if device is None:
|
| 185 |
+
device = self.transformer.wte.weight.device
|
| 186 |
+
channel_range = torch.arange(0, head_dim, 2, dtype=torch.float32, device=device)
|
| 187 |
+
inv_freq = 1.0 / (base ** (channel_range / head_dim))
|
| 188 |
+
t = torch.arange(seq_len, dtype=torch.float32, device=device)
|
| 189 |
+
freqs = torch.outer(t, inv_freq)
|
| 190 |
+
cos, sin = freqs.cos(), freqs.sin()
|
| 191 |
+
cos, sin = cos.bfloat16(), sin.bfloat16()
|
| 192 |
+
cos, sin = cos[None, :, None, :], sin[None, :, None, :]
|
| 193 |
+
return cos, sin
|
| 194 |
+
|
| 195 |
+
def _compute_window_sizes(self, config):
|
| 196 |
+
pattern = config.window_pattern.upper()
|
| 197 |
+
assert all(c in "SL" for c in pattern)
|
| 198 |
+
long_window = config.sequence_len
|
| 199 |
+
short_window = long_window // 2
|
| 200 |
+
char_to_window = {"L": (long_window, 0), "S": (short_window, 0)}
|
| 201 |
+
window_sizes = []
|
| 202 |
+
for layer_idx in range(config.n_layer):
|
| 203 |
+
char = pattern[layer_idx % len(pattern)]
|
| 204 |
+
window_sizes.append(char_to_window[char])
|
| 205 |
+
window_sizes[-1] = (long_window, 0)
|
| 206 |
+
return window_sizes
|
| 207 |
+
|
| 208 |
+
def estimate_flops(self):
|
| 209 |
+
"""Estimated FLOPs per token (forward + backward)."""
|
| 210 |
+
nparams = sum(p.numel() for p in self.parameters())
|
| 211 |
+
value_embeds_numel = sum(ve.weight.numel() for ve in self.value_embeds.values())
|
| 212 |
+
nparams_exclude = (self.transformer.wte.weight.numel() + value_embeds_numel +
|
| 213 |
+
self.resid_lambdas.numel() + self.x0_lambdas.numel())
|
| 214 |
+
h = self.config.n_head
|
| 215 |
+
q = self.config.n_embd // self.config.n_head
|
| 216 |
+
t = self.config.sequence_len
|
| 217 |
+
attn_flops = 0
|
| 218 |
+
for window_size in self.window_sizes:
|
| 219 |
+
window = window_size[0]
|
| 220 |
+
effective_seq = t if window < 0 else min(window, t)
|
| 221 |
+
attn_flops += 12 * h * q * effective_seq
|
| 222 |
+
return 6 * (nparams - nparams_exclude) + attn_flops
|
| 223 |
+
|
| 224 |
+
def num_scaling_params(self):
|
| 225 |
+
wte = sum(p.numel() for p in self.transformer.wte.parameters())
|
| 226 |
+
value_embeds = sum(p.numel() for p in self.value_embeds.parameters())
|
| 227 |
+
lm_head = sum(p.numel() for p in self.lm_head.parameters())
|
| 228 |
+
transformer_matrices = sum(p.numel() for p in self.transformer.h.parameters())
|
| 229 |
+
scalars = self.resid_lambdas.numel() + self.x0_lambdas.numel()
|
| 230 |
+
total = wte + value_embeds + lm_head + transformer_matrices + scalars
|
| 231 |
+
return {
|
| 232 |
+
'wte': wte, 'value_embeds': value_embeds, 'lm_head': lm_head,
|
| 233 |
+
'transformer_matrices': transformer_matrices, 'scalars': scalars, 'total': total,
|
| 234 |
+
}
|
| 235 |
+
|
| 236 |
+
def setup_optimizer(self, unembedding_lr=0.004, embedding_lr=0.2, matrix_lr=0.02,
|
| 237 |
+
weight_decay=0.0, adam_betas=(0.8, 0.95), scalar_lr=0.5):
|
| 238 |
+
model_dim = self.config.n_embd
|
| 239 |
+
matrix_params = list(self.transformer.h.parameters())
|
| 240 |
+
value_embeds_params = list(self.value_embeds.parameters())
|
| 241 |
+
embedding_params = list(self.transformer.wte.parameters())
|
| 242 |
+
lm_head_params = list(self.lm_head.parameters())
|
| 243 |
+
resid_params = [self.resid_lambdas]
|
| 244 |
+
x0_params = [self.x0_lambdas]
|
| 245 |
+
assert len(list(self.parameters())) == (len(matrix_params) + len(embedding_params) +
|
| 246 |
+
len(lm_head_params) + len(value_embeds_params) + len(resid_params) + len(x0_params))
|
| 247 |
+
# Scale LR ∝ 1/√dmodel (tuned at 768 dim)
|
| 248 |
+
dmodel_lr_scale = (model_dim / 768) ** -0.5
|
| 249 |
+
print(f"Scaling AdamW LRs by 1/sqrt({model_dim}/768) = {dmodel_lr_scale:.6f}")
|
| 250 |
+
param_groups = [
|
| 251 |
+
dict(kind='adamw', params=lm_head_params, lr=unembedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
| 252 |
+
dict(kind='adamw', params=embedding_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
| 253 |
+
dict(kind='adamw', params=value_embeds_params, lr=embedding_lr * dmodel_lr_scale, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
| 254 |
+
dict(kind='adamw', params=resid_params, lr=scalar_lr * 0.01, betas=adam_betas, eps=1e-10, weight_decay=0.0),
|
| 255 |
+
dict(kind='adamw', params=x0_params, lr=scalar_lr, betas=(0.96, 0.95), eps=1e-10, weight_decay=0.0),
|
| 256 |
+
]
|
| 257 |
+
for shape in sorted({p.shape for p in matrix_params}):
|
| 258 |
+
group_params = [p for p in matrix_params if p.shape == shape]
|
| 259 |
+
param_groups.append(dict(
|
| 260 |
+
kind='muon', params=group_params, lr=matrix_lr,
|
| 261 |
+
momentum=0.95, ns_steps=5, beta2=0.95, weight_decay=weight_decay,
|
| 262 |
+
))
|
| 263 |
+
optimizer = MuonAdamW(param_groups)
|
| 264 |
+
for group in optimizer.param_groups:
|
| 265 |
+
group["initial_lr"] = group["lr"]
|
| 266 |
+
return optimizer
|
| 267 |
+
|
| 268 |
+
def forward(self, idx, targets=None, reduction='mean'):
|
| 269 |
+
B, T = idx.size()
|
| 270 |
+
assert T <= self.cos.size(1)
|
| 271 |
+
cos_sin = self.cos[:, :T], self.sin[:, :T]
|
| 272 |
+
|
| 273 |
+
x = self.transformer.wte(idx)
|
| 274 |
+
x = norm(x)
|
| 275 |
+
x0 = x
|
| 276 |
+
for i, block in enumerate(self.transformer.h):
|
| 277 |
+
x = self.resid_lambdas[i] * x + self.x0_lambdas[i] * x0
|
| 278 |
+
ve = self.value_embeds[str(i)](idx) if str(i) in self.value_embeds else None
|
| 279 |
+
x = block(x, ve, cos_sin, self.window_sizes[i])
|
| 280 |
+
x = norm(x)
|
| 281 |
+
|
| 282 |
+
softcap = 15
|
| 283 |
+
logits = self.lm_head(x)
|
| 284 |
+
logits = logits.float()
|
| 285 |
+
logits = softcap * torch.tanh(logits / softcap)
|
| 286 |
+
|
| 287 |
+
if targets is not None:
|
| 288 |
+
loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1),
|
| 289 |
+
ignore_index=-1, reduction=reduction)
|
| 290 |
+
return loss
|
| 291 |
+
return logits
|
| 292 |
+
|
| 293 |
+
# ---------------------------------------------------------------------------
|
| 294 |
+
# Optimizer (MuonAdamW, single GPU only)
|
| 295 |
+
# ---------------------------------------------------------------------------
|
| 296 |
+
|
| 297 |
+
polar_express_coeffs = [
|
| 298 |
+
(8.156554524902461, -22.48329292557795, 15.878769915207462),
|
| 299 |
+
(4.042929935166739, -2.808917465908714, 0.5000178451051316),
|
| 300 |
+
(3.8916678022926607, -2.772484153217685, 0.5060648178503393),
|
| 301 |
+
(3.285753657755655, -2.3681294933425376, 0.46449024233003106),
|
| 302 |
+
(2.3465413258596377, -1.7097828382687081, 0.42323551169305323),
|
| 303 |
+
]
|
| 304 |
+
|
| 305 |
+
@torch.compile(dynamic=False, fullgraph=True)
|
| 306 |
+
def adamw_step_fused(p, grad, exp_avg, exp_avg_sq, step_t, lr_t, beta1_t, beta2_t, eps_t, wd_t):
|
| 307 |
+
p.mul_(1 - lr_t * wd_t)
|
| 308 |
+
exp_avg.lerp_(grad, 1 - beta1_t)
|
| 309 |
+
exp_avg_sq.lerp_(grad.square(), 1 - beta2_t)
|
| 310 |
+
bias1 = 1 - beta1_t ** step_t
|
| 311 |
+
bias2 = 1 - beta2_t ** step_t
|
| 312 |
+
denom = (exp_avg_sq / bias2).sqrt() + eps_t
|
| 313 |
+
step_size = lr_t / bias1
|
| 314 |
+
p.add_(exp_avg / denom, alpha=-step_size)
|
| 315 |
+
|
| 316 |
+
@torch.compile(dynamic=False, fullgraph=True)
|
| 317 |
+
def muon_step_fused(stacked_grads, stacked_params, momentum_buffer, second_momentum_buffer,
|
| 318 |
+
momentum_t, lr_t, wd_t, beta2_t, ns_steps, red_dim):
|
| 319 |
+
# Nesterov momentum
|
| 320 |
+
momentum = momentum_t.to(stacked_grads.dtype)
|
| 321 |
+
momentum_buffer.lerp_(stacked_grads, 1 - momentum)
|
| 322 |
+
g = stacked_grads.lerp_(momentum_buffer, momentum)
|
| 323 |
+
# Polar express orthogonalization
|
| 324 |
+
X = g.bfloat16()
|
| 325 |
+
X = X / (X.norm(dim=(-2, -1), keepdim=True) * 1.02 + 1e-6)
|
| 326 |
+
if g.size(-2) > g.size(-1):
|
| 327 |
+
for a, b, c in polar_express_coeffs[:ns_steps]:
|
| 328 |
+
A = X.mT @ X
|
| 329 |
+
B = b * A + c * (A @ A)
|
| 330 |
+
X = a * X + X @ B
|
| 331 |
+
else:
|
| 332 |
+
for a, b, c in polar_express_coeffs[:ns_steps]:
|
| 333 |
+
A = X @ X.mT
|
| 334 |
+
B = b * A + c * (A @ A)
|
| 335 |
+
X = a * X + B @ X
|
| 336 |
+
g = X
|
| 337 |
+
# NorMuon variance reduction
|
| 338 |
+
beta2 = beta2_t.to(g.dtype)
|
| 339 |
+
v_mean = g.float().square().mean(dim=red_dim, keepdim=True)
|
| 340 |
+
red_dim_size = g.size(red_dim)
|
| 341 |
+
v_norm_sq = v_mean.sum(dim=(-2, -1), keepdim=True) * red_dim_size
|
| 342 |
+
v_norm = v_norm_sq.sqrt()
|
| 343 |
+
second_momentum_buffer.lerp_(v_mean.to(dtype=second_momentum_buffer.dtype), 1 - beta2)
|
| 344 |
+
step_size = second_momentum_buffer.clamp_min(1e-10).rsqrt()
|
| 345 |
+
scaled_sq_sum = (v_mean * red_dim_size) * step_size.float().square()
|
| 346 |
+
v_norm_new = scaled_sq_sum.sum(dim=(-2, -1), keepdim=True).sqrt()
|
| 347 |
+
final_scale = step_size * (v_norm / v_norm_new.clamp_min(1e-10))
|
| 348 |
+
g = g * final_scale.to(g.dtype)
|
| 349 |
+
# Cautious weight decay + parameter update
|
| 350 |
+
lr = lr_t.to(g.dtype)
|
| 351 |
+
wd = wd_t.to(g.dtype)
|
| 352 |
+
mask = (g * stacked_params) >= 0
|
| 353 |
+
stacked_params.sub_(lr * g + lr * wd * stacked_params * mask)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class MuonAdamW(torch.optim.Optimizer):
|
| 357 |
+
"""Combined optimizer: Muon for 2D matrix params, AdamW for others."""
|
| 358 |
+
|
| 359 |
+
def __init__(self, param_groups):
|
| 360 |
+
super().__init__(param_groups, defaults={})
|
| 361 |
+
# 0-D CPU tensors to avoid torch.compile recompilation when values change
|
| 362 |
+
self._adamw_step_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 363 |
+
self._adamw_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 364 |
+
self._adamw_beta1_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 365 |
+
self._adamw_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 366 |
+
self._adamw_eps_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 367 |
+
self._adamw_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 368 |
+
self._muon_momentum_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 369 |
+
self._muon_lr_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 370 |
+
self._muon_wd_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 371 |
+
self._muon_beta2_t = torch.tensor(0.0, dtype=torch.float32, device="cpu")
|
| 372 |
+
|
| 373 |
+
def _step_adamw(self, group):
|
| 374 |
+
for p in group['params']:
|
| 375 |
+
if p.grad is None:
|
| 376 |
+
continue
|
| 377 |
+
grad = p.grad
|
| 378 |
+
state = self.state[p]
|
| 379 |
+
if not state:
|
| 380 |
+
state['step'] = 0
|
| 381 |
+
state['exp_avg'] = torch.zeros_like(p)
|
| 382 |
+
state['exp_avg_sq'] = torch.zeros_like(p)
|
| 383 |
+
state['step'] += 1
|
| 384 |
+
self._adamw_step_t.fill_(state['step'])
|
| 385 |
+
self._adamw_lr_t.fill_(group['lr'])
|
| 386 |
+
self._adamw_beta1_t.fill_(group['betas'][0])
|
| 387 |
+
self._adamw_beta2_t.fill_(group['betas'][1])
|
| 388 |
+
self._adamw_eps_t.fill_(group['eps'])
|
| 389 |
+
self._adamw_wd_t.fill_(group['weight_decay'])
|
| 390 |
+
adamw_step_fused(p, grad, state['exp_avg'], state['exp_avg_sq'],
|
| 391 |
+
self._adamw_step_t, self._adamw_lr_t, self._adamw_beta1_t,
|
| 392 |
+
self._adamw_beta2_t, self._adamw_eps_t, self._adamw_wd_t)
|
| 393 |
+
|
| 394 |
+
def _step_muon(self, group):
|
| 395 |
+
params = group['params']
|
| 396 |
+
if not params:
|
| 397 |
+
return
|
| 398 |
+
p = params[0]
|
| 399 |
+
state = self.state[p]
|
| 400 |
+
num_params = len(params)
|
| 401 |
+
shape, device, dtype = p.shape, p.device, p.dtype
|
| 402 |
+
if "momentum_buffer" not in state:
|
| 403 |
+
state["momentum_buffer"] = torch.zeros(num_params, *shape, dtype=dtype, device=device)
|
| 404 |
+
if "second_momentum_buffer" not in state:
|
| 405 |
+
state_shape = (num_params, shape[-2], 1) if shape[-2] >= shape[-1] else (num_params, 1, shape[-1])
|
| 406 |
+
state["second_momentum_buffer"] = torch.zeros(state_shape, dtype=dtype, device=device)
|
| 407 |
+
red_dim = -1 if shape[-2] >= shape[-1] else -2
|
| 408 |
+
stacked_grads = torch.stack([p.grad for p in params])
|
| 409 |
+
stacked_params = torch.stack(params)
|
| 410 |
+
self._muon_momentum_t.fill_(group["momentum"])
|
| 411 |
+
self._muon_beta2_t.fill_(group["beta2"] if group["beta2"] is not None else 0.0)
|
| 412 |
+
self._muon_lr_t.fill_(group["lr"] * max(1.0, shape[-2] / shape[-1])**0.5)
|
| 413 |
+
self._muon_wd_t.fill_(group["weight_decay"])
|
| 414 |
+
muon_step_fused(stacked_grads, stacked_params,
|
| 415 |
+
state["momentum_buffer"], state["second_momentum_buffer"],
|
| 416 |
+
self._muon_momentum_t, self._muon_lr_t, self._muon_wd_t,
|
| 417 |
+
self._muon_beta2_t, group["ns_steps"], red_dim)
|
| 418 |
+
torch._foreach_copy_(params, list(stacked_params.unbind(0)))
|
| 419 |
+
|
| 420 |
+
@torch.no_grad()
|
| 421 |
+
def step(self):
|
| 422 |
+
for group in self.param_groups:
|
| 423 |
+
if group['kind'] == 'adamw':
|
| 424 |
+
self._step_adamw(group)
|
| 425 |
+
elif group['kind'] == 'muon':
|
| 426 |
+
self._step_muon(group)
|
| 427 |
+
|
| 428 |
+
# ---------------------------------------------------------------------------
|
| 429 |
+
# Hyperparameters (edit these directly, no CLI flags needed)
|
| 430 |
+
# ---------------------------------------------------------------------------
|
| 431 |
+
|
| 432 |
+
# Model architecture
|
| 433 |
+
ASPECT_RATIO = 64 # model_dim = depth * ASPECT_RATIO
|
| 434 |
+
HEAD_DIM = 128 # target head dimension for attention
|
| 435 |
+
WINDOW_PATTERN = "SSSL" # sliding window pattern: L=full, S=half context
|
| 436 |
+
|
| 437 |
+
# Optimization
|
| 438 |
+
TOTAL_BATCH_SIZE = 2**19 # ~524K tokens per optimizer step
|
| 439 |
+
EMBEDDING_LR = 0.6 # learning rate for token embeddings (Adam)
|
| 440 |
+
UNEMBEDDING_LR = 0.004 # learning rate for lm_head (Adam)
|
| 441 |
+
MATRIX_LR = 0.04 # learning rate for matrix parameters (Muon)
|
| 442 |
+
SCALAR_LR = 0.5 # learning rate for per-layer scalars (Adam)
|
| 443 |
+
WEIGHT_DECAY = 0.2 # cautious weight decay for Muon
|
| 444 |
+
ADAM_BETAS = (0.8, 0.95) # Adam beta1, beta2
|
| 445 |
+
WARMUP_RATIO = 0.0 # fraction of time budget for LR warmup
|
| 446 |
+
WARMDOWN_RATIO = 0.5 # fraction of time budget for LR warmdown
|
| 447 |
+
FINAL_LR_FRAC = 0.0 # final LR as fraction of initial
|
| 448 |
+
|
| 449 |
+
# Model size
|
| 450 |
+
DEPTH = 8 # number of transformer layers
|
| 451 |
+
DEVICE_BATCH_SIZE = 128 # per-device batch size (reduce if OOM)
|
| 452 |
+
|
| 453 |
+
# ---------------------------------------------------------------------------
|
| 454 |
+
# Setup: tokenizer, model, optimizer, dataloader
|
| 455 |
+
# ---------------------------------------------------------------------------
|
| 456 |
+
|
| 457 |
+
t_start = time.time()
|
| 458 |
+
torch.manual_seed(42)
|
| 459 |
+
torch.cuda.manual_seed(42)
|
| 460 |
+
torch.set_float32_matmul_precision("high")
|
| 461 |
+
device = torch.device("cuda")
|
| 462 |
+
autocast_ctx = torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16)
|
| 463 |
+
H100_BF16_PEAK_FLOPS = 989.5e12
|
| 464 |
+
|
| 465 |
+
tokenizer = Tokenizer.from_directory()
|
| 466 |
+
vocab_size = tokenizer.get_vocab_size()
|
| 467 |
+
print(f"Vocab size: {vocab_size:,}")
|
| 468 |
+
|
| 469 |
+
def build_model_config(depth):
|
| 470 |
+
base_dim = depth * ASPECT_RATIO
|
| 471 |
+
model_dim = ((base_dim + HEAD_DIM - 1) // HEAD_DIM) * HEAD_DIM
|
| 472 |
+
num_heads = model_dim // HEAD_DIM
|
| 473 |
+
return GPTConfig(
|
| 474 |
+
sequence_len=MAX_SEQ_LEN, vocab_size=vocab_size,
|
| 475 |
+
n_layer=depth, n_head=num_heads, n_kv_head=num_heads, n_embd=model_dim,
|
| 476 |
+
window_pattern=WINDOW_PATTERN,
|
| 477 |
+
)
|
| 478 |
+
|
| 479 |
+
config = build_model_config(DEPTH)
|
| 480 |
+
print(f"Model config: {asdict(config)}")
|
| 481 |
+
|
| 482 |
+
with torch.device("meta"):
|
| 483 |
+
model = GPT(config)
|
| 484 |
+
model.to_empty(device=device)
|
| 485 |
+
model.init_weights()
|
| 486 |
+
|
| 487 |
+
param_counts = model.num_scaling_params()
|
| 488 |
+
print("Parameter counts:")
|
| 489 |
+
for key, value in param_counts.items():
|
| 490 |
+
print(f" {key:24s}: {value:,}")
|
| 491 |
+
num_params = param_counts['total']
|
| 492 |
+
num_flops_per_token = model.estimate_flops()
|
| 493 |
+
print(f"Estimated FLOPs per token: {num_flops_per_token:e}")
|
| 494 |
+
|
| 495 |
+
tokens_per_fwdbwd = DEVICE_BATCH_SIZE * MAX_SEQ_LEN
|
| 496 |
+
assert TOTAL_BATCH_SIZE % tokens_per_fwdbwd == 0
|
| 497 |
+
grad_accum_steps = TOTAL_BATCH_SIZE // tokens_per_fwdbwd
|
| 498 |
+
|
| 499 |
+
optimizer = model.setup_optimizer(
|
| 500 |
+
unembedding_lr=UNEMBEDDING_LR,
|
| 501 |
+
embedding_lr=EMBEDDING_LR,
|
| 502 |
+
scalar_lr=SCALAR_LR,
|
| 503 |
+
adam_betas=ADAM_BETAS,
|
| 504 |
+
matrix_lr=MATRIX_LR,
|
| 505 |
+
weight_decay=WEIGHT_DECAY,
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
model = torch.compile(model, dynamic=False)
|
| 509 |
+
|
| 510 |
+
train_loader = make_dataloader(tokenizer, DEVICE_BATCH_SIZE, MAX_SEQ_LEN, "train")
|
| 511 |
+
x, y, epoch = next(train_loader) # prefetch first batch
|
| 512 |
+
|
| 513 |
+
print(f"Time budget: {TIME_BUDGET}s")
|
| 514 |
+
print(f"Gradient accumulation steps: {grad_accum_steps}")
|
| 515 |
+
|
| 516 |
+
# Schedules (all based on progress = training_time / TIME_BUDGET)
|
| 517 |
+
|
| 518 |
+
def get_lr_multiplier(progress):
|
| 519 |
+
if progress < WARMUP_RATIO:
|
| 520 |
+
return progress / WARMUP_RATIO if WARMUP_RATIO > 0 else 1.0
|
| 521 |
+
elif progress < 1.0 - WARMDOWN_RATIO:
|
| 522 |
+
return 1.0
|
| 523 |
+
else:
|
| 524 |
+
cooldown = (1.0 - progress) / WARMDOWN_RATIO
|
| 525 |
+
return cooldown * 1.0 + (1 - cooldown) * FINAL_LR_FRAC
|
| 526 |
+
|
| 527 |
+
def get_muon_momentum(step):
|
| 528 |
+
frac = min(step / 300, 1)
|
| 529 |
+
return (1 - frac) * 0.85 + frac * 0.95
|
| 530 |
+
|
| 531 |
+
def get_weight_decay(progress):
|
| 532 |
+
return WEIGHT_DECAY * (1 - progress)
|
| 533 |
+
|
| 534 |
+
# ---------------------------------------------------------------------------
|
| 535 |
+
# Training loop
|
| 536 |
+
# ---------------------------------------------------------------------------
|
| 537 |
+
|
| 538 |
+
t_start_training = time.time()
|
| 539 |
+
smooth_train_loss = 0
|
| 540 |
+
total_training_time = 0
|
| 541 |
+
step = 0
|
| 542 |
+
|
| 543 |
+
while True:
|
| 544 |
+
torch.cuda.synchronize()
|
| 545 |
+
t0 = time.time()
|
| 546 |
+
for micro_step in range(grad_accum_steps):
|
| 547 |
+
with autocast_ctx:
|
| 548 |
+
loss = model(x, y)
|
| 549 |
+
train_loss = loss.detach()
|
| 550 |
+
loss = loss / grad_accum_steps
|
| 551 |
+
loss.backward()
|
| 552 |
+
x, y, epoch = next(train_loader)
|
| 553 |
+
|
| 554 |
+
# Progress and schedules
|
| 555 |
+
progress = min(total_training_time / TIME_BUDGET, 1.0)
|
| 556 |
+
lrm = get_lr_multiplier(progress)
|
| 557 |
+
muon_momentum = get_muon_momentum(step)
|
| 558 |
+
muon_weight_decay = get_weight_decay(progress)
|
| 559 |
+
for group in optimizer.param_groups:
|
| 560 |
+
group["lr"] = group["initial_lr"] * lrm
|
| 561 |
+
if group['kind'] == 'muon':
|
| 562 |
+
group["momentum"] = muon_momentum
|
| 563 |
+
group["weight_decay"] = muon_weight_decay
|
| 564 |
+
optimizer.step()
|
| 565 |
+
model.zero_grad(set_to_none=True)
|
| 566 |
+
|
| 567 |
+
train_loss_f = train_loss.item()
|
| 568 |
+
|
| 569 |
+
# Fast fail: abort if loss is exploding or NaN
|
| 570 |
+
if math.isnan(train_loss_f) or train_loss_f > 100:
|
| 571 |
+
print("FAIL")
|
| 572 |
+
exit(1)
|
| 573 |
+
|
| 574 |
+
torch.cuda.synchronize()
|
| 575 |
+
t1 = time.time()
|
| 576 |
+
dt = t1 - t0
|
| 577 |
+
|
| 578 |
+
if step > 10:
|
| 579 |
+
total_training_time += dt
|
| 580 |
+
|
| 581 |
+
# Logging
|
| 582 |
+
ema_beta = 0.9
|
| 583 |
+
smooth_train_loss = ema_beta * smooth_train_loss + (1 - ema_beta) * train_loss_f
|
| 584 |
+
debiased_smooth_loss = smooth_train_loss / (1 - ema_beta**(step + 1))
|
| 585 |
+
pct_done = 100 * progress
|
| 586 |
+
tok_per_sec = int(TOTAL_BATCH_SIZE / dt)
|
| 587 |
+
mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE / dt / H100_BF16_PEAK_FLOPS
|
| 588 |
+
remaining = max(0, TIME_BUDGET - total_training_time)
|
| 589 |
+
|
| 590 |
+
print(f"\rstep {step:05d} ({pct_done:.1f}%) | loss: {debiased_smooth_loss:.6f} | lrm: {lrm:.2f} | dt: {dt*1000:.0f}ms | tok/sec: {tok_per_sec:,} | mfu: {mfu:.1f}% | epoch: {epoch} | remaining: {remaining:.0f}s ", end="", flush=True)
|
| 591 |
+
|
| 592 |
+
# GC management (Python's GC causes ~500ms stalls)
|
| 593 |
+
if step == 0:
|
| 594 |
+
gc.collect()
|
| 595 |
+
gc.freeze()
|
| 596 |
+
gc.disable()
|
| 597 |
+
elif (step + 1) % 5000 == 0:
|
| 598 |
+
gc.collect()
|
| 599 |
+
|
| 600 |
+
step += 1
|
| 601 |
+
|
| 602 |
+
# Time's up — but only stop after warmup steps so we don't count compilation
|
| 603 |
+
if step > 10 and total_training_time >= TIME_BUDGET:
|
| 604 |
+
break
|
| 605 |
+
|
| 606 |
+
print() # newline after \r training log
|
| 607 |
+
|
| 608 |
+
total_tokens = step * TOTAL_BATCH_SIZE
|
| 609 |
+
|
| 610 |
+
# Final eval
|
| 611 |
+
model.eval()
|
| 612 |
+
with autocast_ctx:
|
| 613 |
+
val_bpb = evaluate_bpb(model, tokenizer, DEVICE_BATCH_SIZE)
|
| 614 |
+
|
| 615 |
+
# Final summary
|
| 616 |
+
t_end = time.time()
|
| 617 |
+
startup_time = t_start_training - t_start
|
| 618 |
+
steady_state_mfu = 100 * num_flops_per_token * TOTAL_BATCH_SIZE * (step - 10) / total_training_time / H100_BF16_PEAK_FLOPS if total_training_time > 0 else 0
|
| 619 |
+
peak_vram_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
|
| 620 |
+
|
| 621 |
+
print("---")
|
| 622 |
+
print(f"val_bpb: {val_bpb:.6f}")
|
| 623 |
+
print(f"training_seconds: {total_training_time:.1f}")
|
| 624 |
+
print(f"total_seconds: {t_end - t_start:.1f}")
|
| 625 |
+
print(f"peak_vram_mb: {peak_vram_mb:.1f}")
|
| 626 |
+
print(f"mfu_percent: {steady_state_mfu:.2f}")
|
| 627 |
+
print(f"total_tokens_M: {total_tokens / 1e6:.1f}")
|
| 628 |
+
print(f"num_steps: {step}")
|
| 629 |
+
print(f"num_params_M: {num_params / 1e6:.1f}")
|
| 630 |
+
print(f"depth: {DEPTH}")
|
autoresearch/uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=5.0.0
|
| 2 |
+
spaces
|
| 3 |
+
torch
|
| 4 |
+
transformers
|
| 5 |
+
accelerate
|
| 6 |
+
openai
|