| | |
| | """ |
| | Nanomind pretraining script for decoder-only causal LM on JSONL.gz data. |
| | |
| | - Expects input file with one JSON object per line containing a `text` field. |
| | - Streams, tokenizes, and packs sequences to a fixed length for efficient training. |
| | - Uses a small LLaMA-style config by default (RMSNorm + SwiGLU + RoPE, MQA). |
| | |
| | Usage example: |
| | python /workspace/nanomind/train.py \ |
| | --data_path /workspace/nanomind_data/pretrain_1m.jsonl.gz \ |
| | --out_dir /workspace/nanomind_runs/run1 \ |
| | --tokenizer_name hf-internal-testing/llama-tokenizer \ |
| | --seq_len 4096 --global_batch_size 256 \ |
| | --lr 1e-3 --warmup_steps 2000 --max_steps 50000 --bf16 |
| | """ |
| |
|
| | import os |
| | import io |
| | import gc |
| | import gzip |
| | import json |
| | import math |
| | import time |
| | import random |
| | import argparse |
| | from pathlib import Path |
| | from typing import Iterator, List, Dict, Optional |
| |
|
| | import torch |
| | from torch import nn |
| | from torch.utils.data import IterableDataset, DataLoader |
| |
|
| | from transformers import ( |
| | AutoTokenizer, |
| | LlamaConfig, |
| | LlamaForCausalLM, |
| | get_cosine_schedule_with_warmup, |
| | ) |
| |
|
| |
|
| | class JsonlPackedDataset(IterableDataset): |
| | """ |
| | Streams a JSONL(.gz) file of objects with a `text` field, tokenizes, and |
| | packs tokens into fixed-length blocks of `seq_len`. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | data_path: str, |
| | tokenizer, |
| | seq_len: int, |
| | shuffle_lines: bool = False, |
| | add_bos_eos: bool = True, |
| | repeat: bool = True, |
| | buffer_tokens_limit: int = 4_000_000, |
| | ) -> None: |
| | super().__init__() |
| | self.data_path = str(data_path) |
| | self.tokenizer = tokenizer |
| | self.seq_len = int(seq_len) |
| | self.shuffle_lines = bool(shuffle_lines) |
| | self.add_bos_eos = bool(add_bos_eos) |
| | self.repeat = bool(repeat) |
| | self.buffer_tokens_limit = int(buffer_tokens_limit) |
| |
|
| | |
| | self._token_buffer: List[int] = [] |
| |
|
| | def _line_iter(self) -> Iterator[str]: |
| | path = self.data_path |
| | is_gz = path.endswith(".gz") |
| | open_fn = gzip.open if is_gz else open |
| | mode = "rt" |
| | while True: |
| | with open_fn(path, mode, encoding="utf-8") as f: |
| | for line in f: |
| | yield line |
| | if not self.repeat: |
| | break |
| |
|
| | def _yield_blocks(self) -> Iterator[Dict[str, torch.Tensor]]: |
| | bos_id = getattr(self.tokenizer, "bos_token_id", None) |
| | eos_id = getattr(self.tokenizer, "eos_token_id", None) |
| |
|
| | |
| | token_buffer = self._token_buffer |
| | seq_len = self.seq_len |
| |
|
| | for raw_line in self._line_iter(): |
| | raw_line = raw_line.strip() |
| | if not raw_line: |
| | continue |
| | try: |
| | obj = json.loads(raw_line) |
| | except json.JSONDecodeError: |
| | continue |
| | text = obj.get("text") |
| | if not text or len(text) < 10: |
| | continue |
| |
|
| | if self.add_bos_eos and bos_id is not None and eos_id is not None: |
| | encoded = self.tokenizer.encode( |
| | text, add_special_tokens=False |
| | ) |
| | |
| | if not encoded: |
| | continue |
| | token_buffer.append(bos_id) |
| | token_buffer.extend(encoded) |
| | token_buffer.append(eos_id) |
| | else: |
| | encoded = self.tokenizer.encode(text, add_special_tokens=True) |
| | if not encoded: |
| | continue |
| | token_buffer.extend(encoded) |
| |
|
| | |
| | if len(token_buffer) > self.buffer_tokens_limit: |
| | del token_buffer[: len(token_buffer) - self.buffer_tokens_limit] |
| |
|
| | |
| | while len(token_buffer) >= seq_len: |
| | block = token_buffer[:seq_len] |
| | del token_buffer[:seq_len] |
| |
|
| | input_ids = torch.tensor(block, dtype=torch.long) |
| | attention_mask = torch.ones_like(input_ids) |
| | |
| | yield { |
| | "input_ids": input_ids, |
| | "attention_mask": attention_mask, |
| | "labels": input_ids.clone(), |
| | } |
| |
|
| | def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: |
| | |
| | |
| | return self._yield_blocks() |
| |
|
| |
|
| | def build_model_and_tokenizer( |
| | tokenizer_name: Optional[str], |
| | tokenizer_dir: Optional[str], |
| | model_name: Optional[str], |
| | vocab_size_override: Optional[int], |
| | hidden_size: int, |
| | n_layers: int, |
| | n_heads: int, |
| | n_kv_heads: int, |
| | rope_theta: float, |
| | max_position_embeddings: int, |
| | ) -> tuple: |
| | |
| | if tokenizer_name: |
| | tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True) |
| | elif tokenizer_dir: |
| | tokenizer = AutoTokenizer.from_pretrained(tokenizer_dir, use_fast=True) |
| | else: |
| | raise ValueError("Provide --tokenizer_name or --tokenizer_dir") |
| |
|
| | |
| | if tokenizer.pad_token_id is None: |
| | if tokenizer.eos_token_id is not None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| | else: |
| | |
| | tokenizer.add_special_tokens({"pad_token": "[PAD]"}) |
| |
|
| | vocab_size = vocab_size_override or len(tokenizer) |
| |
|
| | |
| | if model_name: |
| | model = LlamaForCausalLM.from_pretrained(model_name) |
| | |
| | if model.get_input_embeddings().weight.shape[0] != vocab_size: |
| | model.resize_token_embeddings(vocab_size) |
| | else: |
| | config = LlamaConfig( |
| | vocab_size=vocab_size, |
| | hidden_size=hidden_size, |
| | intermediate_size=int(hidden_size * 2.2), |
| | num_hidden_layers=n_layers, |
| | num_attention_heads=n_heads, |
| | num_key_value_heads=n_kv_heads, |
| | rms_norm_eps=1e-5, |
| | rope_theta=rope_theta, |
| | max_position_embeddings=max_position_embeddings, |
| | tie_word_embeddings=True, |
| | ) |
| | model = LlamaForCausalLM(config) |
| |
|
| | return model, tokenizer |
| |
|
| |
|
| | def get_dataloader( |
| | data_path: str, |
| | tokenizer, |
| | seq_len: int, |
| | micro_batch_size: int, |
| | num_workers: int, |
| | ) -> DataLoader: |
| | dataset = JsonlPackedDataset( |
| | data_path=data_path, |
| | tokenizer=tokenizer, |
| | seq_len=seq_len, |
| | shuffle_lines=False, |
| | add_bos_eos=True, |
| | repeat=True, |
| | ) |
| | return DataLoader( |
| | dataset, |
| | batch_size=micro_batch_size, |
| | num_workers=num_workers, |
| | pin_memory=True, |
| | drop_last=True, |
| | collate_fn=_collate_batch, |
| | ) |
| |
|
| |
|
| | def _collate_batch(features: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]: |
| | |
| | input_ids = torch.stack([f["input_ids"] for f in features], dim=0) |
| | attention_mask = torch.stack([f["attention_mask"] for f in features], dim=0) |
| | labels = torch.stack([f["labels"] for f in features], dim=0) |
| | return {"input_ids": input_ids, "attention_mask": attention_mask, "labels": labels} |
| |
|
| |
|
| | def parse_args() -> argparse.Namespace: |
| | ap = argparse.ArgumentParser() |
| | |
| | ap.add_argument("--data_path", required=True, help="Path to JSONL(.gz) with {text}") |
| | ap.add_argument("--seq_len", type=int, default=4096) |
| | ap.add_argument("--num_workers", type=int, default=2) |
| |
|
| | |
| | ap.add_argument("--tokenizer_name", default=None, help="HF tokenizer name") |
| | ap.add_argument("--tokenizer_dir", default=None, help="Local dir of HF tokenizer") |
| | ap.add_argument("--model_name", default=None, help="HF model name to continue from (CPT)") |
| | ap.add_argument("--vocab_size_override", type=int, default=None) |
| |
|
| | |
| | ap.add_argument("--hidden_size", type=int, default=768) |
| | ap.add_argument("--n_layers", type=int, default=24) |
| | ap.add_argument("--n_heads", type=int, default=12) |
| | ap.add_argument("--n_kv_heads", type=int, default=1) |
| | ap.add_argument("--rope_theta", type=float, default=1e6) |
| | ap.add_argument("--max_position_embeddings", type=int, default=4096) |
| |
|
| | |
| | ap.add_argument("--out_dir", required=True) |
| | ap.add_argument("--global_batch_size", type=int, default=256) |
| | ap.add_argument("--micro_batch_size", type=int, default=None, help="Per-step batch size before grad accumulation") |
| | ap.add_argument("--lr", type=float, default=1e-3) |
| | ap.add_argument("--weight_decay", type=float, default=0.05) |
| | ap.add_argument("--warmup_steps", type=int, default=2000) |
| | ap.add_argument("--max_steps", type=int, default=50_000) |
| | ap.add_argument("--save_every", type=int, default=2000) |
| | ap.add_argument("--clip_grad", type=float, default=1.0) |
| | ap.add_argument("--bf16", action="store_true") |
| | ap.add_argument("--seed", type=int, default=42) |
| |
|
| | return ap.parse_args() |
| |
|
| |
|
| | def set_seed(seed: int) -> None: |
| | random.seed(seed) |
| | os.environ["PYTHONHASHSEED"] = str(seed) |
| | torch.manual_seed(seed) |
| | torch.cuda.manual_seed_all(seed) |
| |
|
| |
|
| | def main() -> None: |
| | args = parse_args() |
| | set_seed(args.seed) |
| |
|
| | out_dir = Path(args.out_dir) |
| | out_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | torch.backends.cuda.matmul.allow_tf32 = True |
| | torch.backends.cudnn.allow_tf32 = True |
| |
|
| | model, tokenizer = build_model_and_tokenizer( |
| | tokenizer_name=args.tokenizer_name, |
| | tokenizer_dir=args.tokenizer_dir, |
| | model_name=args.model_name, |
| | vocab_size_override=args.vocab_size_override, |
| | hidden_size=args.hidden_size, |
| | n_layers=args.n_layers, |
| | n_heads=args.n_heads, |
| | n_kv_heads=args.n_kv_heads, |
| | rope_theta=args.rope_theta, |
| | max_position_embeddings=args.max_position_embeddings, |
| | ) |
| |
|
| | model = model.to(device) |
| |
|
| | |
| | micro_bs = args.micro_batch_size or min( max(1, args.global_batch_size // 8), args.global_batch_size) |
| | grad_accum = max(1, args.global_batch_size // micro_bs) |
| | train_loader = get_dataloader( |
| | data_path=args.data_path, |
| | tokenizer=tokenizer, |
| | seq_len=args.seq_len, |
| | micro_batch_size=micro_bs, |
| | num_workers=args.num_workers, |
| | ) |
| |
|
| | |
| | optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.95)) |
| | scheduler = get_cosine_schedule_with_warmup( |
| | optimizer=optimizer, |
| | num_warmup_steps=args.warmup_steps, |
| | num_training_steps=args.max_steps, |
| | ) |
| |
|
| | scaler = None |
| | use_bf16 = args.bf16 and torch.cuda.is_available() |
| | autocast_dtype = torch.bfloat16 if use_bf16 else torch.float16 |
| |
|
| | model.train() |
| | step = 0 |
| | running_loss = 0.0 |
| | tokens_per_step = args.global_batch_size * args.seq_len |
| | last_log = time.time() |
| |
|
| | |
| | data_iter = iter(train_loader) |
| | while step < args.max_steps: |
| | optimizer.zero_grad(set_to_none=True) |
| | for micro_step in range(grad_accum): |
| | try: |
| | batch = next(data_iter) |
| | except StopIteration: |
| | data_iter = iter(train_loader) |
| | batch = next(data_iter) |
| |
|
| | input_ids = batch["input_ids"].to(device, non_blocking=True) |
| | attention_mask = batch["attention_mask"].to(device, non_blocking=True) |
| | labels = batch["labels"].to(device, non_blocking=True) |
| |
|
| | with torch.autocast(device_type="cuda", dtype=autocast_dtype, enabled=use_bf16): |
| | outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels) |
| | loss = outputs.loss / grad_accum |
| |
|
| | loss.backward() |
| | running_loss += loss.item() |
| |
|
| | torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad) |
| | optimizer.step() |
| | scheduler.step() |
| | step += 1 |
| |
|
| | |
| | if step % 10 == 0: |
| | now = time.time() |
| | dt = now - last_log |
| | last_log = now |
| | avg_loss = running_loss / 10 |
| | running_loss = 0.0 |
| | ppl = math.exp(avg_loss) if avg_loss < 30 else float("inf") |
| | tokens_sec = tokens_per_step / dt if dt > 0 else 0.0 |
| | print( |
| | f"step {step:6d} | loss {avg_loss:.4f} | ppl {ppl:.2f} | tokens/s {tokens_sec:,.0f} | lr {scheduler.get_last_lr()[0]:.2e}", |
| | flush=True, |
| | ) |
| |
|
| | |
| | if step % args.save_every == 0 or step == args.max_steps: |
| | ckpt_dir = out_dir / f"step_{step:06d}" |
| | ckpt_dir.mkdir(parents=True, exist_ok=True) |
| | model.save_pretrained(ckpt_dir) |
| | tokenizer.save_pretrained(ckpt_dir) |
| |
|
| | |
| | if step % 100 == 0: |
| | gc.collect() |
| | if torch.cuda.is_available(): |
| | torch.cuda.empty_cache() |
| |
|
| | |
| | model.save_pretrained(out_dir / "final") |
| | tokenizer.save_pretrained(out_dir / "final") |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|
| |
|
| |
|