| from __future__ import annotations |
|
|
| import json |
| import os |
| from types import SimpleNamespace |
| import sys |
|
|
| import torch |
| from peft import PeftModel |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| if ROOT := "/home/ubuntu/curriculum_cot": |
| if ROOT not in sys.path: |
| sys.path.insert(0, ROOT) |
|
|
| from multi_output_cell_policy import grpo_multi_output_train as baseline_grpo |
| from multi_output_cell_policy import sft_multi_output_train as baseline_sft |
| from latent_multi_output_cell_policy import grpo_residual_projector_latent_train as latent_grpo |
| from latent_multi_output_cell_policy import residual_projector_warmstart_sft_latent_multi_output_train as latent_sft |
|
|
| MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct" |
| CACHE_DIR = os.path.join(ROOT, ".hf_cache") |
| DATA_PATH = os.path.join(ROOT, "data", "sudoku_t3_30empty_value_qwen_text.jsonl") |
| EVAL_ROWS = 20 |
| TOTAL_EMPTIES_HINT = 30 |
|
|
|
|
| def make_tokenizer() -> AutoTokenizer: |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR, use_fast=True) |
| if tokenizer.pad_token_id is None: |
| tokenizer.pad_token = tokenizer.eos_token or "<|endoftext|>" |
| return tokenizer |
|
|
|
|
| def make_device() -> torch.device: |
| return torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
| def make_baseline_sft_model(checkpoint_dir: str, device: torch.device) -> torch.nn.Module: |
| base = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| cache_dir=CACHE_DIR, |
| torch_dtype=baseline_sft.pick_dtype() if torch.cuda.is_available() else torch.float32, |
| low_cpu_mem_usage=True, |
| ) |
| model = PeftModel.from_pretrained(base, checkpoint_dir, is_trainable=False) |
| model.to(device) |
| model.eval() |
| return model |
|
|
|
|
| def make_baseline_grpo_model(checkpoint_dir: str, device: torch.device) -> torch.nn.Module: |
| base = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| cache_dir=CACHE_DIR, |
| torch_dtype=baseline_grpo.pick_dtype() if torch.cuda.is_available() else torch.float32, |
| low_cpu_mem_usage=True, |
| ) |
| model = baseline_grpo.load_trainable_adapter(base, checkpoint_dir) |
| model.to(device) |
| model.eval() |
| return model |
|
|
|
|
| def make_latent_model(checkpoint_dir: str, device: torch.device) -> torch.nn.Module: |
| base = AutoModelForCausalLM.from_pretrained( |
| MODEL_NAME, |
| cache_dir=CACHE_DIR, |
| torch_dtype=latent_grpo.pick_dtype() if torch.cuda.is_available() else torch.float32, |
| low_cpu_mem_usage=True, |
| ) |
| model = latent_grpo.load_trainable_adapter(base, checkpoint_dir) |
| projector_hidden = latent_grpo.infer_projector_hidden_from_state(checkpoint_dir) or latent_grpo.PROJECTOR_HIDDEN |
| latent_grpo.attach_residual_projector_modules( |
| model, |
| hidden_size=int(latent_grpo.unwrap_backbone(model).config.hidden_size), |
| projector_hidden=projector_hidden, |
| ) |
| latent_grpo.maybe_load_projector_state(model, checkpoint_dir) |
| model.to(device) |
| model.eval() |
| return model |
|
|
|
|
| def common_reward_args() -> dict: |
| return { |
| "reward_good_value": 1.0, |
| "penalty_bad_value": 1.75, |
| "penalty_malformed": 4.0, |
| "penalty_empty": 0.5, |
| "penalty_singleton": 1.5, |
| } |
|
|
|
|
| def eval_baseline_sft(checkpoint_dir: str, stage_i: int) -> dict: |
| device = make_device() |
| tokenizer = make_tokenizer() |
| model = make_baseline_sft_model(checkpoint_dir, device) |
| rows = baseline_sft.load_jsonl_rows(DATA_PATH, limit_rows=EVAL_ROWS) |
| args = SimpleNamespace( |
| stage_i=int(stage_i), |
| total_empties_hint=TOTAL_EMPTIES_HINT, |
| max_completion_length=24, |
| debug_print_limit=0, |
| ) |
| metrics = baseline_sft.run_eval(args, rows, model, tokenizer, device) |
| del model |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| return metrics |
|
|
|
|
| def eval_baseline_grpo(checkpoint_dir: str, stage_i: int) -> dict: |
| device = make_device() |
| tokenizer = make_tokenizer() |
| model = make_baseline_grpo_model(checkpoint_dir, device) |
| rows = baseline_grpo.load_jsonl_rows(DATA_PATH, limit_rows=EVAL_ROWS) |
| args = SimpleNamespace( |
| stage_i=int(stage_i), |
| total_empties_hint=TOTAL_EMPTIES_HINT, |
| max_completion_length=24, |
| debug_print_limit=0, |
| **common_reward_args(), |
| ) |
| metrics = baseline_grpo.run_eval(args=args, rows=rows, model=model, tokenizer=tokenizer, device=device) |
| del model |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| return metrics |
|
|
|
|
| def eval_latent_sft(checkpoint_dir: str, stage_i: int, num_cot_tokens: int) -> dict: |
| device = make_device() |
| tokenizer = make_tokenizer() |
| model = make_latent_model(checkpoint_dir, device) |
| rows = baseline_sft.load_jsonl_rows(DATA_PATH, limit_rows=EVAL_ROWS) |
| args = SimpleNamespace( |
| stage_i=int(stage_i), |
| num_cot_tokens=int(num_cot_tokens), |
| total_empties_hint=TOTAL_EMPTIES_HINT, |
| max_completion_length=32, |
| debug_print_limit=0, |
| **common_reward_args(), |
| ) |
| metrics = latent_sft.run_eval(args=args, rows=rows, model=model, tokenizer=tokenizer, device=device, eval_stage_i=int(stage_i)) |
| del model |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| return metrics |
|
|
|
|
| def eval_latent_grpo(checkpoint_dir: str, stage_i: int, num_cot_tokens: int) -> dict: |
| device = make_device() |
| tokenizer = make_tokenizer() |
| model = make_latent_model(checkpoint_dir, device) |
| rows = latent_grpo.load_jsonl_rows(DATA_PATH, limit_rows=EVAL_ROWS) |
| args = SimpleNamespace( |
| stage_i=int(stage_i), |
| num_cot_tokens=int(num_cot_tokens), |
| total_empties_hint=TOTAL_EMPTIES_HINT, |
| max_completion_length=32, |
| debug_print_limit=0, |
| **common_reward_args(), |
| ) |
| metrics = latent_grpo.run_eval(args=args, rows=rows, model=model, tokenizer=tokenizer, device=device, eval_stage_i=int(stage_i)) |
| del model |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
| return metrics |
|
|
|
|
| def main() -> None: |
| |
| checkpoints = [ |
| { |
| "label": "baseline_stage1_sft", |
| "stage_i": 1, |
| "kind": "baseline_sft", |
| "checkpoint_dir": os.path.join( |
| ROOT, |
| "final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/20260404_023600_baseline30_clean/baseline_pipeline_30empty_4stage_hard9x9/stage01_sft_i1_30empty/checkpoint-step-01000", |
| ), |
| }, |
| { |
| "label": "baseline_stage1_grpo", |
| "stage_i": 1, |
| "kind": "baseline_grpo", |
| "checkpoint_dir": os.path.join( |
| ROOT, |
| "final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/grpo/i1_20260404_fixed_baseline_grpo_i1/checkpoint-5350", |
| ), |
| }, |
| { |
| "label": "baseline_stage2_sft", |
| "stage_i": 2, |
| "kind": "baseline_sft", |
| "checkpoint_dir": os.path.join( |
| ROOT, |
| "final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/sft/i2_20260404_stage2_baseline_sft_from_grpo5350/checkpoint-step-13100", |
| ), |
| }, |
| { |
| "label": "baseline_stage2_grpo", |
| "stage_i": 2, |
| "kind": "baseline_grpo", |
| "checkpoint_dir": os.path.join( |
| ROOT, |
| "final_checkpoint/large_baseline_extension/hard_9x9_qwen05b/baseline/grpo/i2_20260405_stage2_baseline_grpo_from_sft13100/checkpoint-4325", |
| ), |
| }, |
| { |
| "label": "latent_stage1_sft", |
| "stage_i": 1, |
| "kind": "latent_sft", |
| "num_cot_tokens": 1, |
| "checkpoint_dir": os.path.join( |
| ROOT, |
| "final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/20260404_013500_latent30_frombaseline/latent_pipeline_30empty_4stage_hard9x9/stage01_sft_i1_30empty_residual_projector/checkpoint-step-00200", |
| ), |
| }, |
| { |
| "label": "latent_stage1_grpo", |
| "stage_i": 1, |
| "kind": "latent_grpo", |
| "num_cot_tokens": 1, |
| "checkpoint_dir": os.path.join( |
| ROOT, |
| "final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/grpo/i1_cot1_20260404_fixed_latent_grpo_i1/checkpoint-2740", |
| ), |
| }, |
| { |
| "label": "latent_stage2_sft", |
| "stage_i": 2, |
| "kind": "latent_sft", |
| "num_cot_tokens": 2, |
| "checkpoint_dir": os.path.join( |
| ROOT, |
| "final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/sft/i2_cot2_20260404_stage2_latent_sft_from_grpo2740/checkpoint-step-00700", |
| ), |
| }, |
| { |
| "label": "latent_stage2_grpo", |
| "stage_i": 2, |
| "kind": "latent_grpo", |
| "num_cot_tokens": 2, |
| "checkpoint_dir": os.path.join( |
| ROOT, |
| "final_checkpoint/large_latent_extension/hard_9x9_qwen05b/latent/grpo/i2_cot2_20260405_stage2_latent_grpo_from_sft00700/checkpoint-1620", |
| ), |
| }, |
| ] |
|
|
| results: dict[str, dict] = {} |
| for item in checkpoints: |
| label = item["label"] |
| print(f"[eval] starting {label}", flush=True) |
| if item["kind"] == "baseline_sft": |
| metrics = eval_baseline_sft(item["checkpoint_dir"], item["stage_i"]) |
| elif item["kind"] == "baseline_grpo": |
| metrics = eval_baseline_grpo(item["checkpoint_dir"], item["stage_i"]) |
| elif item["kind"] == "latent_sft": |
| metrics = eval_latent_sft(item["checkpoint_dir"], item["stage_i"], item["num_cot_tokens"]) |
| else: |
| metrics = eval_latent_grpo(item["checkpoint_dir"], item["stage_i"], item["num_cot_tokens"]) |
| results[label] = metrics |
| print(json.dumps({"label": label, "metrics": metrics}, sort_keys=True), flush=True) |
|
|
| print("[eval] complete", flush=True) |
| print(json.dumps(results, sort_keys=True, indent=2), flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|