curriculum-cot-code / analysis /eval_saved_hard9x9_checkpoints.py
Avra98's picture
Initial code dump (rebuttal-ready snapshot)
76de008 verified
from __future__ import annotations
import json
import os
from types import SimpleNamespace
import sys
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
if ROOT := "/home/ubuntu/curriculum_cot":
if ROOT not in sys.path:
sys.path.insert(0, ROOT)
from multi_output_cell_policy import grpo_multi_output_train as baseline_grpo
from multi_output_cell_policy import sft_multi_output_train as baseline_sft
from latent_multi_output_cell_policy import grpo_residual_projector_latent_train as latent_grpo
from latent_multi_output_cell_policy import residual_projector_warmstart_sft_latent_multi_output_train as latent_sft
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
CACHE_DIR = os.path.join(ROOT, ".hf_cache")
DATA_PATH = os.path.join(ROOT, "data", "sudoku_t3_30empty_value_qwen_text.jsonl")
EVAL_ROWS = 20
TOTAL_EMPTIES_HINT = 30
def make_tokenizer() -> AutoTokenizer:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR, use_fast=True)
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>"
return tokenizer
def make_device() -> torch.device:
return torch.device("cuda" if torch.cuda.is_available() else "cpu")
def make_baseline_sft_model(checkpoint_dir: str, device: torch.device) -> torch.nn.Module:
base = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
cache_dir=CACHE_DIR,
torch_dtype=baseline_sft.pick_dtype() if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True,
)
model = PeftModel.from_pretrained(base, checkpoint_dir, is_trainable=False)
model.to(device)
model.eval()
return model
def make_baseline_grpo_model(checkpoint_dir: str, device: torch.device) -> torch.nn.Module:
base = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
cache_dir=CACHE_DIR,
torch_dtype=baseline_grpo.pick_dtype() if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True,
)
model = baseline_grpo.load_trainable_adapter(base, checkpoint_dir)
model.to(device)
model.eval()
return model
def make_latent_model(checkpoint_dir: str, device: torch.device) -> torch.nn.Module:
base = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
cache_dir=CACHE_DIR,
torch_dtype=latent_grpo.pick_dtype() if torch.cuda.is_available() else torch.float32,
low_cpu_mem_usage=True,
)
model = latent_grpo.load_trainable_adapter(base, checkpoint_dir)
projector_hidden = latent_grpo.infer_projector_hidden_from_state(checkpoint_dir) or latent_grpo.PROJECTOR_HIDDEN
latent_grpo.attach_residual_projector_modules(
model,
hidden_size=int(latent_grpo.unwrap_backbone(model).config.hidden_size),
projector_hidden=projector_hidden,
)
latent_grpo.maybe_load_projector_state(model, checkpoint_dir)
model.to(device)
model.eval()
return model
def common_reward_args() -> dict:
return {
"reward_good_value": 1.0,
"penalty_bad_value": 1.75,
"penalty_malformed": 4.0,
"penalty_empty": 0.5,
"penalty_singleton": 1.5,
}
def eval_baseline_sft(checkpoint_dir: str, stage_i: int) -> dict:
device = make_device()
tokenizer = make_tokenizer()
model = make_baseline_sft_model(checkpoint_dir, device)
rows = baseline_sft.load_jsonl_rows(DATA_PATH, limit_rows=EVAL_ROWS)
args = SimpleNamespace(
stage_i=int(stage_i),
total_empties_hint=TOTAL_EMPTIES_HINT,
max_completion_length=24,
debug_print_limit=0,
)
metrics = baseline_sft.run_eval(args, rows, model, tokenizer, device)
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
return metrics
def eval_baseline_grpo(checkpoint_dir: str, stage_i: int) -> dict:
device = make_device()
tokenizer = make_tokenizer()
model = make_baseline_grpo_model(checkpoint_dir, device)
rows = baseline_grpo.load_jsonl_rows(DATA_PATH, limit_rows=EVAL_ROWS)
args = SimpleNamespace(
stage_i=int(stage_i),
total_empties_hint=TOTAL_EMPTIES_HINT,
max_completion_length=24,
debug_print_limit=0,
**common_reward_args(),
)
metrics = baseline_grpo.run_eval(args=args, rows=rows, model=model, tokenizer=tokenizer, device=device)
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
return metrics
def eval_latent_sft(checkpoint_dir: str, stage_i: int, num_cot_tokens: int) -> dict:
device = make_device()
tokenizer = make_tokenizer()
model = make_latent_model(checkpoint_dir, device)
rows = baseline_sft.load_jsonl_rows(DATA_PATH, limit_rows=EVAL_ROWS)
args = SimpleNamespace(
stage_i=int(stage_i),
num_cot_tokens=int(num_cot_tokens),
total_empties_hint=TOTAL_EMPTIES_HINT,
max_completion_length=32,
debug_print_limit=0,
**common_reward_args(),
)
metrics = latent_sft.run_eval(args=args, rows=rows, model=model, tokenizer=tokenizer, device=device, eval_stage_i=int(stage_i))
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
return metrics
def eval_latent_grpo(checkpoint_dir: str, stage_i: int, num_cot_tokens: int) -> dict:
device = make_device()
tokenizer = make_tokenizer()
model = make_latent_model(checkpoint_dir, device)
rows = latent_grpo.load_jsonl_rows(DATA_PATH, limit_rows=EVAL_ROWS)
args = SimpleNamespace(
stage_i=int(stage_i),
num_cot_tokens=int(num_cot_tokens),
total_empties_hint=TOTAL_EMPTIES_HINT,
max_completion_length=32,
debug_print_limit=0,
**common_reward_args(),
)
metrics = latent_grpo.run_eval(args=args, rows=rows, model=model, tokenizer=tokenizer, device=device, eval_stage_i=int(stage_i))
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
return metrics
def main() -> None:
# Explicit step dirs (not run roots) so metrics match the agreed endpoints.
checkpoints = [
{
"label": "baseline_stage1_sft",
"stage_i": 1,
"kind": "baseline_sft",
"checkpoint_dir": os.path.join(
ROOT,
"final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/20260404_023600_baseline30_clean/baseline_pipeline_30empty_4stage_hard9x9/stage01_sft_i1_30empty/checkpoint-step-01000",
),
},
{
"label": "baseline_stage1_grpo",
"stage_i": 1,
"kind": "baseline_grpo",
"checkpoint_dir": os.path.join(
ROOT,
"final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/grpo/i1_20260404_fixed_baseline_grpo_i1/checkpoint-5350",
),
},
{
"label": "baseline_stage2_sft",
"stage_i": 2,
"kind": "baseline_sft",
"checkpoint_dir": os.path.join(
ROOT,
"final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/sft/i2_20260404_stage2_baseline_sft_from_grpo5350/checkpoint-step-13100",
),
},
{
"label": "baseline_stage2_grpo",
"stage_i": 2,
"kind": "baseline_grpo",
"checkpoint_dir": os.path.join(
ROOT,
"final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/grpo/i2_20260405_stage2_baseline_grpo_from_sft13100/checkpoint-4325",
),
},
{
"label": "latent_stage1_sft",
"stage_i": 1,
"kind": "latent_sft",
"num_cot_tokens": 1,
"checkpoint_dir": os.path.join(
ROOT,
"final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/20260404_013500_latent30_frombaseline/latent_pipeline_30empty_4stage_hard9x9/stage01_sft_i1_30empty_residual_projector/checkpoint-step-00200",
),
},
{
"label": "latent_stage1_grpo",
"stage_i": 1,
"kind": "latent_grpo",
"num_cot_tokens": 1,
"checkpoint_dir": os.path.join(
ROOT,
"final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/grpo/i1_cot1_20260404_fixed_latent_grpo_i1/checkpoint-2740",
),
},
{
"label": "latent_stage2_sft",
"stage_i": 2,
"kind": "latent_sft",
"num_cot_tokens": 2,
"checkpoint_dir": os.path.join(
ROOT,
"final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/sft/i2_cot2_20260404_stage2_latent_sft_from_grpo2740/checkpoint-step-00700",
),
},
{
"label": "latent_stage2_grpo",
"stage_i": 2,
"kind": "latent_grpo",
"num_cot_tokens": 2,
"checkpoint_dir": os.path.join(
ROOT,
"final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/grpo/i2_cot2_20260405_stage2_latent_grpo_from_sft00700/checkpoint-1620",
),
},
]
results: dict[str, dict] = {}
for item in checkpoints:
label = item["label"]
print(f"[eval] starting {label}", flush=True)
if item["kind"] == "baseline_sft":
metrics = eval_baseline_sft(item["checkpoint_dir"], item["stage_i"])
elif item["kind"] == "baseline_grpo":
metrics = eval_baseline_grpo(item["checkpoint_dir"], item["stage_i"])
elif item["kind"] == "latent_sft":
metrics = eval_latent_sft(item["checkpoint_dir"], item["stage_i"], item["num_cot_tokens"])
else:
metrics = eval_latent_grpo(item["checkpoint_dir"], item["stage_i"], item["num_cot_tokens"])
results[label] = metrics
print(json.dumps({"label": label, "metrics": metrics}, sort_keys=True), flush=True)
print("[eval] complete", flush=True)
print(json.dumps(results, sort_keys=True, indent=2), flush=True)
if __name__ == "__main__":
main()