| """SimMart RL training — lightweight GRPO-style loop. |
| |
| Trains a Qwen 2.5 Instruct CEO against the SimMart environment using |
| group-normalised REINFORCE (the "GRPO" special case from DeepSeek). |
| |
| Design notes (tuned for a 2-day hackathon timeline + MI250X + Unsloth): |
| • Rollout batch B env instances with B distinct seeds → same policy |
| parameters at each step, different trajectories. Variance reduction |
| comes from the group baseline rather than multiple completions for the |
| same prompt (which would require state-checkpointing). |
| • Each training step: |
| 1. 13-week rollout across B parallel envs |
| 2. Per-week advantage = (reward - group_mean) / (group_std + eps) |
| 3. Policy-gradient loss = -E[A_t * log π(a_t | s_t)] |
| 4. KL penalty vs. frozen reference policy (optional; β from --kl) |
| 5. Adam step on LoRA adapters only (Unsloth 4-bit base) |
| • Log reward mean/max and parse-error rate to stdout (+ W&B optional). |
| |
| Usage (inside edaamd/unsloth-vllm container): |
| python train.py --model Qwen/Qwen2.5-1.5B-Instruct --steps 30 --batch 4 \\ |
| --lr 1e-5 --max-new-tokens 768 --kl 0.02 \\ |
| --out /mnt/dcgpuval/hkandala/simmart-runs/smoke-1p5b |
| |
| Hero run (7B overnight): |
| python train.py --model Qwen/Qwen2.5-7B-Instruct --steps 120 --batch 6 \\ |
| --lr 5e-6 --kl 0.02 \\ |
| --out /mnt/dcgpuval/hkandala/simmart-runs/hero-7b |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import random |
| import statistics |
| import sys |
| import time |
| from dataclasses import asdict, dataclass, field |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| HERE = os.path.dirname(os.path.abspath(__file__)) |
| if HERE not in sys.path: |
| sys.path.insert(0, HERE) |
|
|
|
|
| |
| from unsloth import FastLanguageModel |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.nn.functional as F |
|
|
| from models import ProposalDecision, SimMartAction |
| from prompts import ( |
| SYSTEM_PROMPT, build_chat, parse_response, render_observation, |
| build_action_chat, build_journal_chat, parse_journal_response, |
| ) |
| from server.environment import SimMartEnvironment |
|
|
|
|
| |
| |
| |
|
|
| def init_distributed() -> Tuple[int, int, int]: |
| """Initialise torch.distributed if LOCAL_RANK is set (accelerate launch). |
| Returns (rank, local_rank, world_size). Falls back to single-process. |
| """ |
| if "LOCAL_RANK" not in os.environ: |
| return 0, 0, 1 |
| local_rank = int(os.environ["LOCAL_RANK"]) |
| if torch.cuda.is_available(): |
| torch.cuda.set_device(local_rank) |
| if not dist.is_initialized(): |
| dist.init_process_group(backend="nccl") |
| rank = dist.get_rank() |
| world_size = dist.get_world_size() |
| return rank, local_rank, world_size |
|
|
|
|
| def all_reduce_mean_gradients(model, world_size: int) -> None: |
| """Sum grads across ranks and divide by world_size (manual DDP step). |
| Cheaper than wrapping Unsloth's 4-bit model in DDP which has known quirks. |
| """ |
| if world_size == 1: |
| return |
| for p in model.parameters(): |
| if p.grad is not None: |
| dist.all_reduce(p.grad, op=dist.ReduceOp.SUM) |
| p.grad.div_(world_size) |
|
|
|
|
| def all_reduce_scalar(value: float, world_size: int, device) -> float: |
| if world_size == 1: |
| return value |
| t = torch.tensor([value], device=device, dtype=torch.float32) |
| dist.all_reduce(t, op=dist.ReduceOp.SUM) |
| return (t.item() / world_size) |
|
|
|
|
| |
| |
| |
|
|
| @dataclass |
| class TrainConfig: |
| model: str = "Qwen/Qwen2.5-1.5B-Instruct" |
| steps: int = 30 |
| batch: int = 4 |
| max_seq_len: int = 4096 |
| max_new_tokens: int = 768 |
| lr: float = 1e-5 |
| lr_min: float = 0.0 |
| beta_kl: float = 0.02 |
| entropy_coef: float = 0.01 |
| clip_grad: float = 1.0 |
| seed_offset: int = 0 |
| out_dir: str = "./simmart-run" |
| log_every: int = 1 |
| save_every: int = 10 |
| wandb_project: Optional[str] = None |
| lora_r: int = 16 |
| lora_alpha: int = 32 |
| lora_dropout: float = 0.0 |
| load_in_4bit: bool = True |
| dtype: str = "bfloat16" |
| init_adapter: Optional[str] = None |
| mb_size: int = 8 |
| |
| |
| |
| |
| |
| journal_adapter: Optional[str] = None |
| action_max_tokens: int = 300 |
| journal_max_tokens: int = 400 |
| |
| rollout_temperature: float = 0.9 |
| rollout_top_p: float = 0.95 |
| fmt_penalty: float = 0.0 |
| |
| |
| |
| |
| rollout_temp_low: float = -1.0 |
| rollout_temp_high: float = -1.0 |
| |
| rank: int = 0 |
| local_rank: int = 0 |
| world_size: int = 1 |
|
|
|
|
| @dataclass |
| class StepLog: |
| step: int |
| mean_reward: float |
| max_reward: float |
| min_reward: float |
| reward_std: float |
| mean_episode_return: float |
| parse_error_rate: float |
| rogue_recall: float |
| loss: float |
| pg_loss: float |
| kl_loss: float |
| entropy: float |
| elapsed_s: float |
|
|
|
|
| |
| |
| |
|
|
| def log_kv(step: int, kvs: Dict[str, float]) -> None: |
| parts = [f"[step {step:03d}]"] |
| for k, v in kvs.items(): |
| if isinstance(v, float): |
| parts.append(f"{k}={v:+.4f}") |
| else: |
| parts.append(f"{k}={v}") |
| print(" ".join(parts), flush=True) |
|
|
|
|
| |
| |
| |
|
|
| def load_policy(cfg: TrainConfig): |
| dtype = torch.bfloat16 if cfg.dtype == "bfloat16" else torch.float16 |
| device_map = {"": cfg.local_rank} if cfg.world_size > 1 else "auto" |
| model, tokenizer = FastLanguageModel.from_pretrained( |
| model_name=cfg.model, |
| max_seq_length=cfg.max_seq_len, |
| dtype=dtype, |
| load_in_4bit=cfg.load_in_4bit, |
| device_map=device_map, |
| ) |
| model = FastLanguageModel.get_peft_model( |
| model, |
| r=cfg.lora_r, |
| lora_alpha=cfg.lora_alpha, |
| lora_dropout=cfg.lora_dropout, |
| bias="none", |
| target_modules=[ |
| "q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj", |
| ], |
| use_gradient_checkpointing="unsloth", |
| random_state=42, |
| ) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| if cfg.init_adapter: |
| if cfg.rank == 0: |
| print(f"[init] loading SFT adapter weights (action) from {cfg.init_adapter}") |
| from peft import set_peft_model_state_dict, load_peft_weights |
| state_dict = load_peft_weights(cfg.init_adapter) |
| |
| set_peft_model_state_dict(model, state_dict) |
|
|
| |
| if cfg.journal_adapter: |
| if cfg.rank == 0: |
| print(f"[init] loading FROZEN journal adapter from {cfg.journal_adapter}") |
| model.load_adapter(cfg.journal_adapter, adapter_name="journal") |
| |
| |
| |
| |
| n_frozen = 0 |
| for name, param in model.named_parameters(): |
| if ".journal." in name: |
| param.requires_grad = False |
| n_frozen += 1 |
| if cfg.rank == 0: |
| print(f"[init] froze {n_frozen} journal-adapter params") |
| |
| model.set_adapter("default") |
|
|
| return model, tokenizer |
|
|
|
|
| |
| |
| |
|
|
| def _apply_chat(tokenizer, chat) -> str: |
| """Render the chat in the model's template with the assistant header open.""" |
| return tokenizer.apply_chat_template( |
| chat, tokenize=False, add_generation_prompt=True, |
| ) |
|
|
|
|
| def batched_generate( |
| model, tokenizer, prompts: List[str], max_new_tokens: int, |
| temperature: float = 0.9, top_p: float = 0.95, |
| ) -> Tuple[List[str], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Generate B completions, return: |
| completions_text : List[str], len B |
| input_ids : (B, L_in_max) padded-left |
| completion_ids : (B, L_gen_max) padded-right |
| input_mask : (B, L_in_max) 1 where real prompt token |
| completion_mask : (B, L_gen_max) 1 where real generated token |
| """ |
| tokenizer.padding_side = "left" |
| enc = tokenizer( |
| prompts, return_tensors="pt", padding=True, truncation=True, |
| max_length=model.config.max_position_embeddings, |
| ).to(model.device) |
|
|
| |
| |
| FastLanguageModel.for_inference(model) |
| try: |
| with torch.inference_mode(): |
| out = model.generate( |
| **enc, |
| max_new_tokens=max_new_tokens, |
| do_sample=True, |
| temperature=temperature, |
| top_p=top_p, |
| pad_token_id=tokenizer.pad_token_id, |
| use_cache=True, |
| ) |
| finally: |
| FastLanguageModel.for_training(model) |
|
|
| input_len = enc.input_ids.size(1) |
| completion_ids = out[:, input_len:] |
| completion_mask = (completion_ids != tokenizer.pad_token_id).long() |
|
|
| completions_text = tokenizer.batch_decode(completion_ids, skip_special_tokens=True) |
| return completions_text, enc.input_ids, completion_ids, enc.attention_mask, completion_mask |
|
|
|
|
| def compute_completion_logprobs( |
| model, input_ids: torch.Tensor, completion_ids: torch.Tensor, |
| input_mask: torch.Tensor, completion_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| """Return per-token log-probs for the completion tokens (B, L_gen). |
| |
| Computed with teacher-forcing on prompt+completion. |
| """ |
| full_ids = torch.cat([input_ids, completion_ids], dim=1) |
| full_mask = torch.cat([input_mask, completion_mask], dim=1) |
|
|
| out = model(input_ids=full_ids, attention_mask=full_mask) |
| logits = out.logits[:, :-1, :] |
| targets = full_ids[:, 1:] |
| logp = F.log_softmax(logits.float(), dim=-1) |
| tok_logp = logp.gather(-1, targets.unsqueeze(-1)).squeeze(-1) |
|
|
| L_in = input_ids.size(1) |
| completion_logp = tok_logp[:, L_in - 1:] |
| comp_mask = completion_mask.float() |
| completion_logp = completion_logp[:, :comp_mask.size(1)] * comp_mask |
| return completion_logp, comp_mask |
|
|
|
|
| def compute_entropy_and_logp( |
| model, input_ids: torch.Tensor, completion_ids: torch.Tensor, |
| input_mask: torch.Tensor, completion_mask: torch.Tensor, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """Return (completion_logp, entropy, comp_mask). |
| |
| completion_logp: per-token log π(a_t) -- (B, L_gen) |
| entropy : per-token H(π) -- (B, L_gen) |
| """ |
| full_ids = torch.cat([input_ids, completion_ids], dim=1) |
| full_mask = torch.cat([input_mask, completion_mask], dim=1) |
|
|
| out = model(input_ids=full_ids, attention_mask=full_mask) |
| logits = out.logits[:, :-1, :] |
| targets = full_ids[:, 1:] |
| logp_full = F.log_softmax(logits.float(), dim=-1) |
| p_full = logp_full.exp() |
| ent_full = -(p_full * logp_full).sum(-1) |
|
|
| tok_logp = logp_full.gather(-1, targets.unsqueeze(-1)).squeeze(-1) |
|
|
| L_in = input_ids.size(1) |
| completion_logp = tok_logp[:, L_in - 1:] |
| entropy = ent_full[:, L_in - 1:] |
| comp_mask = completion_mask.float() |
| completion_logp = completion_logp[:, :comp_mask.size(1)] * comp_mask |
| entropy = entropy[:, :comp_mask.size(1)] * comp_mask |
| return completion_logp, entropy, comp_mask |
|
|
|
|
| @dataclass |
| class RolloutStep: |
| week: int |
| env_idx: int |
| prompt_text: str |
| completion_text: str |
| prompt_ids: torch.Tensor |
| completion_ids: torch.Tensor |
| prompt_mask: torch.Tensor |
| completion_mask: torch.Tensor |
| reward: float |
| parse_ok: bool |
| was_rogue_week: bool = False |
| caught_rogue: bool = False |
|
|
|
|
| def rollout_batch( |
| model, tokenizer, cfg: TrainConfig, step_idx: int, |
| ) -> List[RolloutStep]: |
| """Run `cfg.batch` parallel 13-week episodes on this rank. |
| |
| Dual-head mode (``cfg.journal_adapter`` set): |
| Week t ──> action pass (default adapter, trainable) → JSON |
| journal pass (journal adapter, frozen) → text |
| env.step(combined_action) |
| The RolloutStep stores action-pass tensors only so policy-gradient |
| flows exclusively through action tokens (clean credit assignment). |
| |
| Legacy mode (``cfg.journal_adapter`` = None): |
| Single ``build_chat`` generation — identical to v5 behaviour. |
| |
| Seeds spread across ranks so different GPUs explore different env |
| trajectories: global_seed = seed_offset + step * world_size * batch |
| + rank * batch + i. |
| """ |
| dual_head = cfg.journal_adapter is not None |
|
|
| envs: List[SimMartEnvironment] = [] |
| obss = [] |
| for i in range(cfg.batch): |
| env = SimMartEnvironment() |
| seed = (cfg.seed_offset |
| + step_idx * cfg.world_size * cfg.batch |
| + cfg.rank * cfg.batch + i) |
| obs = env.reset(seed=seed, episode_id=f"train-r{cfg.rank}-{step_idx}-{i}") |
| envs.append(env) |
| obss.append(obs) |
|
|
| rollout: List[RolloutStep] = [] |
|
|
| for week in range(1, SimMartEnvironment.MAX_WEEKS + 1): |
| |
| if dual_head: |
| model.set_adapter("default") |
| prompts_text = [ |
| _apply_chat(tokenizer, build_action_chat(obs)) for obs in obss |
| ] |
| action_budget = cfg.action_max_tokens |
| else: |
| prompts_text = [_apply_chat(tokenizer, build_chat(obs)) for obs in obss] |
| action_budget = cfg.max_new_tokens |
|
|
| mixed_temp = ( |
| cfg.rollout_temp_low > 0 |
| and cfg.rollout_temp_high > 0 |
| and cfg.rollout_temp_low != cfg.rollout_temp_high |
| and len(prompts_text) >= 2 |
| ) |
| if mixed_temp: |
| mid = len(prompts_text) // 2 |
| txt_lo, in_lo, comp_lo, im_lo, cm_lo = batched_generate( |
| model, tokenizer, prompts_text[:mid], max_new_tokens=action_budget, |
| temperature=cfg.rollout_temp_low, top_p=cfg.rollout_top_p, |
| ) |
| txt_hi, in_hi, comp_hi, im_hi, cm_hi = batched_generate( |
| model, tokenizer, prompts_text[mid:], max_new_tokens=action_budget, |
| temperature=cfg.rollout_temp_high, top_p=cfg.rollout_top_p, |
| ) |
| completions_text = list(txt_lo) + list(txt_hi) |
| input_ids_list = ( |
| [in_lo[i].detach().cpu() for i in range(in_lo.size(0))] |
| + [in_hi[i].detach().cpu() for i in range(in_hi.size(0))] |
| ) |
| completion_ids_list = ( |
| [comp_lo[i].detach().cpu() for i in range(comp_lo.size(0))] |
| + [comp_hi[i].detach().cpu() for i in range(comp_hi.size(0))] |
| ) |
| input_mask_list = ( |
| [im_lo[i].detach().cpu() for i in range(im_lo.size(0))] |
| + [im_hi[i].detach().cpu() for i in range(im_hi.size(0))] |
| ) |
| completion_mask_list = ( |
| [cm_lo[i].detach().cpu() for i in range(cm_lo.size(0))] |
| + [cm_hi[i].detach().cpu() for i in range(cm_hi.size(0))] |
| ) |
| else: |
| completions_text_t, input_ids, completion_ids, input_mask, completion_mask = \ |
| batched_generate( |
| model, tokenizer, prompts_text, max_new_tokens=action_budget, |
| temperature=cfg.rollout_temperature, top_p=cfg.rollout_top_p, |
| ) |
| completions_text = list(completions_text_t) |
| input_ids_list = [input_ids[i].detach().cpu() for i in range(input_ids.size(0))] |
| completion_ids_list = [completion_ids[i].detach().cpu() for i in range(completion_ids.size(0))] |
| input_mask_list = [input_mask[i].detach().cpu() for i in range(input_mask.size(0))] |
| completion_mask_list = [completion_mask[i].detach().cpu() for i in range(completion_mask.size(0))] |
|
|
| |
| |
| parsed: List[Tuple[Any, Dict[str, Any]]] = [] |
| for obs, comp_text in zip(obss, completions_text): |
| parsed.append(parse_response(comp_text, obs.inbox)) |
|
|
| |
| if dual_head: |
| model.set_adapter("journal") |
| journal_prompts = [ |
| _apply_chat( |
| tokenizer, |
| build_journal_chat( |
| obs, action.decisions, action.budget_allocations, |
| ), |
| ) |
| for obs, (action, _) in zip(obss, parsed) |
| ] |
| with torch.inference_mode(): |
| jc_text, _, _, _, _ = batched_generate( |
| model, tokenizer, journal_prompts, |
| max_new_tokens=cfg.journal_max_tokens, |
| temperature=cfg.rollout_temperature, top_p=cfg.rollout_top_p, |
| ) |
| for (action, _), jt in zip(parsed, jc_text): |
| action.journal_entry = parse_journal_response(jt) |
| |
| |
| model.set_adapter("default") |
|
|
| |
| for i, (env, obs, (action, tel), comp_text) in enumerate( |
| zip(envs, obss, parsed, completions_text), |
| ): |
| rogue_ids_this_week: set = set() |
| for r in env.state.rogue_incidents: |
| if week in r.active_weeks: |
| rogue_ids_this_week.update(r.associated_proposal_ids) |
| was_rogue_week = len(rogue_ids_this_week) > 0 |
|
|
| step_obs = env.step(action) |
|
|
| caught = any( |
| d.verdict == "flag_suspicious" and d.proposal_id in rogue_ids_this_week |
| for d in action.decisions |
| ) |
|
|
| env_reward = float(step_obs.reward or 0.0) |
| parse_ok = tel["parse_ok"] or tel.get("parse_partial", False) |
| shaped_reward = env_reward - cfg.fmt_penalty * (0.0 if parse_ok else 1.0) |
| rollout.append(RolloutStep( |
| week=week, |
| env_idx=i, |
| prompt_text=prompts_text[i], |
| completion_text=comp_text, |
| prompt_ids=input_ids_list[i], |
| completion_ids=completion_ids_list[i], |
| prompt_mask=input_mask_list[i], |
| completion_mask=completion_mask_list[i], |
| reward=shaped_reward, |
| parse_ok=parse_ok, |
| was_rogue_week=was_rogue_week, |
| caught_rogue=caught, |
| )) |
|
|
| obss[i] = step_obs |
|
|
| |
| if mixed_temp: |
| del in_lo, in_hi, comp_lo, comp_hi, im_lo, im_hi, cm_lo, cm_hi |
| else: |
| del input_ids, completion_ids, input_mask, completion_mask |
| torch.cuda.empty_cache() |
|
|
| return rollout |
|
|
|
|
| |
| |
| |
|
|
| def compute_advantages(rollout: List[RolloutStep]) -> List[float]: |
| """Group-normalise rewards: per-week (r - mean) / (std + eps). |
| |
| Matches the DeepSeek GRPO trick: the 'group' is the B parallel rollouts |
| at the same week index, so advantage is well-conditioned even as absolute |
| reward drifts over training. |
| """ |
| by_week: Dict[int, List[float]] = {} |
| for r in rollout: |
| by_week.setdefault(r.week, []).append(r.reward) |
|
|
| week_stats = { |
| w: (statistics.mean(xs), statistics.stdev(xs) if len(xs) > 1 else 1e-6) |
| for w, xs in by_week.items() |
| } |
| eps = 1e-6 |
| return [(r.reward - week_stats[r.week][0]) / (week_stats[r.week][1] + eps) for r in rollout] |
|
|
|
|
| |
| |
| |
|
|
| def train_step( |
| model, tokenizer, ref_model, optimizer, cfg: TrainConfig, step_idx: int, |
| ) -> StepLog: |
| t0 = time.time() |
| rollout = rollout_batch(model, tokenizer, cfg, step_idx) |
|
|
| rewards = [r.reward for r in rollout] |
| parse_ok_count = sum(1 for r in rollout if r.parse_ok) |
| n_rogue_weeks = sum(1 for r in rollout if r.was_rogue_week) |
| n_caught = sum(1 for r in rollout if r.caught_rogue) |
|
|
| advantages = compute_advantages(rollout) |
|
|
| |
| |
| |
| total_loss = 0.0 |
| total_pg = 0.0 |
| total_kl = 0.0 |
| total_ent = 0.0 |
| seen = 0 |
|
|
| mb_size = cfg.mb_size |
| model.train() |
| for mb_start in range(0, len(rollout), mb_size): |
| mb = rollout[mb_start:mb_start + mb_size] |
| advs_mb = advantages[mb_start:mb_start + mb_size] |
|
|
| |
| input_ids = torch.nn.utils.rnn.pad_sequence( |
| [r.prompt_ids for r in mb], |
| batch_first=True, padding_value=tokenizer.pad_token_id, |
| ).to(model.device) |
| completion_ids = torch.nn.utils.rnn.pad_sequence( |
| [r.completion_ids for r in mb], |
| batch_first=True, padding_value=tokenizer.pad_token_id, |
| ).to(model.device) |
| input_mask = torch.nn.utils.rnn.pad_sequence( |
| [r.prompt_mask for r in mb], |
| batch_first=True, padding_value=0, |
| ).to(model.device) |
| completion_mask = torch.nn.utils.rnn.pad_sequence( |
| [r.completion_mask for r in mb], |
| batch_first=True, padding_value=0, |
| ).to(model.device) |
|
|
| |
| completion_logp, entropy, mask = compute_entropy_and_logp( |
| model, input_ids, completion_ids, input_mask, completion_mask, |
| ) |
|
|
| |
| |
| |
| if ref_model is not None and cfg.beta_kl > 0: |
| with torch.inference_mode(), ref_model.disable_adapter(): |
| ref_logp, _ = compute_completion_logprobs( |
| ref_model, input_ids, completion_ids, input_mask, completion_mask, |
| ) |
| kl_per_tok = (completion_logp - ref_logp.detach()) * mask |
| else: |
| kl_per_tok = torch.zeros_like(completion_logp) |
|
|
| |
| denom = mask.sum(dim=1).clamp_min(1.0) |
| logp_per_sample = (completion_logp * mask).sum(dim=1) / denom |
| entropy_per_sample = (entropy * mask).sum(dim=1) / denom |
| kl_per_sample = kl_per_tok.sum(dim=1) / denom |
|
|
| adv_t = torch.tensor(advs_mb, device=model.device, dtype=logp_per_sample.dtype) |
| pg = -(adv_t * logp_per_sample).mean() |
| ent_term = -cfg.entropy_coef * entropy_per_sample.mean() |
| kl_term = cfg.beta_kl * kl_per_sample.mean() |
|
|
| loss = pg + kl_term + ent_term |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| |
| |
| all_reduce_mean_gradients(model, cfg.world_size) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.clip_grad) |
| optimizer.step() |
|
|
| total_loss += float(loss.item()) * len(mb) |
| total_pg += float(pg.item()) * len(mb) |
| total_kl += float(kl_term.item()) * len(mb) |
| total_ent += float(entropy_per_sample.mean().item()) * len(mb) |
| seen += len(mb) |
|
|
| del input_ids, completion_ids, input_mask, completion_mask |
| del completion_logp, entropy, mask, kl_per_tok |
| torch.cuda.empty_cache() |
|
|
| |
| |
| |
| episode_returns: Dict[int, float] = {} |
| for r in rollout: |
| episode_returns[r.env_idx] = episode_returns.get(r.env_idx, 0.0) + r.reward |
|
|
| elapsed = time.time() - t0 |
| return StepLog( |
| step=step_idx, |
| mean_reward=statistics.mean(rewards), |
| max_reward=max(rewards), |
| min_reward=min(rewards), |
| reward_std=statistics.stdev(rewards) if len(rewards) > 1 else 0.0, |
| mean_episode_return=statistics.mean(list(episode_returns.values())), |
| parse_error_rate=1.0 - (parse_ok_count / max(1, len(rollout))), |
| rogue_recall=(n_caught / n_rogue_weeks) if n_rogue_weeks > 0 else 0.0, |
| loss=total_loss / max(1, seen), |
| pg_loss=total_pg / max(1, seen), |
| kl_loss=total_kl / max(1, seen), |
| entropy=total_ent / max(1, seen), |
| elapsed_s=elapsed, |
| ) |
|
|
|
|
| |
| |
| |
|
|
| def _reduce_log(log: StepLog, world_size: int, device) -> StepLog: |
| """All-reduce metric averages across ranks so rank-0 logs reflect the |
| full batch, not just one GPU's slice.""" |
| if world_size == 1: |
| return log |
| fields = [ |
| "mean_reward", "max_reward", "min_reward", "reward_std", |
| "mean_episode_return", "parse_error_rate", "rogue_recall", |
| "loss", "pg_loss", "kl_loss", "entropy", |
| ] |
| values = torch.tensor( |
| [getattr(log, f) for f in fields], device=device, dtype=torch.float32, |
| ) |
| dist.all_reduce(values, op=dist.ReduceOp.SUM) |
| values = values / world_size |
| |
| vmax = torch.tensor([log.max_reward], device=device, dtype=torch.float32) |
| vmin = torch.tensor([log.min_reward], device=device, dtype=torch.float32) |
| dist.all_reduce(vmax, op=dist.ReduceOp.MAX) |
| dist.all_reduce(vmin, op=dist.ReduceOp.MIN) |
|
|
| kv = dict(zip(fields, values.tolist())) |
| kv["max_reward"] = vmax.item() |
| kv["min_reward"] = vmin.item() |
| kv["step"] = log.step |
| kv["elapsed_s"] = log.elapsed_s |
| return StepLog(**kv) |
|
|
|
|
| def main() -> int: |
| p = argparse.ArgumentParser() |
| p.add_argument("--model", default="Qwen/Qwen2.5-1.5B-Instruct") |
| p.add_argument("--steps", type=int, default=30) |
| p.add_argument("--batch", type=int, default=4) |
| p.add_argument("--lr", type=float, default=1e-5) |
| p.add_argument("--lr-min", type=float, default=0.0, |
| help="If > 0, cosine-anneal lr -> lr_min over --steps (default: flat lr)") |
| p.add_argument("--max-seq-len", type=int, default=4096) |
| p.add_argument("--max-new-tokens", type=int, default=768) |
| p.add_argument("--kl", type=float, default=0.02, dest="beta_kl") |
| p.add_argument("--entropy", type=float, default=0.01, dest="entropy_coef") |
| p.add_argument("--seed-offset", type=int, default=0) |
| p.add_argument("--out", default="./simmart-run", dest="out_dir") |
| p.add_argument("--save-every", type=int, default=10) |
| p.add_argument("--wandb", type=str, default=None) |
| p.add_argument("--dtype", choices=["float16", "bfloat16"], default="bfloat16") |
| p.add_argument("--init-adapter", default=None, |
| help="Load LoRA weights from this SFT checkpoint at startup (action head)") |
| p.add_argument("--journal-adapter", default=None, |
| help="Optional: load a frozen journal LoRA as second adapter. " |
| "Triggers dual-head (two-pass) rollout mode.") |
| p.add_argument("--action-max-tokens", type=int, default=300, |
| help="Output budget for the action head JSON (dual-head mode)") |
| p.add_argument("--journal-max-tokens", type=int, default=400, |
| help="Output budget for the journal head text (dual-head mode)") |
| p.add_argument("--mb-size", type=int, default=8, |
| help="Minibatch size for the PG pass (was 2; bigger = fewer forward passes)") |
| p.add_argument("--rollout-temperature", type=float, default=0.9, dest="rollout_temperature", |
| help="Sampling temperature for training rollouts (default 0.9; lower shrinks sampled-vs-greedy gap)") |
| p.add_argument("--rollout-top-p", type=float, default=0.95, dest="rollout_top_p", |
| help="Nucleus sampling top_p for training rollouts (default 0.95)") |
| p.add_argument("--fmt-penalty", type=float, default=0.0, dest="fmt_penalty", |
| help="Magnitude (>=0) subtracted from env reward when parse fails (default 0)") |
| p.add_argument("--rollout-temp-low", type=float, default=-1.0, dest="rollout_temp_low", |
| help="Low temperature for mixed-temp rollouts (set both -low and -high to enable; default off)") |
| p.add_argument("--rollout-temp-high", type=float, default=-1.0, dest="rollout_temp_high", |
| help="High temperature for mixed-temp rollouts (default off; falls back to --rollout-temperature)") |
| args = p.parse_args() |
|
|
| rank, local_rank, world_size = init_distributed() |
| is_main = rank == 0 |
|
|
| cfg = TrainConfig( |
| model=args.model, steps=args.steps, batch=args.batch, |
| max_seq_len=args.max_seq_len, max_new_tokens=args.max_new_tokens, |
| lr=args.lr, lr_min=args.lr_min, beta_kl=args.beta_kl, entropy_coef=args.entropy_coef, |
| seed_offset=args.seed_offset, out_dir=args.out_dir, |
| save_every=args.save_every, wandb_project=args.wandb, dtype=args.dtype, |
| init_adapter=args.init_adapter, |
| journal_adapter=args.journal_adapter, |
| action_max_tokens=args.action_max_tokens, |
| journal_max_tokens=args.journal_max_tokens, |
| mb_size=args.mb_size, |
| rollout_temperature=args.rollout_temperature, |
| rollout_top_p=args.rollout_top_p, |
| fmt_penalty=args.fmt_penalty, |
| rollout_temp_low=args.rollout_temp_low, |
| rollout_temp_high=args.rollout_temp_high, |
| rank=rank, local_rank=local_rank, world_size=world_size, |
| ) |
| out_dir = Path(cfg.out_dir) |
| if is_main: |
| out_dir.mkdir(parents=True, exist_ok=True) |
| with open(out_dir / "config.json", "w") as f: |
| json.dump(asdict(cfg), f, indent=2) |
|
|
| wandb = None |
| if cfg.wandb_project and is_main: |
| try: |
| import wandb as _wb |
| wandb = _wb |
| wandb.init(project=cfg.wandb_project, config=asdict(cfg)) |
| except Exception as e: |
| print(f"[warn] wandb disabled: {e}") |
|
|
| if is_main: |
| print(f"Loading policy: {cfg.model} (4-bit LoRA r={cfg.lora_r})") |
| if cfg.init_adapter: |
| print(f"Init adapter: {cfg.init_adapter}") |
| print(f"DDP: rank={rank}/{world_size} local_rank={local_rank} " |
| f"per_rank_batch={cfg.batch} global_batch={cfg.batch * world_size}") |
|
|
| model, tokenizer = load_policy(cfg) |
| if is_main: |
| print(f"Device: {model.device} dtype: {cfg.dtype}") |
|
|
| ref_model = model |
|
|
| optimizer = torch.optim.AdamW( |
| [p for p in model.parameters() if p.requires_grad], |
| lr=cfg.lr, betas=(0.9, 0.95), weight_decay=0.0, |
| ) |
| scheduler = None |
| if 0 < cfg.lr_min < cfg.lr: |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, T_max=cfg.steps, eta_min=cfg.lr_min, |
| ) |
| if is_main: |
| print(f"[lr] cosine decay: {cfg.lr:.2e} -> {cfg.lr_min:.2e} over {cfg.steps} steps") |
|
|
| |
| if world_size > 1: |
| dist.barrier() |
|
|
| history: List[StepLog] = [] |
| for step in range(1, cfg.steps + 1): |
| log = train_step(model, tokenizer, ref_model, optimizer, cfg, step) |
| reduced = _reduce_log(log, world_size, model.device) |
| history.append(reduced) |
|
|
| if is_main: |
| log_kv(step, { |
| "mean_r": reduced.mean_reward, |
| "ep_ret": reduced.mean_episode_return, |
| "r_std": reduced.reward_std, |
| "max_r": reduced.max_reward, |
| "parse_err": reduced.parse_error_rate, |
| "rogue_rec": reduced.rogue_recall, |
| "loss": reduced.loss, |
| "pg": reduced.pg_loss, |
| "kl": reduced.kl_loss, |
| "ent": reduced.entropy, |
| "sec/step": reduced.elapsed_s, |
| "lr": optimizer.param_groups[0]["lr"], |
| }) |
| if wandb: |
| wandb.log({f"train/{k}": v for k, v in asdict(reduced).items()}, step=step) |
| with open(out_dir / "history.jsonl", "a") as f: |
| f.write(json.dumps(asdict(reduced)) + "\n") |
|
|
| if (step % cfg.save_every == 0 or step == cfg.steps): |
| if is_main: |
| ckpt = out_dir / f"adapter-step-{step:03d}" |
| model.save_pretrained(str(ckpt)) |
| tokenizer.save_pretrained(str(ckpt)) |
| print(f"[ckpt] saved {ckpt}") |
| if world_size > 1: |
| dist.barrier() |
|
|
| if scheduler is not None: |
| scheduler.step() |
|
|
| if is_main: |
| print("Training complete.") |
| if wandb: |
| wandb.finish() |
| if world_size > 1: |
| dist.destroy_process_group() |
| return 0 |
|
|
|
|
| if __name__ == "__main__": |
| raise SystemExit(main()) |
|
|