ZeroShot-1B

LLaMA-style 1.19B-param decoder-only LM (24L / 16H / 4 KV heads / 2048 d / 5632 FFN SwiGLU, 3072 ctx, RoPE, RMSNorm, GQA, tied embeddings, bf16). Three-stage pipeline: pretrain on FineWeb-Edu, midtrain on a higher-quality mix, then SFT on UltraChat. Single RTX 5090 (Blackwell sm_120, 32 GB), no DDP, no torch.compile.

File map

file role
train.py Everything: model, data, training, inference. 4 subcommands
run.sh Auto-restart wrapper: re-launches on exit code 2
requirements.txt Python deps (torch nightly cu128 install line in comments)

Architecture deviations from the brief, justified

The original spec was 24L / 16H / 2048 d / 8192 FFN at 4096 ctx, ~1.56B params. On a 32 GB card with microbatchβ‰₯4 that's ~29 GB of activations alone, before weights or grads. The spec explicitly allowed dropping FFN to 5632 or ctx to 3072. To get real margin (not just borderline) I did both:

variant FFN ctx params act mem (mb=4, bwd) fits 32 GB?
spec as-written 8192 4096 1.56B ~29 GB ❌
FFN trim only 5632 4096 1.19B ~23 GB borderline
ctx trim only 8192 3072 1.56B ~22 GB borderline
both (this) 5632 3072 1.19B ~21 GB βœ“ ~3 GB margin

Param count = 1.19B (above the 1B floor in the spec). At startup train.py prints the param count and a VRAM estimate so you can sanity-check before committing 70 hours.

Budget math (this is the load-bearing part)

Throughput estimate: 25k tok/s pessimistic, 35k tok/s optimistic on a 5090, no-compile, bf16, FlashAttn-via-SDPA, 1.19B LLaMA at ctx 3072. (Spec said 25–40k in the user's experience; using 25–35k as the working range.)

Tokens/step: mb=4 Γ— ga=20 Γ— T=3072 = 245,760 (~0.25M).

Stage Steps Tokens Wall @25k tok/s Wall @35k tok/s Cost worst Cost best
base 25,000 6.14B 68.3 hr 48.8 hr $24.93 $17.81
mid 2,500 614M 6.8 hr 4.9 hr $2.49 $1.78
sft 2,000 196M 2.2 hr 1.6 hr $0.80 $0.57
total β€” 6.95B 77.3 hr 55.3 hr $28.22 $20.16

Worst case lands at $28.22 with ~6 hours of restart buffer before the $30 ceiling. Best case finishes in 56 hours with $10 to spare. Either way, fits the 95-hour wall clock with margin.

If actual throughput comes in below 25k tok/s (sustained over the first 1000 steps, visible in the per-step tok/s log line), drop base steps with --steps 20000 to re-fit. The log line shows elapsed | eta | $cost every 10 steps so this is easy to monitor.

Microbatch: I considered going higher than 4 (mb=5 estimated at ~26 GB activations, mb=6 estimated OOM) but the throughput win is ~1% (Python overhead per fwd-bwd is single-digit ms vs ~hundreds of ms of GPU work) and the risk of OOM at hour 50 of a 70-hour run isn't worth it. Stayed at mb=4.

Setup on a Vast.ai 5090

Pick any Ubuntu 22.04 + CUDA 12.8-capable image (we install our own PyTorch). ~$0.365/hr for a community 5090.

git clone <your repo> zeroshot-1b && cd zeroshot-1b

python -m venv .venv && source .venv/bin/activate
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu128
pip install -r requirements.txt

# verify Blackwell + bf16
python -c "import torch; print(torch.__version__, torch.cuda.get_device_name(0)); \
           print('bf16 ok' if torch.cuda.is_bf16_supported() else 'NO BF16')"

# (optional) HF auth β€” UltraChat is public but you'll want to be logged in for rate
huggingface-cli login

# (optional) WandB
export WANDB_API_KEY=...

chmod +x run.sh

Datasets all stream β€” nothing is pre-downloaded, so disk usage stays in the few-GB range for HF cache shards.

Run the full pipeline

# Stage 1 β€” pretrain on FineWeb-Edu (~6B tokens, ~70 hr worst case)
./run.sh base --no-compile

# Stage 2 β€” midtrain (FineWeb-Edu 90% + Cosmopedia 10%, ~7 hr worst case)
./run.sh mid --checkpoint ckpt_base_final.pt --no-compile

# Stage 3 β€” SFT on UltraChat (~2 hr worst case)
./run.sh finetune --checkpoint ckpt_mid_final.pt --no-compile

# Resume SFT directly from a mid checkpoint without rerunning mid
./run.sh finetune --checkpoint ckpt_mid_final.pt --no-compile --skip_mid

# Generate from any checkpoint
python train.py generate --checkpoint checkpoints/ckpt_base_final.pt \
       --prompt "The history of language models" --max_new 200
python train.py generate --checkpoint checkpoints/ckpt_sft_final.pt \
       --prompt "Explain attention to a 12 year old." --chat

run.sh auto-restarts train.py on exit code 2 (data error, OOM, SIGTERM from the host). Any other failure exits cleanly so you don't loop on real bugs.

Inside a stage, re-launching with no --fresh flag picks up the newest matching ckpt in ./checkpoints/ and restores model + optimizer + step + RNG + dataset position. --checkpoint is only for cross-stage weight init β€” that path resets the optimizer.

Resume guarantees

  • Time-based, not step-based: checkpoint every 30 minutes of wall clock. Vast instances die at unpredictable steps, not unpredictable step counts.
  • Atomic: write to <file>.pt.tmp, then os.replace to the real path. Never leaves a half-written ckpt.
  • What's saved: model state, AdamW8bit state, step, samples_consumed, RNG (torch + cuda + numpy + python), train+model configs.
  • Last 3 kept per stage prefix; older deleted. Plus _final.pt at end of stage.
  • Emergency ckpt on SIGTERM (vast preempt), data error, OOM, or Ctrl-C β†’ saves to ckpt_<stage>_emergency.pt and exits 2 β†’ run.sh restarts.
  • Streaming retry: each HF stream wrapped in _retry_iter with exponential backoff (5 retries, base delay 2s β†’ max ~32s). If exhausted: emergency ckpt + exit 2.

Defaults chosen (note these)

  • Tokens per step: mb=4 Γ— ga=20 Γ— T=3072 = 245,760. SFT uses ga=8 for more frequent updates on shorter docs (~98k tok/step).
  • 8-bit AdamW (bitsandbytes.optim.AdamW8bit). Embedding has fp32 override (bnb best practice β€” 8-bit Adam states on the embedding cause instability). Betas (0.9, 0.95), wd 0.1, no decay on 1D norm params.
  • Cosine schedule: warmup 2000, peak 3e-4, min 3e-5 (base). Mid: 1e-4 β†’ 1e-5. SFT: 5e-5 β†’ 5e-6.
  • Gradient clip 1.0.
  • Mid mix: 90% FineWeb-Edu (sample-10BT) / 10% Cosmopedia (web_samples_v2), blended at the document level.
  • SFT dataset: HuggingFaceH4/ultrachat_200k (train_sft split). Loss is masked on user turns; computed on assistant content + the EOT after each turn.
  • Chat format: literal \n<|user|>\n / \n<|assistant|>\n text markers, tokenized with the standard GPT-2 tokenizer (no vocab surgery).
  • Per-stage data seed: base/mid/sft use different seeds so they don't iterate the same FineWeb shuffle order across stages.

Override flags

The flags I expect to need most often, on the train subcommands:

flag meaning
--lr 2e-4 peak LR for cosine schedule
--min_lr 2e-5 floor LR
--batch_size 3 drop micro batch (use if mb=4 OOMs on this nightly)
--grad_accum 27 bump accum to keep tokens/step constant after mb drop
--steps 20000 shrink stage if throughput is below estimate
--warmup 1000 warmup steps
--ckpt_dir /workspace/ckpts where checkpoints live (point at fast disk)
--fresh ignore existing ckpts in this stage's prefix
--seed 7 RNG + dataset shuffle seed
--no-wandb disable WandB even if WANDB_API_KEY is set

Things deliberately not included

  • DDP / FSDP β€” single GPU only.
  • torch.compile β€” Blackwell + nightly is flaky; spec says don't.
  • Activation checkpointing β€” would let mb go higher but adds ~25-30% compute. Memory analysis above shows we don't need it.
  • Eval during training β€” no time, eval after SFT manually with generate.
  • Custom CUDA / Triton kernels β€” SDPA already dispatches to FlashAttention.
Downloads last month
24
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for TobiasLogic/ZeroShot-1B

Unable to build the model tree, the base model loops to the model itself. Learn more.