"""SimMart RL training — lightweight GRPO-style loop. Trains a Qwen 2.5 Instruct CEO against the SimMart environment using group-normalised REINFORCE (the "GRPO" special case from DeepSeek). Design notes (tuned for a 2-day hackathon timeline + MI250X + Unsloth): • Rollout batch B env instances with B distinct seeds → same policy parameters at each step, different trajectories. Variance reduction comes from the group baseline rather than multiple completions for the same prompt (which would require state-checkpointing). • Each training step: 1. 13-week rollout across B parallel envs 2. Per-week advantage = (reward - group_mean) / (group_std + eps) 3. Policy-gradient loss = -E[A_t * log π(a_t | s_t)] 4. KL penalty vs. frozen reference policy (optional; β from --kl) 5. Adam step on LoRA adapters only (Unsloth 4-bit base) • Log reward mean/max and parse-error rate to stdout (+ W&B optional). Usage (inside edaamd/unsloth-vllm container): python train.py --model Qwen/Qwen2.5-1.5B-Instruct --steps 30 --batch 4 \\ --lr 1e-5 --max-new-tokens 768 --kl 0.02 \\ --out /mnt/dcgpuval/hkandala/simmart-runs/smoke-1p5b Hero run (7B overnight): python train.py --model Qwen/Qwen2.5-7B-Instruct --steps 120 --batch 6 \\ --lr 5e-6 --kl 0.02 \\ --out /mnt/dcgpuval/hkandala/simmart-runs/hero-7b """ from __future__ import annotations import argparse import json import os import random import statistics import sys import time from dataclasses import asdict, dataclass, field from pathlib import Path from typing import Dict, List, Optional, Tuple HERE = os.path.dirname(os.path.abspath(__file__)) if HERE not in sys.path: sys.path.insert(0, HERE) # Unsloth must be imported BEFORE transformers / peft for its patches from unsloth import FastLanguageModel # noqa: E402 import torch # noqa: E402 import torch.distributed as dist # noqa: E402 import torch.nn.functional as F # noqa: E402 from models import ProposalDecision, SimMartAction # noqa: E402 from prompts import ( # noqa: E402 SYSTEM_PROMPT, build_chat, parse_response, render_observation, build_action_chat, build_journal_chat, parse_journal_response, ) from server.environment import SimMartEnvironment # noqa: E402 # --------------------------------------------------------------------------- # Distributed setup # --------------------------------------------------------------------------- def init_distributed() -> Tuple[int, int, int]: """Initialise torch.distributed if LOCAL_RANK is set (accelerate launch). Returns (rank, local_rank, world_size). Falls back to single-process. """ if "LOCAL_RANK" not in os.environ: return 0, 0, 1 local_rank = int(os.environ["LOCAL_RANK"]) if torch.cuda.is_available(): torch.cuda.set_device(local_rank) if not dist.is_initialized(): dist.init_process_group(backend="nccl") rank = dist.get_rank() world_size = dist.get_world_size() return rank, local_rank, world_size def all_reduce_mean_gradients(model, world_size: int) -> None: """Sum grads across ranks and divide by world_size (manual DDP step). Cheaper than wrapping Unsloth's 4-bit model in DDP which has known quirks. """ if world_size == 1: return for p in model.parameters(): if p.grad is not None: dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) p.grad.div_(world_size) def all_reduce_scalar(value: float, world_size: int, device) -> float: if world_size == 1: return value t = torch.tensor([value], device=device, dtype=torch.float32) dist.all_reduce(t, op=dist.ReduceOp.SUM) return (t.item() / world_size) # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- @dataclass class TrainConfig: model: str = "Qwen/Qwen2.5-1.5B-Instruct" steps: int = 30 batch: int = 4 # parallel envs per rank per training step max_seq_len: int = 4096 max_new_tokens: int = 768 lr: float = 1e-5 lr_min: float = 0.0 # if > 0 and < lr, cosine-anneal lr -> lr_min over `steps` beta_kl: float = 0.02 # KL penalty vs. reference policy entropy_coef: float = 0.01 clip_grad: float = 1.0 seed_offset: int = 0 out_dir: str = "./simmart-run" log_every: int = 1 save_every: int = 10 wandb_project: Optional[str] = None lora_r: int = 16 lora_alpha: int = 32 lora_dropout: float = 0.0 load_in_4bit: bool = True dtype: str = "bfloat16" # "float16" or "bfloat16" init_adapter: Optional[str] = None # load LoRA weights from an SFT checkpoint (action head) mb_size: int = 8 # minibatch size for the PG pass # ----- Dual-head (two-pass) config ---------------------------------- # If ``journal_adapter`` is set, the trainer loads it as a SECOND LoRA # named "journal", freezes its params, and uses a two-pass rollout: # action (trainable) + journal (frozen). The RL reward attribution is # then cleanly scoped to action tokens only. journal_adapter: Optional[str] = None action_max_tokens: int = 300 # action-head output budget (JSON only) journal_max_tokens: int = 400 # journal-head output budget (free text) # ----- Rollout sampling + reward shaping ---------------------------- rollout_temperature: float = 0.9 rollout_top_p: float = 0.95 fmt_penalty: float = 0.0 # subtracted from env reward when parse fails # Mixed-temperature rollouts: when both > 0 and rollout_temp_low != # rollout_temp_high, the action batch is split into halves and the two # halves are sampled at the two temperatures respectively. Otherwise the # full batch is sampled at ``rollout_temperature``. Sentinel <=0 = unset. rollout_temp_low: float = -1.0 rollout_temp_high: float = -1.0 # DDP state (set at runtime, not from CLI) rank: int = 0 local_rank: int = 0 world_size: int = 1 @dataclass class StepLog: step: int mean_reward: float max_reward: float min_reward: float reward_std: float mean_episode_return: float parse_error_rate: float rogue_recall: float loss: float pg_loss: float kl_loss: float entropy: float elapsed_s: float # --------------------------------------------------------------------------- # Utilities # --------------------------------------------------------------------------- def log_kv(step: int, kvs: Dict[str, float]) -> None: parts = [f"[step {step:03d}]"] for k, v in kvs.items(): if isinstance(v, float): parts.append(f"{k}={v:+.4f}") else: parts.append(f"{k}={v}") print(" ".join(parts), flush=True) # --------------------------------------------------------------------------- # Model setup # --------------------------------------------------------------------------- def load_policy(cfg: TrainConfig): dtype = torch.bfloat16 if cfg.dtype == "bfloat16" else torch.float16 device_map = {"": cfg.local_rank} if cfg.world_size > 1 else "auto" model, tokenizer = FastLanguageModel.from_pretrained( model_name=cfg.model, max_seq_length=cfg.max_seq_len, dtype=dtype, load_in_4bit=cfg.load_in_4bit, device_map=device_map, ) model = FastLanguageModel.get_peft_model( model, r=cfg.lora_r, lora_alpha=cfg.lora_alpha, lora_dropout=cfg.lora_dropout, bias="none", target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj", ], use_gradient_checkpointing="unsloth", random_state=42, ) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Optionally warm-start the action head from an SFT checkpoint. if cfg.init_adapter: if cfg.rank == 0: print(f"[init] loading SFT adapter weights (action) from {cfg.init_adapter}") from peft import set_peft_model_state_dict, load_peft_weights state_dict = load_peft_weights(cfg.init_adapter) # set_peft_model_state_dict handles key-name mapping across peft versions set_peft_model_state_dict(model, state_dict) # Optionally load a FROZEN journal adapter (dual-head architecture). if cfg.journal_adapter: if cfg.rank == 0: print(f"[init] loading FROZEN journal adapter from {cfg.journal_adapter}") model.load_adapter(cfg.journal_adapter, adapter_name="journal") # Freeze all journal-adapter parameters. PEFT names LoRA params like # ``...lora_A.journal.weight`` / ``...lora_B.journal.weight``, so # filtering on ``.journal.`` isolates them from the action/default # adapter. n_frozen = 0 for name, param in model.named_parameters(): if ".journal." in name: param.requires_grad = False n_frozen += 1 if cfg.rank == 0: print(f"[init] froze {n_frozen} journal-adapter params") # Make the action (default) adapter active for training. model.set_adapter("default") return model, tokenizer # --------------------------------------------------------------------------- # Rollout # --------------------------------------------------------------------------- def _apply_chat(tokenizer, chat) -> str: """Render the chat in the model's template with the assistant header open.""" return tokenizer.apply_chat_template( chat, tokenize=False, add_generation_prompt=True, ) def batched_generate( model, tokenizer, prompts: List[str], max_new_tokens: int, temperature: float = 0.9, top_p: float = 0.95, ) -> Tuple[List[str], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """Generate B completions, return: completions_text : List[str], len B input_ids : (B, L_in_max) padded-left completion_ids : (B, L_gen_max) padded-right input_mask : (B, L_in_max) 1 where real prompt token completion_mask : (B, L_gen_max) 1 where real generated token """ tokenizer.padding_side = "left" enc = tokenizer( prompts, return_tensors="pt", padding=True, truncation=True, max_length=model.config.max_position_embeddings, ).to(model.device) # Unsloth-accelerated inference path: keeps LoRA active but switches to # faster attention/cache kernels. We flip back to train mode afterwards. FastLanguageModel.for_inference(model) try: with torch.inference_mode(): out = model.generate( **enc, max_new_tokens=max_new_tokens, do_sample=True, temperature=temperature, top_p=top_p, pad_token_id=tokenizer.pad_token_id, use_cache=True, ) finally: FastLanguageModel.for_training(model) input_len = enc.input_ids.size(1) completion_ids = out[:, input_len:] completion_mask = (completion_ids != tokenizer.pad_token_id).long() completions_text = tokenizer.batch_decode(completion_ids, skip_special_tokens=True) return completions_text, enc.input_ids, completion_ids, enc.attention_mask, completion_mask def compute_completion_logprobs( model, input_ids: torch.Tensor, completion_ids: torch.Tensor, input_mask: torch.Tensor, completion_mask: torch.Tensor, ) -> torch.Tensor: """Return per-token log-probs for the completion tokens (B, L_gen). Computed with teacher-forcing on prompt+completion. """ full_ids = torch.cat([input_ids, completion_ids], dim=1) full_mask = torch.cat([input_mask, completion_mask], dim=1) out = model(input_ids=full_ids, attention_mask=full_mask) logits = out.logits[:, :-1, :] # align with next-token targets = full_ids[:, 1:] logp = F.log_softmax(logits.float(), dim=-1) tok_logp = logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1) L_in = input_ids.size(1) completion_logp = tok_logp[:, L_in - 1:] # logp for completion tokens comp_mask = completion_mask.float() completion_logp = completion_logp[:, :comp_mask.size(1)] * comp_mask return completion_logp, comp_mask def compute_entropy_and_logp( model, input_ids: torch.Tensor, completion_ids: torch.Tensor, input_mask: torch.Tensor, completion_mask: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Return (completion_logp, entropy, comp_mask). completion_logp: per-token log π(a_t) -- (B, L_gen) entropy : per-token H(π) -- (B, L_gen) """ full_ids = torch.cat([input_ids, completion_ids], dim=1) full_mask = torch.cat([input_mask, completion_mask], dim=1) out = model(input_ids=full_ids, attention_mask=full_mask) logits = out.logits[:, :-1, :] targets = full_ids[:, 1:] logp_full = F.log_softmax(logits.float(), dim=-1) p_full = logp_full.exp() ent_full = -(p_full * logp_full).sum(-1) tok_logp = logp_full.gather(-1, targets.unsqueeze(-1)).squeeze(-1) L_in = input_ids.size(1) completion_logp = tok_logp[:, L_in - 1:] entropy = ent_full[:, L_in - 1:] comp_mask = completion_mask.float() completion_logp = completion_logp[:, :comp_mask.size(1)] * comp_mask entropy = entropy[:, :comp_mask.size(1)] * comp_mask return completion_logp, entropy, comp_mask @dataclass class RolloutStep: week: int env_idx: int prompt_text: str completion_text: str prompt_ids: torch.Tensor # (L_in,) completion_ids: torch.Tensor # (L_gen,) prompt_mask: torch.Tensor completion_mask: torch.Tensor reward: float parse_ok: bool was_rogue_week: bool = False caught_rogue: bool = False def rollout_batch( model, tokenizer, cfg: TrainConfig, step_idx: int, ) -> List[RolloutStep]: """Run `cfg.batch` parallel 13-week episodes on this rank. Dual-head mode (``cfg.journal_adapter`` set): Week t ──> action pass (default adapter, trainable) → JSON journal pass (journal adapter, frozen) → text env.step(combined_action) The RolloutStep stores action-pass tensors only so policy-gradient flows exclusively through action tokens (clean credit assignment). Legacy mode (``cfg.journal_adapter`` = None): Single ``build_chat`` generation — identical to v5 behaviour. Seeds spread across ranks so different GPUs explore different env trajectories: global_seed = seed_offset + step * world_size * batch + rank * batch + i. """ dual_head = cfg.journal_adapter is not None envs: List[SimMartEnvironment] = [] obss = [] for i in range(cfg.batch): env = SimMartEnvironment() seed = (cfg.seed_offset + step_idx * cfg.world_size * cfg.batch + cfg.rank * cfg.batch + i) obs = env.reset(seed=seed, episode_id=f"train-r{cfg.rank}-{step_idx}-{i}") envs.append(env) obss.append(obs) rollout: List[RolloutStep] = [] for week in range(1, SimMartEnvironment.MAX_WEEKS + 1): # ---- Pass 1: ACTION (trainable "default" adapter) -------------------- if dual_head: model.set_adapter("default") prompts_text = [ _apply_chat(tokenizer, build_action_chat(obs)) for obs in obss ] action_budget = cfg.action_max_tokens else: prompts_text = [_apply_chat(tokenizer, build_chat(obs)) for obs in obss] action_budget = cfg.max_new_tokens mixed_temp = ( cfg.rollout_temp_low > 0 and cfg.rollout_temp_high > 0 and cfg.rollout_temp_low != cfg.rollout_temp_high and len(prompts_text) >= 2 ) if mixed_temp: mid = len(prompts_text) // 2 txt_lo, in_lo, comp_lo, im_lo, cm_lo = batched_generate( model, tokenizer, prompts_text[:mid], max_new_tokens=action_budget, temperature=cfg.rollout_temp_low, top_p=cfg.rollout_top_p, ) txt_hi, in_hi, comp_hi, im_hi, cm_hi = batched_generate( model, tokenizer, prompts_text[mid:], max_new_tokens=action_budget, temperature=cfg.rollout_temp_high, top_p=cfg.rollout_top_p, ) completions_text = list(txt_lo) + list(txt_hi) input_ids_list = ( [in_lo[i].detach().cpu() for i in range(in_lo.size(0))] + [in_hi[i].detach().cpu() for i in range(in_hi.size(0))] ) completion_ids_list = ( [comp_lo[i].detach().cpu() for i in range(comp_lo.size(0))] + [comp_hi[i].detach().cpu() for i in range(comp_hi.size(0))] ) input_mask_list = ( [im_lo[i].detach().cpu() for i in range(im_lo.size(0))] + [im_hi[i].detach().cpu() for i in range(im_hi.size(0))] ) completion_mask_list = ( [cm_lo[i].detach().cpu() for i in range(cm_lo.size(0))] + [cm_hi[i].detach().cpu() for i in range(cm_hi.size(0))] ) else: completions_text_t, input_ids, completion_ids, input_mask, completion_mask = \ batched_generate( model, tokenizer, prompts_text, max_new_tokens=action_budget, temperature=cfg.rollout_temperature, top_p=cfg.rollout_top_p, ) completions_text = list(completions_text_t) input_ids_list = [input_ids[i].detach().cpu() for i in range(input_ids.size(0))] completion_ids_list = [completion_ids[i].detach().cpu() for i in range(completion_ids.size(0))] input_mask_list = [input_mask[i].detach().cpu() for i in range(input_mask.size(0))] completion_mask_list = [completion_mask[i].detach().cpu() for i in range(completion_mask.size(0))] # Parse actions — in dual-head mode these carry decisions + budget but # an empty journal; in single-pass mode they carry everything. parsed: List[Tuple[Any, Dict[str, Any]]] = [] for obs, comp_text in zip(obss, completions_text): parsed.append(parse_response(comp_text, obs.inbox)) # ---- Pass 2: JOURNAL (frozen "journal" adapter) ---------------------- if dual_head: model.set_adapter("journal") journal_prompts = [ _apply_chat( tokenizer, build_journal_chat( obs, action.decisions, action.budget_allocations, ), ) for obs, (action, _) in zip(obss, parsed) ] with torch.inference_mode(): jc_text, _, _, _, _ = batched_generate( model, tokenizer, journal_prompts, max_new_tokens=cfg.journal_max_tokens, temperature=cfg.rollout_temperature, top_p=cfg.rollout_top_p, ) for (action, _), jt in zip(parsed, jc_text): action.journal_entry = parse_journal_response(jt) # Switch back to the trainable adapter so the logprob recompute # in train_step uses the same adapter configuration as rollout. model.set_adapter("default") # ---- Env step + RolloutStep record (action tokens only for grad) ---- for i, (env, obs, (action, tel), comp_text) in enumerate( zip(envs, obss, parsed, completions_text), ): rogue_ids_this_week: set = set() for r in env.state.rogue_incidents: if week in r.active_weeks: rogue_ids_this_week.update(r.associated_proposal_ids) was_rogue_week = len(rogue_ids_this_week) > 0 step_obs = env.step(action) caught = any( d.verdict == "flag_suspicious" and d.proposal_id in rogue_ids_this_week for d in action.decisions ) env_reward = float(step_obs.reward or 0.0) parse_ok = tel["parse_ok"] or tel.get("parse_partial", False) shaped_reward = env_reward - cfg.fmt_penalty * (0.0 if parse_ok else 1.0) rollout.append(RolloutStep( week=week, env_idx=i, prompt_text=prompts_text[i], completion_text=comp_text, prompt_ids=input_ids_list[i], completion_ids=completion_ids_list[i], prompt_mask=input_mask_list[i], completion_mask=completion_mask_list[i], reward=shaped_reward, parse_ok=parse_ok, was_rogue_week=was_rogue_week, caught_rogue=caught, )) obss[i] = step_obs # Clean up large GPU tensors from the action pass if mixed_temp: del in_lo, in_hi, comp_lo, comp_hi, im_lo, im_hi, cm_lo, cm_hi else: del input_ids, completion_ids, input_mask, completion_mask torch.cuda.empty_cache() return rollout # --------------------------------------------------------------------------- # Advantage + loss # --------------------------------------------------------------------------- def compute_advantages(rollout: List[RolloutStep]) -> List[float]: """Group-normalise rewards: per-week (r - mean) / (std + eps). Matches the DeepSeek GRPO trick: the 'group' is the B parallel rollouts at the same week index, so advantage is well-conditioned even as absolute reward drifts over training. """ by_week: Dict[int, List[float]] = {} for r in rollout: by_week.setdefault(r.week, []).append(r.reward) week_stats = { w: (statistics.mean(xs), statistics.stdev(xs) if len(xs) > 1 else 1e-6) for w, xs in by_week.items() } eps = 1e-6 return [(r.reward - week_stats[r.week][0]) / (week_stats[r.week][1] + eps) for r in rollout] # --------------------------------------------------------------------------- # Training step # --------------------------------------------------------------------------- def train_step( model, tokenizer, ref_model, optimizer, cfg: TrainConfig, step_idx: int, ) -> StepLog: t0 = time.time() rollout = rollout_batch(model, tokenizer, cfg, step_idx) rewards = [r.reward for r in rollout] parse_ok_count = sum(1 for r in rollout if r.parse_ok) n_rogue_weeks = sum(1 for r in rollout if r.was_rogue_week) n_caught = sum(1 for r in rollout if r.caught_rogue) advantages = compute_advantages(rollout) # ------------------------------------------------------------------- # Minibatch PG update (one pass over the rollout) # ------------------------------------------------------------------- total_loss = 0.0 total_pg = 0.0 total_kl = 0.0 total_ent = 0.0 seen = 0 mb_size = cfg.mb_size model.train() for mb_start in range(0, len(rollout), mb_size): mb = rollout[mb_start:mb_start + mb_size] advs_mb = advantages[mb_start:mb_start + mb_size] # Re-tokenise-and-pad the minibatch fresh (avoids mixed lengths from rollout) input_ids = torch.nn.utils.rnn.pad_sequence( [r.prompt_ids for r in mb], batch_first=True, padding_value=tokenizer.pad_token_id, ).to(model.device) completion_ids = torch.nn.utils.rnn.pad_sequence( [r.completion_ids for r in mb], batch_first=True, padding_value=tokenizer.pad_token_id, ).to(model.device) input_mask = torch.nn.utils.rnn.pad_sequence( [r.prompt_mask for r in mb], batch_first=True, padding_value=0, ).to(model.device) completion_mask = torch.nn.utils.rnn.pad_sequence( [r.completion_mask for r in mb], batch_first=True, padding_value=0, ).to(model.device) # Current policy log-probs + entropy completion_logp, entropy, mask = compute_entropy_and_logp( model, input_ids, completion_ids, input_mask, completion_mask, ) # Reference policy log-probs (LoRA-disabled base model) for KL. # We share weights with the policy and toggle the adapter off so # there's no second copy in GPU memory. if ref_model is not None and cfg.beta_kl > 0: with torch.inference_mode(), ref_model.disable_adapter(): ref_logp, _ = compute_completion_logprobs( ref_model, input_ids, completion_ids, input_mask, completion_mask, ) kl_per_tok = (completion_logp - ref_logp.detach()) * mask else: kl_per_tok = torch.zeros_like(completion_logp) # Per-sample log-prob sum (masked avg) denom = mask.sum(dim=1).clamp_min(1.0) logp_per_sample = (completion_logp * mask).sum(dim=1) / denom entropy_per_sample = (entropy * mask).sum(dim=1) / denom kl_per_sample = kl_per_tok.sum(dim=1) / denom adv_t = torch.tensor(advs_mb, device=model.device, dtype=logp_per_sample.dtype) pg = -(adv_t * logp_per_sample).mean() ent_term = -cfg.entropy_coef * entropy_per_sample.mean() kl_term = cfg.beta_kl * kl_per_sample.mean() loss = pg + kl_term + ent_term optimizer.zero_grad() loss.backward() # Manual DDP: sum grads across ranks, average. Cheaper than wrapping # Unsloth's 4-bit model in torch.nn.parallel.DistributedDataParallel. all_reduce_mean_gradients(model, cfg.world_size) torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_grad) optimizer.step() total_loss += float(loss.item()) * len(mb) total_pg += float(pg.item()) * len(mb) total_kl += float(kl_term.item()) * len(mb) total_ent += float(entropy_per_sample.mean().item()) * len(mb) seen += len(mb) del input_ids, completion_ids, input_mask, completion_mask del completion_logp, entropy, mask, kl_per_tok torch.cuda.empty_cache() # ------------------------------------------------------------------- # Aggregate metrics for logging # ------------------------------------------------------------------- episode_returns: Dict[int, float] = {} for r in rollout: episode_returns[r.env_idx] = episode_returns.get(r.env_idx, 0.0) + r.reward elapsed = time.time() - t0 return StepLog( step=step_idx, mean_reward=statistics.mean(rewards), max_reward=max(rewards), min_reward=min(rewards), reward_std=statistics.stdev(rewards) if len(rewards) > 1 else 0.0, mean_episode_return=statistics.mean(list(episode_returns.values())), parse_error_rate=1.0 - (parse_ok_count / max(1, len(rollout))), rogue_recall=(n_caught / n_rogue_weeks) if n_rogue_weeks > 0 else 0.0, loss=total_loss / max(1, seen), pg_loss=total_pg / max(1, seen), kl_loss=total_kl / max(1, seen), entropy=total_ent / max(1, seen), elapsed_s=elapsed, ) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def _reduce_log(log: StepLog, world_size: int, device) -> StepLog: """All-reduce metric averages across ranks so rank-0 logs reflect the full batch, not just one GPU's slice.""" if world_size == 1: return log fields = [ "mean_reward", "max_reward", "min_reward", "reward_std", "mean_episode_return", "parse_error_rate", "rogue_recall", "loss", "pg_loss", "kl_loss", "entropy", ] values = torch.tensor( [getattr(log, f) for f in fields], device=device, dtype=torch.float32, ) dist.all_reduce(values, op=dist.ReduceOp.SUM) values = values / world_size # max and min should stay max/min, not averaged vmax = torch.tensor([log.max_reward], device=device, dtype=torch.float32) vmin = torch.tensor([log.min_reward], device=device, dtype=torch.float32) dist.all_reduce(vmax, op=dist.ReduceOp.MAX) dist.all_reduce(vmin, op=dist.ReduceOp.MIN) kv = dict(zip(fields, values.tolist())) kv["max_reward"] = vmax.item() kv["min_reward"] = vmin.item() kv["step"] = log.step kv["elapsed_s"] = log.elapsed_s return StepLog(**kv) def main() -> int: p = argparse.ArgumentParser() p.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct") p.add_argument("--steps", type=int, default=30) p.add_argument("--batch", type=int, default=4) p.add_argument("--lr", type=float, default=1e-5) p.add_argument("--lr-min", type=float, default=0.0, help="If > 0, cosine-anneal lr -> lr_min over --steps (default: flat lr)") p.add_argument("--max-seq-len", type=int, default=4096) p.add_argument("--max-new-tokens", type=int, default=768) p.add_argument("--kl", type=float, default=0.02, dest="beta_kl") p.add_argument("--entropy", type=float, default=0.01, dest="entropy_coef") p.add_argument("--seed-offset", type=int, default=0) p.add_argument("--out", default="./simmart-run", dest="out_dir") p.add_argument("--save-every", type=int, default=10) p.add_argument("--wandb", type=str, default=None) p.add_argument("--dtype", choices=["float16", "bfloat16"], default="bfloat16") p.add_argument("--init-adapter", default=None, help="Load LoRA weights from this SFT checkpoint at startup (action head)") p.add_argument("--journal-adapter", default=None, help="Optional: load a frozen journal LoRA as second adapter. " "Triggers dual-head (two-pass) rollout mode.") p.add_argument("--action-max-tokens", type=int, default=300, help="Output budget for the action head JSON (dual-head mode)") p.add_argument("--journal-max-tokens", type=int, default=400, help="Output budget for the journal head text (dual-head mode)") p.add_argument("--mb-size", type=int, default=8, help="Minibatch size for the PG pass (was 2; bigger = fewer forward passes)") p.add_argument("--rollout-temperature", type=float, default=0.9, dest="rollout_temperature", help="Sampling temperature for training rollouts (default 0.9; lower shrinks sampled-vs-greedy gap)") p.add_argument("--rollout-top-p", type=float, default=0.95, dest="rollout_top_p", help="Nucleus sampling top_p for training rollouts (default 0.95)") p.add_argument("--fmt-penalty", type=float, default=0.0, dest="fmt_penalty", help="Magnitude (>=0) subtracted from env reward when parse fails (default 0)") p.add_argument("--rollout-temp-low", type=float, default=-1.0, dest="rollout_temp_low", help="Low temperature for mixed-temp rollouts (set both -low and -high to enable; default off)") p.add_argument("--rollout-temp-high", type=float, default=-1.0, dest="rollout_temp_high", help="High temperature for mixed-temp rollouts (default off; falls back to --rollout-temperature)") args = p.parse_args() rank, local_rank, world_size = init_distributed() is_main = rank == 0 cfg = TrainConfig( model=args.model, steps=args.steps, batch=args.batch, max_seq_len=args.max_seq_len, max_new_tokens=args.max_new_tokens, lr=args.lr, lr_min=args.lr_min, beta_kl=args.beta_kl, entropy_coef=args.entropy_coef, seed_offset=args.seed_offset, out_dir=args.out_dir, save_every=args.save_every, wandb_project=args.wandb, dtype=args.dtype, init_adapter=args.init_adapter, journal_adapter=args.journal_adapter, action_max_tokens=args.action_max_tokens, journal_max_tokens=args.journal_max_tokens, mb_size=args.mb_size, rollout_temperature=args.rollout_temperature, rollout_top_p=args.rollout_top_p, fmt_penalty=args.fmt_penalty, rollout_temp_low=args.rollout_temp_low, rollout_temp_high=args.rollout_temp_high, rank=rank, local_rank=local_rank, world_size=world_size, ) out_dir = Path(cfg.out_dir) if is_main: out_dir.mkdir(parents=True, exist_ok=True) with open(out_dir / "config.json", "w") as f: json.dump(asdict(cfg), f, indent=2) wandb = None if cfg.wandb_project and is_main: try: import wandb as _wb wandb = _wb wandb.init(project=cfg.wandb_project, config=asdict(cfg)) except Exception as e: print(f"[warn] wandb disabled: {e}") if is_main: print(f"Loading policy: {cfg.model} (4-bit LoRA r={cfg.lora_r})") if cfg.init_adapter: print(f"Init adapter: {cfg.init_adapter}") print(f"DDP: rank={rank}/{world_size} local_rank={local_rank} " f"per_rank_batch={cfg.batch} global_batch={cfg.batch * world_size}") model, tokenizer = load_policy(cfg) if is_main: print(f"Device: {model.device} dtype: {cfg.dtype}") ref_model = model optimizer = torch.optim.AdamW( [p for p in model.parameters() if p.requires_grad], lr=cfg.lr, betas=(0.9, 0.95), weight_decay=0.0, ) scheduler = None if 0 < cfg.lr_min < cfg.lr: scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=cfg.steps, eta_min=cfg.lr_min, ) if is_main: print(f"[lr] cosine decay: {cfg.lr:.2e} -> {cfg.lr_min:.2e} over {cfg.steps} steps") # Barrier so all ranks have loaded weights before step 1 starts. if world_size > 1: dist.barrier() history: List[StepLog] = [] for step in range(1, cfg.steps + 1): log = train_step(model, tokenizer, ref_model, optimizer, cfg, step) reduced = _reduce_log(log, world_size, model.device) history.append(reduced) if is_main: log_kv(step, { "mean_r": reduced.mean_reward, "ep_ret": reduced.mean_episode_return, "r_std": reduced.reward_std, "max_r": reduced.max_reward, "parse_err": reduced.parse_error_rate, "rogue_rec": reduced.rogue_recall, "loss": reduced.loss, "pg": reduced.pg_loss, "kl": reduced.kl_loss, "ent": reduced.entropy, "sec/step": reduced.elapsed_s, "lr": optimizer.param_groups[0]["lr"], }) if wandb: wandb.log({f"train/{k}": v for k, v in asdict(reduced).items()}, step=step) with open(out_dir / "history.jsonl", "a") as f: f.write(json.dumps(asdict(reduced)) + "\n") if (step % cfg.save_every == 0 or step == cfg.steps): if is_main: ckpt = out_dir / f"adapter-step-{step:03d}" model.save_pretrained(str(ckpt)) tokenizer.save_pretrained(str(ckpt)) print(f"[ckpt] saved {ckpt}") if world_size > 1: dist.barrier() if scheduler is not None: scheduler.step() if is_main: print("Training complete.") if wandb: wandb.finish() if world_size > 1: dist.destroy_process_group() return 0 if __name__ == "__main__": raise SystemExit(main())