ZeroShot-1B
LLaMA-style 1.19B-param decoder-only LM (24L / 16H / 4 KV heads / 2048 d / 5632 FFN
SwiGLU, 3072 ctx, RoPE, RMSNorm, GQA, tied embeddings, bf16). Three-stage pipeline:
pretrain on FineWeb-Edu, midtrain on a higher-quality mix, then SFT on UltraChat.
Single RTX 5090 (Blackwell sm_120, 32 GB), no DDP, no torch.compile.
File map
| file | role |
|---|---|
train.py |
Everything: model, data, training, inference. 4 subcommands |
run.sh |
Auto-restart wrapper: re-launches on exit code 2 |
requirements.txt |
Python deps (torch nightly cu128 install line in comments) |
Architecture deviations from the brief, justified
The original spec was 24L / 16H / 2048 d / 8192 FFN at 4096 ctx, ~1.56B params. On a 32 GB card with microbatchβ₯4 that's ~29 GB of activations alone, before weights or grads. The spec explicitly allowed dropping FFN to 5632 or ctx to 3072. To get real margin (not just borderline) I did both:
| variant | FFN | ctx | params | act mem (mb=4, bwd) | fits 32 GB? |
|---|---|---|---|---|---|
| spec as-written | 8192 | 4096 | 1.56B | ~29 GB | β |
| FFN trim only | 5632 | 4096 | 1.19B | ~23 GB | borderline |
| ctx trim only | 8192 | 3072 | 1.56B | ~22 GB | borderline |
| both (this) | 5632 | 3072 | 1.19B | ~21 GB | β ~3 GB margin |
Param count = 1.19B (above the 1B floor in the spec). At startup train.py prints
the param count and a VRAM estimate so you can sanity-check before committing 70 hours.
Budget math (this is the load-bearing part)
Throughput estimate: 25k tok/s pessimistic, 35k tok/s optimistic on a 5090, no-compile, bf16, FlashAttn-via-SDPA, 1.19B LLaMA at ctx 3072. (Spec said 25β40k in the user's experience; using 25β35k as the working range.)
Tokens/step: mb=4 Γ ga=20 Γ T=3072 = 245,760 (~0.25M).
| Stage | Steps | Tokens | Wall @25k tok/s | Wall @35k tok/s | Cost worst | Cost best |
|---|---|---|---|---|---|---|
| base | 25,000 | 6.14B | 68.3 hr | 48.8 hr | $24.93 | $17.81 |
| mid | 2,500 | 614M | 6.8 hr | 4.9 hr | $2.49 | $1.78 |
| sft | 2,000 | 196M | 2.2 hr | 1.6 hr | $0.80 | $0.57 |
| total | β | 6.95B | 77.3 hr | 55.3 hr | $28.22 | $20.16 |
Worst case lands at $28.22 with ~6 hours of restart buffer before the $30 ceiling. Best case finishes in 56 hours with $10 to spare. Either way, fits the 95-hour wall clock with margin.
If actual throughput comes in below 25k tok/s (sustained over the first 1000 steps,
visible in the per-step tok/s log line), drop base steps with --steps 20000 to
re-fit. The log line shows elapsed | eta | $cost every 10 steps so this is easy
to monitor.
Microbatch: I considered going higher than 4 (mb=5 estimated at ~26 GB activations, mb=6 estimated OOM) but the throughput win is ~1% (Python overhead per fwd-bwd is single-digit ms vs ~hundreds of ms of GPU work) and the risk of OOM at hour 50 of a 70-hour run isn't worth it. Stayed at mb=4.
Setup on a Vast.ai 5090
Pick any Ubuntu 22.04 + CUDA 12.8-capable image (we install our own PyTorch). ~$0.365/hr for a community 5090.
git clone <your repo> zeroshot-1b && cd zeroshot-1b
python -m venv .venv && source .venv/bin/activate
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128
pip install -r requirements.txt
# verify Blackwell + bf16
python -c "import torch; print(torch.__version__, torch.cuda.get_device_name(0)); \
print('bf16 ok' if torch.cuda.is_bf16_supported() else 'NO BF16')"
# (optional) HF auth β UltraChat is public but you'll want to be logged in for rate
huggingface-cli login
# (optional) WandB
export WANDB_API_KEY=...
chmod +x run.sh
Datasets all stream β nothing is pre-downloaded, so disk usage stays in the few-GB range for HF cache shards.
Run the full pipeline
# Stage 1 β pretrain on FineWeb-Edu (~6B tokens, ~70 hr worst case)
./run.sh base --no-compile
# Stage 2 β midtrain (FineWeb-Edu 90% + Cosmopedia 10%, ~7 hr worst case)
./run.sh mid --checkpoint ckpt_base_final.pt --no-compile
# Stage 3 β SFT on UltraChat (~2 hr worst case)
./run.sh finetune --checkpoint ckpt_mid_final.pt --no-compile
# Resume SFT directly from a mid checkpoint without rerunning mid
./run.sh finetune --checkpoint ckpt_mid_final.pt --no-compile --skip_mid
# Generate from any checkpoint
python train.py generate --checkpoint checkpoints/ckpt_base_final.pt \
--prompt "The history of language models" --max_new 200
python train.py generate --checkpoint checkpoints/ckpt_sft_final.pt \
--prompt "Explain attention to a 12 year old." --chat
run.sh auto-restarts train.py on exit code 2 (data error, OOM, SIGTERM from the
host). Any other failure exits cleanly so you don't loop on real bugs.
Inside a stage, re-launching with no --fresh flag picks up the newest matching ckpt
in ./checkpoints/ and restores model + optimizer + step + RNG + dataset position.
--checkpoint is only for cross-stage weight init β that path resets the optimizer.
Resume guarantees
- Time-based, not step-based: checkpoint every 30 minutes of wall clock. Vast instances die at unpredictable steps, not unpredictable step counts.
- Atomic: write to
<file>.pt.tmp, thenos.replaceto the real path. Never leaves a half-written ckpt. - What's saved: model state, AdamW8bit state, step, samples_consumed, RNG (torch + cuda + numpy + python), train+model configs.
- Last 3 kept per stage prefix; older deleted. Plus
_final.ptat end of stage. - Emergency ckpt on SIGTERM (vast preempt), data error, OOM, or Ctrl-C β
saves to
ckpt_<stage>_emergency.ptand exits 2 βrun.shrestarts. - Streaming retry: each HF stream wrapped in
_retry_iterwith exponential backoff (5 retries, base delay 2s β max ~32s). If exhausted: emergency ckpt + exit 2.
Defaults chosen (note these)
- Tokens per step:
mb=4 Γ ga=20 Γ T=3072 = 245,760. SFT usesga=8for more frequent updates on shorter docs (~98k tok/step). - 8-bit AdamW (
bitsandbytes.optim.AdamW8bit). Embedding has fp32 override (bnb best practice β 8-bit Adam states on the embedding cause instability). Betas (0.9, 0.95), wd 0.1, no decay on 1D norm params. - Cosine schedule: warmup 2000, peak 3e-4, min 3e-5 (base). Mid: 1e-4 β 1e-5. SFT: 5e-5 β 5e-6.
- Gradient clip 1.0.
- Mid mix: 90% FineWeb-Edu (
sample-10BT) / 10% Cosmopedia (web_samples_v2), blended at the document level. - SFT dataset:
HuggingFaceH4/ultrachat_200k(train_sftsplit). Loss is masked on user turns; computed on assistant content + the EOT after each turn. - Chat format: literal
\n<|user|>\n/\n<|assistant|>\ntext markers, tokenized with the standard GPT-2 tokenizer (no vocab surgery). - Per-stage data seed: base/mid/sft use different seeds so they don't iterate the same FineWeb shuffle order across stages.
Override flags
The flags I expect to need most often, on the train subcommands:
| flag | meaning |
|---|---|
--lr 2e-4 |
peak LR for cosine schedule |
--min_lr 2e-5 |
floor LR |
--batch_size 3 |
drop micro batch (use if mb=4 OOMs on this nightly) |
--grad_accum 27 |
bump accum to keep tokens/step constant after mb drop |
--steps 20000 |
shrink stage if throughput is below estimate |
--warmup 1000 |
warmup steps |
--ckpt_dir /workspace/ckpts |
where checkpoints live (point at fast disk) |
--fresh |
ignore existing ckpts in this stage's prefix |
--seed 7 |
RNG + dataset shuffle seed |
--no-wandb |
disable WandB even if WANDB_API_KEY is set |
Things deliberately not included
- DDP / FSDP β single GPU only.
torch.compileβ Blackwell + nightly is flaky; spec says don't.- Activation checkpointing β would let mb go higher but adds ~25-30% compute. Memory analysis above shows we don't need it.
- Eval during training β no time, eval after SFT manually with
generate. - Custom CUDA / Triton kernels β SDPA already dispatches to FlashAttention.
- Downloads last month
- 24
Model tree for TobiasLogic/ZeroShot-1B
Unable to build the model tree, the base model loops to the model itself. Learn more.