curriculum-cot-code / _runs /simple_baseline_sudoku_train.py
Avra98's picture
Add data/ JSONLs + _runs/ launch scripts (override .gitignore)
48c96cf verified
#!/usr/bin/env python3
"""Strawman baseline for the rebuttal.
Vanilla Qwen2.5-1.5B-Instruct + LoRA on top of the *existing* JSONL data
(`data/sudoku_t3_20empty_value_qwen_text_stage1_{train,eval}.jsonl`).
Compared to the cell-policy / latent recipes, this strawman intentionally
removes everything that helped:
- NO curriculum (single stage; we don't even read `stage_i`).
- NO chain-of-thought / latent thought tokens.
- NO per-cell expansion (one example == one whole puzzle).
- NO multi-value oversampling, no special reward shaping (just matches/N).
It uses the *same* model, *same* LoRA config, *same* tokenizer + chat
template wrapping that every cell-policy experiment used, so any solve
gap vs the cell-policy / latent runs is purely due to task framing,
not data, prompt, model, or PEFT differences.
Usage:
python simple_baseline_sudoku_train.py --phase sft --output_dir <out>/sft --learning_rate 5e-5
python simple_baseline_sudoku_train.py --phase grpo --init_adapter_dir <out>/sft/final \
--output_dir <out>/grpo --learning_rate 5e-6
"""
from __future__ import annotations
import argparse
import json
import math
import os
import re
import sys
import time
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional
import torch
from datasets import Dataset
from peft import LoraConfig, PeftModel, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer, set_seed
# Reuse existing helpers (these are the canonical ones used by every cell-policy run).
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT) not in sys.path:
sys.path.insert(0, str(ROOT))
from multi_output_cell_policy.sft_multi_output_train import ( # type: ignore
load_jsonl_rows,
pick_dtype,
)
from multi_output_cell_policy.rewards import score_prediction_text # type: ignore
from multi_output_cell_policy.shared_multi_output_policy import ( # type: ignore
make_solved_grid_from_row,
stage_i_consistent_values,
)
from aligned_cell_policy.shared_cell_policy import build_cell_examples_from_row # type: ignore
# ---- Strawman task definition -----------------------------------------------
# This is the ONLY new piece relative to the cell-policy experiments. The
# system prompt asks the model to emit the missing values for ALL empty cells
# in one shot, in the row-major order that the existing JSONL `completion`
# field already uses. The user message is the raw `prompt` field from the
# JSONL (puzzle as (row,col,value) tuples), which is byte-identical to what
# `prompt_builder.py` consumes in cell-policy runs.
SYSTEM_PROMPT_STRAWMAN = (
"You are a Sudoku solver.\n"
"You will be given a 9x9 Sudoku grid encoded as (row,col,value) tuples in "
"row-major order, where value 0 marks an empty cell.\n"
"Predict the missing values for ALL empty cells in row-major order.\n"
"Return ONLY a JSON list of integers like [v1,v2,...,vK], where K is the "
"number of empty cells (typically 20). Each value must be an integer in "
"[1,9].\n"
"Do not include any explanation, markdown, or text outside the JSON list."
)
def build_chat_prompt(tokenizer: Any, raw_prompt: str) -> str:
"""Same chat template wrapping every other experiment uses (Qwen, system+user)."""
messages = [
{"role": "system", "content": SYSTEM_PROMPT_STRAWMAN.strip()},
{"role": "user", "content": raw_prompt},
]
chat_template = getattr(tokenizer, "chat_template", None)
if chat_template:
return tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return SYSTEM_PROMPT_STRAWMAN.strip() + "\n\n" + raw_prompt + "\n"
# ---- Reward -----------------------------------------------------------------
LIST_RE = re.compile(r"\[[^\[\]]*\]")
def parse_int_list(text: str) -> Optional[List[int]]:
"""Parse the model's emission as a JSON int list with values in [1,9].
Tolerant: tries the whole completion first, then falls back to the first
well-formed JSON list match. Returns None on failure.
"""
s = str(text).strip()
if not s:
return None
candidates: List[str] = []
candidates.append(s)
m = LIST_RE.search(s)
if m is not None:
candidates.append(m.group(0))
for cand in candidates:
try:
obj = json.loads(cand)
except Exception:
continue
if not isinstance(obj, list):
continue
out: List[int] = []
ok = True
for v in obj:
if isinstance(v, bool) or not isinstance(v, int):
ok = False
break
if v < 1 or v > 9:
ok = False
break
out.append(int(v))
if ok:
return out
return None
def whole_puzzle_reward(
*,
pred_list: Optional[List[int]],
target_list: List[int],
parse_penalty: float = 4.0,
length_mismatch_penalty: float = 0.5,
full_solve_bonus: float = 5.0,
) -> float:
"""Simple reward: matches per cell + bonus for full solve, penalty if parse fails."""
if pred_list is None:
return -float(parse_penalty)
n = len(target_list)
matches = 0
for i in range(min(len(pred_list), n)):
if int(pred_list[i]) == int(target_list[i]):
matches += 1
reward = float(matches)
if len(pred_list) != n:
reward -= float(length_mismatch_penalty) * abs(len(pred_list) - n)
if len(pred_list) == n and matches == n:
reward += float(full_solve_bonus)
return reward
# ---- Dataset construction ---------------------------------------------------
def build_dataset(rows: List[Dict[str, Any]], tokenizer: Any) -> Dataset:
prompts, completions, targets = [], [], []
for row in rows:
raw_prompt = str(row["prompt"]).strip()
completion_str = str(row["completion"]).strip()
target = parse_int_list(completion_str)
if target is None:
continue
prompts.append(build_chat_prompt(tokenizer, raw_prompt))
completions.append(completion_str)
targets.append(json.dumps(target, separators=(",", ":")))
return Dataset.from_dict(
{"prompt": prompts, "completion": completions, "target": targets}
)
# ---- Eval (deterministic, greedy, single-shot) ------------------------------
@torch.no_grad()
@torch.no_grad()
def run_eval(
model: torch.nn.Module,
tokenizer: Any,
eval_rows: List[Dict[str, Any]],
device: torch.device,
max_new_tokens: int = 96,
print_n: int = 3,
stage_i: int = 3,
) -> Dict[str, float]:
"""Apples-to-apples eval with the cell-policy framework.
The strawman model emits the WHOLE puzzle (a JSON list of integers) in
one forward pass. We then split that list into per-cell SINGLETON
predictions and score each cell with the same ``score_prediction_text``
function the cell-policy / latent baselines use, against the i-consistent
target set at ``stage_i`` (default 3 — matching the S3 eval used for the
rebuttal v6 baseline and the latent champion).
Reported metrics mirror ``multi_output_cell_policy/sft_multi_output_train.py::run_eval``
so numbers are directly comparable across all four 2x2 ablation cells.
"""
model.eval()
total_cells = 0
parse_ok = 0.0
canonical_ok = 0.0
exact_set_match = 0.0
includes_gt = 0.0
precision_sum = 0.0
recall_sum = 0.0
cardinality_match_sum = 0.0
n_solve = 0
n_total_puzzles = 0
n_parse_fail_puzzles = 0
printed = 0
for row in eval_rows:
target_completion = parse_int_list(str(row["completion"]))
if target_completion is None:
continue
n_total_puzzles += 1
prompt = build_chat_prompt(tokenizer, str(row["prompt"]).strip())
enc = tokenizer(prompt, return_tensors="pt", add_special_tokens=False)
enc = {k: v.to(device) for k, v in enc.items()}
out = model.generate(
**enc,
max_new_tokens=int(max_new_tokens),
do_sample=False,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
)
gen = tokenizer.decode(
out[0][int(enc["input_ids"].shape[1]) :], skip_special_tokens=True
).strip()
pred_list = parse_int_list(gen)
try:
cells = build_cell_examples_from_row(row)
solved = make_solved_grid_from_row(row)
except Exception as e:
if printed < print_n:
print(f"[strawman eval debug] row skipped (no metadata): {e}", flush=True)
printed += 1
continue
row_all_exact = True
row_has_eval_cell = False
for idx, ex in enumerate(cells):
target_values = stage_i_consistent_values(
ex.grid, target_cell=ex.target_cell, stage_i=int(stage_i)
)
row_has_eval_cell = True
if pred_list is not None and idx < len(pred_list):
pred_text = json.dumps({"values": [int(pred_list[idx])]})
else:
pred_text = ""
info = score_prediction_text(
text=pred_text,
grid=ex.grid,
solved=solved,
target_cell=ex.target_cell,
stage_i=int(stage_i),
reward_good_value=1.0,
penalty_bad_value=1.75,
penalty_malformed=4.0,
penalty_empty=0.5,
penalty_singleton=1.5,
)
total_cells += 1
parse_ok += float(info["parse_ok"])
canonical_ok += float(info["strict_canonical"])
exact_set_match += float(info["exact_set_match"])
includes_gt += float(info["includes_ground_truth"])
precision_sum += float(info["value_precision"])
recall_sum += float(info["value_recall"])
if int(info["num_predicted_values"]) == int(len(target_values)):
cardinality_match_sum += 1.0
if float(info["exact_set_match"]) < 0.5:
row_all_exact = False
if row_has_eval_cell and row_all_exact:
n_solve += 1
if pred_list is None:
n_parse_fail_puzzles += 1
if printed < print_n:
head_pred = pred_list if pred_list is not None else "PARSE_FAIL"
print(
f"[strawman eval debug] target={target_completion} pred={head_pred} "
f"solve={int(row_all_exact and row_has_eval_cell)} gen={gen!r}",
flush=True,
)
printed += 1
return {
"n_total_cells": float(total_cells),
"n_total_puzzles": float(n_total_puzzles),
"parse_rate": float(parse_ok / max(1, total_cells)),
"strict_canonical_rate": float(canonical_ok / max(1, total_cells)),
"exact_set_match_rate": float(exact_set_match / max(1, total_cells)),
"includes_ground_truth_rate": float(includes_gt / max(1, total_cells)),
"value_precision": float(precision_sum / max(1, total_cells)),
"value_recall": float(recall_sum / max(1, total_cells)),
"cardinality_match_rate": float(cardinality_match_sum / max(1, total_cells)),
"puzzle_parse_fail_rate": float(n_parse_fail_puzzles / max(1, n_total_puzzles)),
"solve_rate": float(n_solve) / max(1, n_total_puzzles),
}
# ---- Main -------------------------------------------------------------------
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser()
p.add_argument("--phase", choices=["sft", "grpo"], required=True)
p.add_argument("--model_name", type=str, default="Qwen/Qwen2.5-1.5B-Instruct")
p.add_argument("--train_jsonl", type=str, required=True)
p.add_argument("--eval_jsonl", type=str, required=True)
p.add_argument("--output_dir", type=str, required=True)
p.add_argument("--cache_dir", type=str, default=str(ROOT / ".hf_cache"))
p.add_argument("--init_adapter_dir", type=str, default="")
p.add_argument("--seed", type=int, default=0)
# Data
p.add_argument("--limit_train_rows", type=int, default=10000)
p.add_argument("--eval_rows", type=int, default=100)
# Train hyperparameters
p.add_argument("--per_device_train_batch_size", type=int, default=8)
p.add_argument("--gradient_accumulation_steps", type=int, default=2)
p.add_argument("--learning_rate", type=float, default=5e-5)
p.add_argument("--weight_decay", type=float, default=0.0)
p.add_argument("--num_epochs", type=float, default=8.0)
p.add_argument("--max_steps", type=int, default=2000)
p.add_argument("--logging_steps", type=int, default=25)
p.add_argument("--save_steps", type=int, default=200)
p.add_argument("--eval_steps", type=int, default=150)
p.add_argument("--max_grad_norm", type=float, default=1.0)
p.add_argument("--max_completion_length", type=int, default=96)
p.add_argument("--max_prompt_length", type=int, default=1024)
# LoRA
p.add_argument("--lora_r", type=int, default=32)
p.add_argument("--lora_alpha", type=int, default=64)
p.add_argument("--lora_dropout", type=float, default=0.05)
p.add_argument("--enable_gradient_checkpointing", action="store_true")
# GRPO-only
p.add_argument("--num_generations", type=int, default=8)
p.add_argument("--beta", type=float, default=0.0)
p.add_argument("--temperature", type=float, default=1.0)
p.add_argument("--full_solve_bonus", type=float, default=5.0)
p.add_argument("--length_mismatch_penalty", type=float, default=0.5)
p.add_argument("--parse_penalty", type=float, default=4.0)
# W&B
p.add_argument("--use_wandb", action="store_true")
p.add_argument("--wandb_project", type=str, default="sudoku-strawman-baseline")
p.add_argument("--wandb_run_name", type=str, default="")
p.add_argument("--wandb_mode", type=str, default="offline")
return p.parse_args()
def setup_model_and_tokenizer(args: argparse.Namespace, device: torch.device):
tokenizer = AutoTokenizer.from_pretrained(
args.model_name, cache_dir=args.cache_dir, use_fast=True
)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>"
if tokenizer.padding_side != "left":
tokenizer.padding_side = "left"
model = AutoModelForCausalLM.from_pretrained(
args.model_name,
cache_dir=args.cache_dir,
torch_dtype=pick_dtype(),
low_cpu_mem_usage=True,
)
if str(args.init_adapter_dir).strip():
model = PeftModel.from_pretrained(model, args.init_adapter_dir, is_trainable=True)
else:
lora = LoraConfig(
r=args.lora_r,
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias="none",
task_type="CAUSAL_LM",
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
)
model = get_peft_model(model, lora)
if args.enable_gradient_checkpointing:
if hasattr(model, "gradient_checkpointing_enable"):
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": False}
)
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
if hasattr(model, "config"):
model.config.use_cache = False
model.to(device)
return model, tokenizer
def run_sft(args: argparse.Namespace) -> None:
from trl import SFTConfig, SFTTrainer # type: ignore
set_seed(int(args.seed))
os.makedirs(args.output_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_rows = load_jsonl_rows(args.train_jsonl, limit_rows=int(args.limit_train_rows))
eval_rows = load_jsonl_rows(args.eval_jsonl, limit_rows=int(args.eval_rows))
model, tokenizer = setup_model_and_tokenizer(args, device)
# Build dataset of {prompt, completion} where prompt is chat-templated.
train_ds = build_dataset(train_rows, tokenizer)
cfg = SFTConfig(
output_dir=args.output_dir,
per_device_train_batch_size=int(args.per_device_train_batch_size),
gradient_accumulation_steps=int(args.gradient_accumulation_steps),
learning_rate=float(args.learning_rate),
weight_decay=float(args.weight_decay),
num_train_epochs=float(args.num_epochs),
max_steps=int(args.max_steps),
logging_steps=int(args.logging_steps),
save_steps=int(args.save_steps),
save_strategy="steps",
save_total_limit=4,
eval_strategy="no",
bf16=(pick_dtype() == torch.bfloat16),
fp16=(pick_dtype() == torch.float16),
max_grad_norm=float(args.max_grad_norm),
gradient_checkpointing=bool(args.enable_gradient_checkpointing),
report_to=("wandb" if args.use_wandb else "none"),
run_name=(args.wandb_run_name or None),
max_length=int(args.max_prompt_length + args.max_completion_length + 8),
completion_only_loss=True,
seed=int(args.seed),
)
trainer = SFTTrainer(
model=model,
args=cfg,
train_dataset=train_ds,
processing_class=tokenizer,
)
# Periodic eval hook (TRL doesn't natively give us a custom eval loop hook,
# so we run eval before training and after the final step here).
print("[strawman sft] BEFORE-train eval:", run_eval(model, tokenizer, eval_rows, device), flush=True)
t0 = time.time()
trainer.train()
print(f"[strawman sft] training time = {time.time() - t0:.1f}s", flush=True)
final_dir = os.path.join(args.output_dir, "final")
trainer.save_model(final_dir)
print(f"[strawman sft] saved final adapter to {final_dir}", flush=True)
print("[strawman sft] AFTER-train eval:", run_eval(model, tokenizer, eval_rows, device), flush=True)
def run_grpo(args: argparse.Namespace) -> None:
from trl import GRPOConfig, GRPOTrainer # type: ignore
set_seed(int(args.seed))
os.makedirs(args.output_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_rows = load_jsonl_rows(args.train_jsonl, limit_rows=int(args.limit_train_rows))
eval_rows = load_jsonl_rows(args.eval_jsonl, limit_rows=int(args.eval_rows))
model, tokenizer = setup_model_and_tokenizer(args, device)
train_ds = build_dataset(train_rows, tokenizer)
parse_penalty = float(args.parse_penalty)
length_mismatch_penalty = float(args.length_mismatch_penalty)
full_solve_bonus = float(args.full_solve_bonus)
def reward_fn(completions, target, **kwargs):
rewards: List[float] = []
for c, tgt in zip(completions, target):
tgt_list = json.loads(tgt) if isinstance(tgt, str) else list(tgt)
pred = parse_int_list(str(c))
rewards.append(
whole_puzzle_reward(
pred_list=pred,
target_list=tgt_list,
parse_penalty=parse_penalty,
length_mismatch_penalty=length_mismatch_penalty,
full_solve_bonus=full_solve_bonus,
)
)
return rewards
cfg = GRPOConfig(
output_dir=args.output_dir,
per_device_train_batch_size=int(args.per_device_train_batch_size),
gradient_accumulation_steps=int(args.gradient_accumulation_steps),
learning_rate=float(args.learning_rate),
weight_decay=float(args.weight_decay),
num_train_epochs=float(args.num_epochs),
max_steps=int(args.max_steps),
logging_steps=int(args.logging_steps),
save_steps=int(args.save_steps),
save_strategy="steps",
save_total_limit=6,
bf16=(pick_dtype() == torch.bfloat16),
fp16=(pick_dtype() == torch.float16),
max_grad_norm=float(args.max_grad_norm),
gradient_checkpointing=bool(args.enable_gradient_checkpointing),
report_to=("wandb" if args.use_wandb else "none"),
run_name=(args.wandb_run_name or None),
max_prompt_length=int(args.max_prompt_length),
max_completion_length=int(args.max_completion_length),
num_generations=int(args.num_generations),
beta=float(args.beta),
temperature=float(args.temperature),
seed=int(args.seed),
)
trainer = GRPOTrainer(
model=model,
reward_funcs=[reward_fn],
args=cfg,
train_dataset=train_ds,
processing_class=tokenizer,
)
print("[strawman grpo] BEFORE-train eval:", run_eval(model, tokenizer, eval_rows, device), flush=True)
t0 = time.time()
trainer.train()
print(f"[strawman grpo] training time = {time.time() - t0:.1f}s", flush=True)
final_dir = os.path.join(args.output_dir, "final")
trainer.save_model(final_dir)
print(f"[strawman grpo] saved final adapter to {final_dir}", flush=True)
print("[strawman grpo] AFTER-train eval:", run_eval(model, tokenizer, eval_rows, device), flush=True)
def main() -> None:
args = parse_args()
if args.use_wandb:
os.environ.setdefault("WANDB_MODE", str(args.wandb_mode))
os.environ["WANDB_PROJECT"] = args.wandb_project
if args.phase == "sft":
run_sft(args)
else:
run_grpo(args)
if __name__ == "__main__":
main()