| """GRPO RL phase for BLT-Reasoner. |
| |
| Architecture-specific differences from Abstract-CoT GRPO (`experiments/abstract_cot/grpo_train.py`): |
| |
| * **z is deterministic given x and current policy weights.** No discrete |
| sampling in the latent. The only stochasticity is in y. So per-prompt we |
| compute z once and sample K different y's from `π(y | x, z)`. This is much |
| cheaper than the abstract-vocab version which had to sample z too. |
| |
| * **Two model instances on the GPU**: policy (trainable, init = SFT ckpt) and |
| reference (frozen, identical init). KL is policy↔reference per y token. We |
| cannot use `peft.disable_adapter()` as the reference, because the SFT ckpt's |
| LoRA + projector + head IS the reference — disabling adapters would compare |
| against vanilla Qwen, which is not what we want. |
| |
| * **InfoNCE optionally retained as a low-weight aux loss** during GRPO so the |
| z geometry doesn't drift while the policy learns to use rewards. This is |
| the BLT-specific safety net. |
| |
| Reward = math-verifier on the extracted final number after `####`. |
| Advantage = group-normalized within K rollouts per prompt. |
| |
| Usage: |
| python -m experiments.blt_reasoner.grpo_train --config configs/grpo_from_step12000.json |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import copy |
| import json |
| import math |
| import os |
| import random |
| import re |
| import sys |
| import time |
| from pathlib import Path |
| from typing import List, Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader |
|
|
| from .data import GSM8KDataset, collate_batch |
| from .losses import InfoNCEHead, encode_answer_for_infonce, infonce_loss |
| from .model import ( |
| NEG, |
| BLTConfig, |
| LatentProjector, |
| build_base, |
| build_blt_mask, |
| forward_with_latent, |
| generate_with_latent, |
| ) |
| from .train import _load_state_from_ckpt |
|
|
|
|
| GSM8K_NUM = re.compile(r"####\s*(-?\d+(?:\.\d+)?)") |
| ANY_NUM = re.compile(r"-?\d+(?:\.\d+)?") |
|
|
|
|
| def parse_pred(text: str) -> Optional[str]: |
| m = GSM8K_NUM.search(text) |
| if m: |
| return m.group(1) |
| nums = ANY_NUM.findall(text) |
| return nums[-1] if nums else None |
|
|
|
|
| def reward_for(decoded: str, gold: str, *, length_pen: float = 0.0, n_tokens: int = 0) -> float: |
| """Math-verifier reward. |
| |
| +1.0 exact-match the gold number |
| -0.5 parseable but wrong, or no number found |
| -length_pen * (n_tokens / 192) small length penalty to discourage rambling |
| """ |
| pred = parse_pred(decoded) |
| base = -0.5 |
| if pred is not None: |
| try: |
| if abs(float(pred) - float(gold)) < 1e-4: |
| base = 1.0 |
| except ValueError: |
| pass |
| return base - length_pen * (n_tokens / 192.0) |
|
|
|
|
| def set_seed(s: int): |
| random.seed(s) |
| torch.manual_seed(s) |
| torch.cuda.manual_seed_all(s) |
|
|
|
|
| def per_token_logp(logits: torch.Tensor, target_ids: torch.Tensor, target_mask: torch.Tensor) -> torch.Tensor: |
| """logits [B, L, V] (already shifted so logits[:, t] predicts target[:, t]); |
| target_ids [B, L]; target_mask [B, L]. Returns [B, L] per-token logp. |
| Masked positions get zero (not NaN). |
| """ |
| logp = F.log_softmax(logits.float(), dim=-1) |
| gathered = logp.gather(-1, target_ids.unsqueeze(-1)).squeeze(-1) |
| return gathered * target_mask.float() |
|
|
|
|
| def forward_logp_for_rollout(model, projector, x_ids, x_attn, y_ids, y_mask, K, *, block_y_to_x=True): |
| """Run forward_with_latent and return per-y-token logps under the given model. |
| |
| `y_ids` is the rollout's sampled answer. We use forward_with_latent to |
| recompute z under the current params and grab logits at y positions. |
| """ |
| logits_y, z, _ = forward_with_latent( |
| model, x_ids, x_attn, y_ids, projector, K, |
| block_y_to_x=block_y_to_x, return_z=True, |
| ) |
| |
| return per_token_logp(logits_y, y_ids, y_mask), z |
|
|
|
|
| def grpo_loss(policy_logp, ref_logp, advantages, y_mask, beta, kl_clamp): |
| """ |
| policy_logp, ref_logp : [B*K, L_y] per-token logps (masked to 0 at pads) |
| advantages : [B*K] |
| y_mask : [B*K, L_y] 1 where real token |
| Returns (loss, info_dict). |
| """ |
| |
| seq_logp = policy_logp.sum(dim=-1) |
| pg = -(advantages * seq_logp).mean() |
|
|
| |
| log_ratio = (policy_logp - ref_logp).clamp(min=-kl_clamp, max=kl_clamp) |
| kl_per_tok = (log_ratio.exp() - 1.0 - log_ratio) |
| n_tok = y_mask.sum().clamp(min=1.0) |
| kl_term = (kl_per_tok * y_mask.float()).sum() / n_tok |
|
|
| loss = pg + beta * kl_term |
| return loss, {"pg": pg.detach(), "kl": kl_term.detach()} |
|
|
|
|
| def make_optimizer(model, projector, head, cfg): |
| try: |
| from bitsandbytes.optim import PagedAdamW8bit as AdamW |
| use_8bit = True |
| except Exception: |
| from torch.optim import AdamW |
| use_8bit = False |
| groups = [ |
| {"params": [p for p in model.parameters() if p.requires_grad], "lr": cfg["lr_lora"]}, |
| {"params": list(projector.parameters()), "lr": cfg["lr_proj"]}, |
| {"params": list(head.parameters()), "lr": cfg["lr_head"]}, |
| ] |
| return AdamW(groups, weight_decay=cfg["weight_decay"]), use_8bit |
|
|
|
|
| def sample_rollouts( |
| model, tokenizer, projector, |
| x_ids, x_attn, *, |
| K_latents, group_size, max_new_tokens, temperature, block_y_to_x, |
| eos_id, pad_id, |
| ): |
| """Sample `group_size` rollouts per prompt by expanding the batch K-fold. |
| |
| Returns (y_ids, y_mask): both [B*K, L_y_gen] |
| """ |
| B = x_ids.size(0) |
| x_ids_rep = x_ids.repeat_interleave(group_size, dim=0) |
| x_attn_rep = x_attn.repeat_interleave(group_size, dim=0) |
|
|
| gen = generate_with_latent( |
| model, tokenizer, projector, |
| x_ids=x_ids_rep, x_attn=x_attn_rep, K=K_latents, |
| block_y_to_x=block_y_to_x, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| eos_token_id=eos_id, |
| ) |
| y_mask = (gen != pad_id).long() |
| return gen, y_mask |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--config", required=True) |
| args = parser.parse_args() |
| with open(args.config) as f: |
| cfg = json.load(f) |
| set_seed(cfg["seed"]) |
|
|
| out_dir = Path(cfg["output_dir"]) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| log_f = open(out_dir / "grpo.log", "a", buffering=1) |
| met_f = open(out_dir / "metrics.jsonl", "a", buffering=1) |
|
|
| def log(m): |
| line = f"[{time.strftime('%H:%M:%S')}] {m}" |
| print(line, flush=True) |
| log_f.write(line + "\n") |
|
|
| log(f"config={json.dumps(cfg)}") |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| bcfg = BLTConfig( |
| base_model=cfg["base_model"], |
| use_lora=cfg["use_lora"], |
| lora_r=cfg["lora_r"], lora_alpha=cfg["lora_alpha"], |
| lora_dropout=cfg["lora_dropout"], |
| lora_target_modules=tuple(cfg["lora_target_modules"]), |
| K_latents=cfg["K_latents"], block_y_to_x=cfg["block_y_to_x"], |
| proj_init_scale=cfg["proj_init_scale"], |
| dtype=cfg["dtype"], attn_impl=cfg["attn_impl"], |
| ) |
|
|
| log("building policy model …") |
| policy_model, tokenizer = build_base(bcfg) |
| policy_model.to(device) |
|
|
| inner_pol = policy_model.get_base_model() if hasattr(policy_model, "get_base_model") else policy_model |
| d_model = inner_pol.config.hidden_size |
| model_dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[cfg["dtype"]] |
| pol_projector = LatentProjector( |
| d_model, init_scale=cfg["proj_init_scale"], |
| use_mlp=cfg.get("proj_mlp", False), |
| hidden_mult=cfg.get("proj_hidden_mult", 4), |
| ).to(device=device, dtype=model_dtype) |
| pol_head = InfoNCEHead(d_z=d_model, d_y=d_model, d_out=256).to(device) |
|
|
| log(f"loading policy from SFT ckpt {cfg['warmup_ckpt']}") |
| _load_state_from_ckpt(policy_model, pol_projector, pol_head, cfg["warmup_ckpt"], device) |
|
|
| log("building reference model (frozen copy of policy init) …") |
| ref_model, _ = build_base(bcfg) |
| ref_model.to(device) |
| ref_projector = LatentProjector( |
| d_model, init_scale=cfg["proj_init_scale"], |
| use_mlp=cfg.get("proj_mlp", False), |
| hidden_mult=cfg.get("proj_hidden_mult", 4), |
| ).to(device=device, dtype=model_dtype) |
| ref_head = InfoNCEHead(d_z=d_model, d_y=d_model, d_out=256).to(device) |
| _load_state_from_ckpt(ref_model, ref_projector, ref_head, cfg["warmup_ckpt"], device) |
| for p in ref_model.parameters(): |
| p.requires_grad_(False) |
| for p in ref_projector.parameters(): |
| p.requires_grad_(False) |
| for p in ref_head.parameters(): |
| p.requires_grad_(False) |
| ref_model.eval() |
| ref_projector.eval() |
| ref_head.eval() |
|
|
| |
| opt, used_8bit = make_optimizer(policy_model, pol_projector, pol_head, cfg) |
| log(f"optimizer 8bit={used_8bit}") |
|
|
| n_trainable = sum(p.numel() for p in policy_model.parameters() if p.requires_grad) \ |
| + sum(p.numel() for p in pol_projector.parameters()) \ |
| + sum(p.numel() for p in pol_head.parameters()) |
| log(f"policy trainable params = {n_trainable/1e6:.2f}M") |
|
|
| |
| train_ds = GSM8KDataset(split="train", max_examples=cfg.get("rl_data_size")) |
| log(f"GSM8K RL pool size = {len(train_ds)}") |
| loader = DataLoader( |
| train_ds, batch_size=cfg["per_prompt_batch"], shuffle=True, drop_last=True, |
| collate_fn=lambda b: collate_batch( |
| b, tokenizer, |
| max_prompt_len=cfg["max_prompt_len"], |
| max_answer_len=cfg["max_answer_len"], |
| ), |
| ) |
|
|
| |
| K = cfg["K_latents"] |
| G = cfg["group_size"] |
| beta = cfg["beta"] |
| kl_clamp = cfg.get("kl_clamp", 20.0) |
| eos_id = tokenizer.eos_token_id |
| pad_id = tokenizer.pad_token_id |
|
|
| step = 0 |
| t0 = time.time() |
| reward_hist: List[float] = [] |
|
|
| while step < cfg["max_steps"]: |
| for batch in loader: |
| if step >= cfg["max_steps"]: |
| break |
|
|
| x_ids = batch.x_ids.to(device) |
| x_attn = batch.x_attn.to(device) |
| |
| golds = [s.replace("#### ", "").strip() for s in batch.final_strs] |
| B = x_ids.size(0) |
|
|
| |
| policy_model.eval() |
| with torch.no_grad(): |
| y_ids, y_mask = sample_rollouts( |
| policy_model, tokenizer, pol_projector, |
| x_ids, x_attn, |
| K_latents=K, group_size=G, |
| max_new_tokens=cfg["max_new_tokens"], |
| temperature=cfg["sample_temperature"], |
| block_y_to_x=cfg["block_y_to_x"], |
| eos_id=eos_id, pad_id=pad_id, |
| ) |
| policy_model.train() |
| |
| decoded = tokenizer.batch_decode(y_ids, skip_special_tokens=True) |
| rewards = [] |
| for i in range(B): |
| for g in range(G): |
| idx = i * G + g |
| n_tok = int(y_mask[idx].sum().item()) |
| rewards.append(reward_for( |
| decoded[idx], golds[i], |
| length_pen=cfg.get("length_penalty_coef", 0.0), |
| n_tokens=n_tok, |
| )) |
| rewards = torch.tensor(rewards, device=device, dtype=torch.float32) |
| reward_hist.append(rewards.mean().item()) |
|
|
| |
| rew = rewards.view(B, G) |
| adv = (rew - rew.mean(dim=1, keepdim=True)) / (rew.std(dim=1, keepdim=True) + 1e-6) |
| adv = adv.view(-1) |
|
|
| |
| n_groups_with_signal = int((rew.std(dim=1) > 1e-6).sum().item()) |
| if n_groups_with_signal == 0: |
| opt.zero_grad(set_to_none=True) |
| step += 1 |
| if step % cfg["log_every"] == 0: |
| log(f"step={step} reward_mean={rewards.mean().item():.3f} " |
| f"n_signal=0 (skipped) elapsed={time.time()-t0:.0f}s") |
| met_f.write(json.dumps({"step": step, "reward_mean": rewards.mean().item(), |
| "n_groups_with_signal": 0, |
| "skipped": 1, |
| "elapsed_s": time.time()-t0}) + "\n") |
| continue |
|
|
| |
| |
| x_ids_r = x_ids.repeat_interleave(G, dim=0) |
| x_attn_r = x_attn.repeat_interleave(G, dim=0) |
|
|
| policy_logp, _z_pol = forward_logp_for_rollout( |
| policy_model, pol_projector, x_ids_r, x_attn_r, y_ids, y_mask, K, |
| block_y_to_x=cfg["block_y_to_x"], |
| ) |
| with torch.no_grad(): |
| ref_logp, _ = forward_logp_for_rollout( |
| ref_model, ref_projector, x_ids_r, x_attn_r, y_ids, y_mask, K, |
| block_y_to_x=cfg["block_y_to_x"], |
| ) |
|
|
| loss, info = grpo_loss(policy_logp, ref_logp, adv, y_mask, beta, kl_clamp) |
| (loss / cfg["grad_accum"]).backward() |
|
|
| if (step + 1) % cfg["grad_accum"] == 0: |
| torch.nn.utils.clip_grad_norm_( |
| [p for p in policy_model.parameters() if p.requires_grad] |
| + list(pol_projector.parameters()) + list(pol_head.parameters()), |
| cfg["max_grad_norm"], |
| ) |
| opt.step() |
| opt.zero_grad(set_to_none=True) |
|
|
| step += 1 |
| if step % cfg["log_every"] == 0: |
| tail = reward_hist[-min(50, len(reward_hist)):] |
| log(f"step={step} reward_mean={rewards.mean().item():.3f} " |
| f"reward_avg50={sum(tail)/len(tail):.3f} " |
| f"adv|abs|={adv.abs().mean().item():.3f} " |
| f"pg={info['pg'].item():.4f} kl={info['kl'].item():.4f} " |
| f"n_signal={n_groups_with_signal}/{B} " |
| f"elapsed={time.time()-t0:.0f}s") |
| met_f.write(json.dumps({ |
| "step": step, |
| "reward_mean": rewards.mean().item(), |
| "reward_avg50": sum(tail)/len(tail), |
| "adv_abs_mean": adv.abs().mean().item(), |
| "pg": float(info["pg"].item()), |
| "kl": float(info["kl"].item()), |
| "n_groups_with_signal": n_groups_with_signal, |
| "skipped": 0, |
| "elapsed_s": time.time() - t0, |
| }) + "\n") |
|
|
| if cfg["save_every"] > 0 and step % cfg["save_every"] == 0: |
| save_dir = out_dir / f"grpo-step{step}" |
| save_dir.mkdir(exist_ok=True) |
| policy_model.save_pretrained(save_dir / "model") |
| torch.save(pol_projector.state_dict(), save_dir / "projector.pt") |
| torch.save(pol_head.state_dict(), save_dir / "head.pt") |
| log(f"[save] {save_dir}") |
|
|
| |
| save_dir = out_dir / "final" |
| save_dir.mkdir(exist_ok=True) |
| policy_model.save_pretrained(save_dir / "model") |
| torch.save(pol_projector.state_dict(), save_dir / "projector.pt") |
| torch.save(pol_head.state_dict(), save_dir / "head.pt") |
| log(f"[done] final saved at {save_dir}") |
| log_f.close() |
| met_f.close() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|