| |
| """Strawman baseline for the rebuttal. |
| |
| Vanilla Qwen2.5-1.5B-Instruct + LoRA on top of the *existing* JSONL data |
| (`data/sudoku_t3_20empty_value_qwen_text_stage1_{train,eval}.jsonl`). |
| |
| Compared to the cell-policy / latent recipes, this strawman intentionally |
| removes everything that helped: |
| |
| - NO curriculum (single stage; we don't even read `stage_i`). |
| - NO chain-of-thought / latent thought tokens. |
| - NO per-cell expansion (one example == one whole puzzle). |
| - NO multi-value oversampling, no special reward shaping (just matches/N). |
| |
| It uses the *same* model, *same* LoRA config, *same* tokenizer + chat |
| template wrapping that every cell-policy experiment used, so any solve |
| gap vs the cell-policy / latent runs is purely due to task framing, |
| not data, prompt, model, or PEFT differences. |
| |
| Usage: |
| python simple_baseline_sudoku_train.py --phase sft --output_dir <out>/sft --learning_rate 5e-5 |
| python simple_baseline_sudoku_train.py --phase grpo --init_adapter_dir <out>/sft/final \ |
| --output_dir <out>/grpo --learning_rate 5e-6 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import math |
| import os |
| import re |
| import sys |
| import time |
| from pathlib import Path |
| from typing import Any, Callable, Dict, List, Optional |
|
|
| import torch |
| from datasets import Dataset |
| from peft import LoraConfig, PeftModel, get_peft_model |
| from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed |
|
|
| |
| ROOT = Path(__file__).resolve().parent.parent |
| if str(ROOT) not in sys.path: |
| sys.path.insert(0, str(ROOT)) |
|
|
| from multi_output_cell_policy.sft_multi_output_train import ( |
| load_jsonl_rows, |
| pick_dtype, |
| ) |
| from multi_output_cell_policy.rewards import score_prediction_text |
| from multi_output_cell_policy.shared_multi_output_policy import ( |
| make_solved_grid_from_row, |
| stage_i_consistent_values, |
| ) |
| from aligned_cell_policy.shared_cell_policy import build_cell_examples_from_row |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| SYSTEM_PROMPT_STRAWMAN = ( |
| "You are a Sudoku solver.\n" |
| "You will be given a 9x9 Sudoku grid encoded as (row,col,value) tuples in " |
| "row-major order, where value 0 marks an empty cell.\n" |
| "Predict the missing values for ALL empty cells in row-major order.\n" |
| "Return ONLY a JSON list of integers like [v1,v2,...,vK], where K is the " |
| "number of empty cells (typically 20). Each value must be an integer in " |
| "[1,9].\n" |
| "Do not include any explanation, markdown, or text outside the JSON list." |
| ) |
|
|
|
|
| def build_chat_prompt(tokenizer: Any, raw_prompt: str) -> str: |
| """Same chat template wrapping every other experiment uses (Qwen, system+user).""" |
| messages = [ |
| {"role": "system", "content": SYSTEM_PROMPT_STRAWMAN.strip()}, |
| {"role": "user", "content": raw_prompt}, |
| ] |
| chat_template = getattr(tokenizer, "chat_template", None) |
| if chat_template: |
| return tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=True |
| ) |
| return SYSTEM_PROMPT_STRAWMAN.strip() + "\n\n" + raw_prompt + "\n" |
|
|
|
|
| |
|
|
| LIST_RE = re.compile(r"\[[^\[\]]*\]") |
|
|
|
|
| def parse_int_list(text: str) -> Optional[List[int]]: |
| """Parse the model's emission as a JSON int list with values in [1,9]. |
| |
| Tolerant: tries the whole completion first, then falls back to the first |
| well-formed JSON list match. Returns None on failure. |
| """ |
| s = str(text).strip() |
| if not s: |
| return None |
| candidates: List[str] = [] |
| candidates.append(s) |
| m = LIST_RE.search(s) |
| if m is not None: |
| candidates.append(m.group(0)) |
| for cand in candidates: |
| try: |
| obj = json.loads(cand) |
| except Exception: |
| continue |
| if not isinstance(obj, list): |
| continue |
| out: List[int] = [] |
| ok = True |
| for v in obj: |
| if isinstance(v, bool) or not isinstance(v, int): |
| ok = False |
| break |
| if v < 1 or v > 9: |
| ok = False |
| break |
| out.append(int(v)) |
| if ok: |
| return out |
| return None |
|
|
|
|
| def whole_puzzle_reward( |
| *, |
| pred_list: Optional[List[int]], |
| target_list: List[int], |
| parse_penalty: float = 4.0, |
| length_mismatch_penalty: float = 0.5, |
| full_solve_bonus: float = 5.0, |
| ) -> float: |
| """Simple reward: matches per cell + bonus for full solve, penalty if parse fails.""" |
| if pred_list is None: |
| return -float(parse_penalty) |
| n = len(target_list) |
| matches = 0 |
| for i in range(min(len(pred_list), n)): |
| if int(pred_list[i]) == int(target_list[i]): |
| matches += 1 |
| reward = float(matches) |
| if len(pred_list) != n: |
| reward -= float(length_mismatch_penalty) * abs(len(pred_list) - n) |
| if len(pred_list) == n and matches == n: |
| reward += float(full_solve_bonus) |
| return reward |
|
|
|
|
| |
|
|
|
|
| def build_dataset(rows: List[Dict[str, Any]], tokenizer: Any) -> Dataset: |
| prompts, completions, targets = [], [], [] |
| for row in rows: |
| raw_prompt = str(row["prompt"]).strip() |
| completion_str = str(row["completion"]).strip() |
| target = parse_int_list(completion_str) |
| if target is None: |
| continue |
| prompts.append(build_chat_prompt(tokenizer, raw_prompt)) |
| completions.append(completion_str) |
| targets.append(json.dumps(target, separators=(",", ":"))) |
| return Dataset.from_dict( |
| {"prompt": prompts, "completion": completions, "target": targets} |
| ) |
|
|
|
|
| |
|
|
|
|
| @torch.no_grad() |
| @torch.no_grad() |
| def run_eval( |
| model: torch.nn.Module, |
| tokenizer: Any, |
| eval_rows: List[Dict[str, Any]], |
| device: torch.device, |
| max_new_tokens: int = 96, |
| print_n: int = 3, |
| stage_i: int = 3, |
| ) -> Dict[str, float]: |
| """Apples-to-apples eval with the cell-policy framework. |
| |
| The strawman model emits the WHOLE puzzle (a JSON list of integers) in |
| one forward pass. We then split that list into per-cell SINGLETON |
| predictions and score each cell with the same ``score_prediction_text`` |
| function the cell-policy / latent baselines use, against the i-consistent |
| target set at ``stage_i`` (default 3 — matching the S3 eval used for the |
| rebuttal v6 baseline and the latent champion). |
| |
| Reported metrics mirror ``multi_output_cell_policy/sft_multi_output_train.py::run_eval`` |
| so numbers are directly comparable across all four 2x2 ablation cells. |
| """ |
| model.eval() |
| total_cells = 0 |
| parse_ok = 0.0 |
| canonical_ok = 0.0 |
| exact_set_match = 0.0 |
| includes_gt = 0.0 |
| precision_sum = 0.0 |
| recall_sum = 0.0 |
| cardinality_match_sum = 0.0 |
| n_solve = 0 |
| n_total_puzzles = 0 |
| n_parse_fail_puzzles = 0 |
| printed = 0 |
| for row in eval_rows: |
| target_completion = parse_int_list(str(row["completion"])) |
| if target_completion is None: |
| continue |
| n_total_puzzles += 1 |
| prompt = build_chat_prompt(tokenizer, str(row["prompt"]).strip()) |
| enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False) |
| enc = {k: v.to(device) for k, v in enc.items()} |
| out = model.generate( |
| **enc, |
| max_new_tokens=int(max_new_tokens), |
| do_sample=False, |
| eos_token_id=tokenizer.eos_token_id, |
| pad_token_id=tokenizer.pad_token_id, |
| ) |
| gen = tokenizer.decode( |
| out[0][int(enc["input_ids"].shape[1]) :], skip_special_tokens=True |
| ).strip() |
| pred_list = parse_int_list(gen) |
|
|
| try: |
| cells = build_cell_examples_from_row(row) |
| solved = make_solved_grid_from_row(row) |
| except Exception as e: |
| if printed < print_n: |
| print(f"[strawman eval debug] row skipped (no metadata): {e}", flush=True) |
| printed += 1 |
| continue |
|
|
| row_all_exact = True |
| row_has_eval_cell = False |
| for idx, ex in enumerate(cells): |
| target_values = stage_i_consistent_values( |
| ex.grid, target_cell=ex.target_cell, stage_i=int(stage_i) |
| ) |
| row_has_eval_cell = True |
| if pred_list is not None and idx < len(pred_list): |
| pred_text = json.dumps({"values": [int(pred_list[idx])]}) |
| else: |
| pred_text = "" |
| info = score_prediction_text( |
| text=pred_text, |
| grid=ex.grid, |
| solved=solved, |
| target_cell=ex.target_cell, |
| stage_i=int(stage_i), |
| reward_good_value=1.0, |
| penalty_bad_value=1.75, |
| penalty_malformed=4.0, |
| penalty_empty=0.5, |
| penalty_singleton=1.5, |
| ) |
| total_cells += 1 |
| parse_ok += float(info["parse_ok"]) |
| canonical_ok += float(info["strict_canonical"]) |
| exact_set_match += float(info["exact_set_match"]) |
| includes_gt += float(info["includes_ground_truth"]) |
| precision_sum += float(info["value_precision"]) |
| recall_sum += float(info["value_recall"]) |
| if int(info["num_predicted_values"]) == int(len(target_values)): |
| cardinality_match_sum += 1.0 |
| if float(info["exact_set_match"]) < 0.5: |
| row_all_exact = False |
| if row_has_eval_cell and row_all_exact: |
| n_solve += 1 |
| if pred_list is None: |
| n_parse_fail_puzzles += 1 |
| if printed < print_n: |
| head_pred = pred_list if pred_list is not None else "PARSE_FAIL" |
| print( |
| f"[strawman eval debug] target={target_completion} pred={head_pred} " |
| f"solve={int(row_all_exact and row_has_eval_cell)} gen={gen!r}", |
| flush=True, |
| ) |
| printed += 1 |
| return { |
| "n_total_cells": float(total_cells), |
| "n_total_puzzles": float(n_total_puzzles), |
| "parse_rate": float(parse_ok / max(1, total_cells)), |
| "strict_canonical_rate": float(canonical_ok / max(1, total_cells)), |
| "exact_set_match_rate": float(exact_set_match / max(1, total_cells)), |
| "includes_ground_truth_rate": float(includes_gt / max(1, total_cells)), |
| "value_precision": float(precision_sum / max(1, total_cells)), |
| "value_recall": float(recall_sum / max(1, total_cells)), |
| "cardinality_match_rate": float(cardinality_match_sum / max(1, total_cells)), |
| "puzzle_parse_fail_rate": float(n_parse_fail_puzzles / max(1, n_total_puzzles)), |
| "solve_rate": float(n_solve) / max(1, n_total_puzzles), |
| } |
|
|
|
|
| |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| p = argparse.ArgumentParser() |
| p.add_argument("--phase", choices=["sft", "grpo"], required=True) |
| p.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-1.5B-Instruct") |
| p.add_argument("--train_jsonl", type=str, required=True) |
| p.add_argument("--eval_jsonl", type=str, required=True) |
| p.add_argument("--output_dir", type=str, required=True) |
| p.add_argument("--cache_dir", type=str, default=str(ROOT / ".hf_cache")) |
| p.add_argument("--init_adapter_dir", type=str, default="") |
| p.add_argument("--seed", type=int, default=0) |
|
|
| |
| p.add_argument("--limit_train_rows", type=int, default=10000) |
| p.add_argument("--eval_rows", type=int, default=100) |
|
|
| |
| p.add_argument("--per_device_train_batch_size", type=int, default=8) |
| p.add_argument("--gradient_accumulation_steps", type=int, default=2) |
| p.add_argument("--learning_rate", type=float, default=5e-5) |
| p.add_argument("--weight_decay", type=float, default=0.0) |
| p.add_argument("--num_epochs", type=float, default=8.0) |
| p.add_argument("--max_steps", type=int, default=2000) |
| p.add_argument("--logging_steps", type=int, default=25) |
| p.add_argument("--save_steps", type=int, default=200) |
| p.add_argument("--eval_steps", type=int, default=150) |
| p.add_argument("--max_grad_norm", type=float, default=1.0) |
| p.add_argument("--max_completion_length", type=int, default=96) |
| p.add_argument("--max_prompt_length", type=int, default=1024) |
|
|
| |
| p.add_argument("--lora_r", type=int, default=32) |
| p.add_argument("--lora_alpha", type=int, default=64) |
| p.add_argument("--lora_dropout", type=float, default=0.05) |
| p.add_argument("--enable_gradient_checkpointing", action="store_true") |
|
|
| |
| p.add_argument("--num_generations", type=int, default=8) |
| p.add_argument("--beta", type=float, default=0.0) |
| p.add_argument("--temperature", type=float, default=1.0) |
| p.add_argument("--full_solve_bonus", type=float, default=5.0) |
| p.add_argument("--length_mismatch_penalty", type=float, default=0.5) |
| p.add_argument("--parse_penalty", type=float, default=4.0) |
|
|
| |
| p.add_argument("--use_wandb", action="store_true") |
| p.add_argument("--wandb_project", type=str, default="sudoku-strawman-baseline") |
| p.add_argument("--wandb_run_name", type=str, default="") |
| p.add_argument("--wandb_mode", type=str, default="offline") |
|
|
| return p.parse_args() |
|
|
|
|
| def setup_model_and_tokenizer(args: argparse.Namespace, device: torch.device): |
| tokenizer = AutoTokenizer.from_pretrained( |
| args.model_name, cache_dir=args.cache_dir, use_fast=True |
| ) |
| if tokenizer.pad_token_id is None: |
| tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>" |
| if tokenizer.padding_side != "left": |
| tokenizer.padding_side = "left" |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| args.model_name, |
| cache_dir=args.cache_dir, |
| torch_dtype=pick_dtype(), |
| low_cpu_mem_usage=True, |
| ) |
| if str(args.init_adapter_dir).strip(): |
| model = PeftModel.from_pretrained(model, args.init_adapter_dir, is_trainable=True) |
| else: |
| lora = LoraConfig( |
| r=args.lora_r, |
| lora_alpha=args.lora_alpha, |
| lora_dropout=args.lora_dropout, |
| bias="none", |
| task_type="CAUSAL_LM", |
| target_modules=[ |
| "q_proj", "k_proj", "v_proj", "o_proj", |
| "gate_proj", "up_proj", "down_proj", |
| ], |
| ) |
| model = get_peft_model(model, lora) |
|
|
| if args.enable_gradient_checkpointing: |
| if hasattr(model, "gradient_checkpointing_enable"): |
| model.gradient_checkpointing_enable( |
| gradient_checkpointing_kwargs={"use_reentrant": False} |
| ) |
| if hasattr(model, "enable_input_require_grads"): |
| model.enable_input_require_grads() |
| if hasattr(model, "config"): |
| model.config.use_cache = False |
| model.to(device) |
| return model, tokenizer |
|
|
|
|
| def run_sft(args: argparse.Namespace) -> None: |
| from trl import SFTConfig, SFTTrainer |
|
|
| set_seed(int(args.seed)) |
| os.makedirs(args.output_dir, exist_ok=True) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| train_rows = load_jsonl_rows(args.train_jsonl, limit_rows=int(args.limit_train_rows)) |
| eval_rows = load_jsonl_rows(args.eval_jsonl, limit_rows=int(args.eval_rows)) |
|
|
| model, tokenizer = setup_model_and_tokenizer(args, device) |
|
|
| |
| train_ds = build_dataset(train_rows, tokenizer) |
|
|
| cfg = SFTConfig( |
| output_dir=args.output_dir, |
| per_device_train_batch_size=int(args.per_device_train_batch_size), |
| gradient_accumulation_steps=int(args.gradient_accumulation_steps), |
| learning_rate=float(args.learning_rate), |
| weight_decay=float(args.weight_decay), |
| num_train_epochs=float(args.num_epochs), |
| max_steps=int(args.max_steps), |
| logging_steps=int(args.logging_steps), |
| save_steps=int(args.save_steps), |
| save_strategy="steps", |
| save_total_limit=4, |
| eval_strategy="no", |
| bf16=(pick_dtype() == torch.bfloat16), |
| fp16=(pick_dtype() == torch.float16), |
| max_grad_norm=float(args.max_grad_norm), |
| gradient_checkpointing=bool(args.enable_gradient_checkpointing), |
| report_to=("wandb" if args.use_wandb else "none"), |
| run_name=(args.wandb_run_name or None), |
| max_length=int(args.max_prompt_length + args.max_completion_length + 8), |
| completion_only_loss=True, |
| seed=int(args.seed), |
| ) |
|
|
| trainer = SFTTrainer( |
| model=model, |
| args=cfg, |
| train_dataset=train_ds, |
| processing_class=tokenizer, |
| ) |
|
|
| |
| |
| print("[strawman sft] BEFORE-train eval:", run_eval(model, tokenizer, eval_rows, device), flush=True) |
|
|
| t0 = time.time() |
| trainer.train() |
| print(f"[strawman sft] training time = {time.time() - t0:.1f}s", flush=True) |
|
|
| final_dir = os.path.join(args.output_dir, "final") |
| trainer.save_model(final_dir) |
| print(f"[strawman sft] saved final adapter to {final_dir}", flush=True) |
|
|
| print("[strawman sft] AFTER-train eval:", run_eval(model, tokenizer, eval_rows, device), flush=True) |
|
|
|
|
| def run_grpo(args: argparse.Namespace) -> None: |
| from trl import GRPOConfig, GRPOTrainer |
|
|
| set_seed(int(args.seed)) |
| os.makedirs(args.output_dir, exist_ok=True) |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| train_rows = load_jsonl_rows(args.train_jsonl, limit_rows=int(args.limit_train_rows)) |
| eval_rows = load_jsonl_rows(args.eval_jsonl, limit_rows=int(args.eval_rows)) |
|
|
| model, tokenizer = setup_model_and_tokenizer(args, device) |
| train_ds = build_dataset(train_rows, tokenizer) |
|
|
| parse_penalty = float(args.parse_penalty) |
| length_mismatch_penalty = float(args.length_mismatch_penalty) |
| full_solve_bonus = float(args.full_solve_bonus) |
|
|
| def reward_fn(completions, target, **kwargs): |
| rewards: List[float] = [] |
| for c, tgt in zip(completions, target): |
| tgt_list = json.loads(tgt) if isinstance(tgt, str) else list(tgt) |
| pred = parse_int_list(str(c)) |
| rewards.append( |
| whole_puzzle_reward( |
| pred_list=pred, |
| target_list=tgt_list, |
| parse_penalty=parse_penalty, |
| length_mismatch_penalty=length_mismatch_penalty, |
| full_solve_bonus=full_solve_bonus, |
| ) |
| ) |
| return rewards |
|
|
| cfg = GRPOConfig( |
| output_dir=args.output_dir, |
| per_device_train_batch_size=int(args.per_device_train_batch_size), |
| gradient_accumulation_steps=int(args.gradient_accumulation_steps), |
| learning_rate=float(args.learning_rate), |
| weight_decay=float(args.weight_decay), |
| num_train_epochs=float(args.num_epochs), |
| max_steps=int(args.max_steps), |
| logging_steps=int(args.logging_steps), |
| save_steps=int(args.save_steps), |
| save_strategy="steps", |
| save_total_limit=6, |
| bf16=(pick_dtype() == torch.bfloat16), |
| fp16=(pick_dtype() == torch.float16), |
| max_grad_norm=float(args.max_grad_norm), |
| gradient_checkpointing=bool(args.enable_gradient_checkpointing), |
| report_to=("wandb" if args.use_wandb else "none"), |
| run_name=(args.wandb_run_name or None), |
| max_prompt_length=int(args.max_prompt_length), |
| max_completion_length=int(args.max_completion_length), |
| num_generations=int(args.num_generations), |
| beta=float(args.beta), |
| temperature=float(args.temperature), |
| seed=int(args.seed), |
| ) |
|
|
| trainer = GRPOTrainer( |
| model=model, |
| reward_funcs=[reward_fn], |
| args=cfg, |
| train_dataset=train_ds, |
| processing_class=tokenizer, |
| ) |
|
|
| print("[strawman grpo] BEFORE-train eval:", run_eval(model, tokenizer, eval_rows, device), flush=True) |
|
|
| t0 = time.time() |
| trainer.train() |
| print(f"[strawman grpo] training time = {time.time() - t0:.1f}s", flush=True) |
|
|
| final_dir = os.path.join(args.output_dir, "final") |
| trainer.save_model(final_dir) |
| print(f"[strawman grpo] saved final adapter to {final_dir}", flush=True) |
|
|
| print("[strawman grpo] AFTER-train eval:", run_eval(model, tokenizer, eval_rows, device), flush=True) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| if args.use_wandb: |
| os.environ.setdefault("WANDB_MODE", str(args.wandb_mode)) |
| os.environ["WANDB_PROJECT"] = args.wandb_project |
| if args.phase == "sft": |
| run_sft(args) |
| else: |
| run_grpo(args) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|