#!/usr/bin/env python3 """Correct evaluation of GRPO adapter on top of SFT-merged base. The GRPO adapter was trained on: SFT-merged Qwen3-8B (not raw Qwen3-8B). Previous evaluations incorrectly loaded: raw Qwen3-8B + GRPO adapter. This script loads the correct stack: raw Qwen3-8B + SFT adapter (merged) + GRPO adapter. Usage: export PYTHONPATH="$PWD/src" python scripts/evaluate_grpo_correct.py """ import os import gc import torch from peft import PeftModel from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig # ---- Configuration ---- BASE_MODEL = "Qwen/Qwen3-8B" SFT_ADAPTER = "nraptisss/Qwen3-8B-TMF921-Intent-QLoRA-qwen3-8b-qlora-20260501-083834" GRPO_ADAPTER = "outputs/qwen3-8b-tmf921-grpo-v3" # local path from training DATASET = "nraptisss/TMF921-intent-to-config-research-sota" OUTPUT_DIR = "outputs/qwen3-8b-tmf921-grpo-v3/eval_correct" MAX_SAMPLES = 100 BATCH_SIZE = 4 def main(): print("=" * 60) print("GRPO Correct Evaluation") print("Model stack: Qwen3-8B + SFT adapter (merged) + GRPO adapter") print("=" * 60) # Step 1: Load base model in 4-bit print("\nStep 1: Loading base model in 4-bit...") bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16, ) model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, trust_remote_code=True, quantization_config=bnb_config, device_map={"": 0}, torch_dtype=torch.bfloat16, ) print(f" GPU memory after base: {torch.cuda.memory_allocated()/1e9:.1f} GB") # Step 2: Load and merge SFT adapter (recreate the GRPO training base) print("\nStep 2: Loading SFT adapter and merging...") model = PeftModel.from_pretrained(model, SFT_ADAPTER) model = model.merge_and_unload() gc.collect() torch.cuda.empty_cache() print(f" GPU memory after SFT merge: {torch.cuda.memory_allocated()/1e9:.1f} GB") # Step 3: Load GRPO adapter on top of merged SFT print("\nStep 3: Loading GRPO adapter on top of SFT-merged base...") model = PeftModel.from_pretrained(model, GRPO_ADAPTER) print(f" GPU memory after GRPO adapter: {torch.cuda.memory_allocated()/1e9:.1f} GB") model.eval() # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token print("\nModel stack loaded correctly!") print("Now running evaluation via evaluate_model.py logic...\n") # Step 4: Run evaluation using the same logic as evaluate_model.py from datasets import load_dataset from tmf921_train.utils import aggregate_metrics, field_f1, get_message, json_exact_match, metadata_constraint_pass, parse_json, write_json from collections import defaultdict from pathlib import Path from tqdm import tqdm import json ds = load_dataset(DATASET) splits = ["test_in_distribution", "test_template_ood", "test_use_case_ood", "test_sector_ood", "test_adversarial"] out_dir = Path(OUTPUT_DIR) out_dir.mkdir(parents=True, exist_ok=True) all_summary = {} for split in splits: split_ds = ds[split] if MAX_SAMPLES and MAX_SAMPLES < len(split_ds): split_ds = split_ds.select(range(MAX_SAMPLES)) split_dir = out_dir / split split_dir.mkdir(parents=True, exist_ok=True) rows = [] print(f"Evaluating {split}: {len(split_ds)} examples, batch_size={BATCH_SIZE}") for start in tqdm(range(0, len(split_ds), BATCH_SIZE), desc=split): batch = [split_ds[i] for i in range(start, min(start + BATCH_SIZE, len(split_ds)))] # Build prompts (system + user, no assistant) texts = [] for ex in batch: messages = ex["messages"] prompt_msgs = [m for m in messages if m["role"] != "assistant"] text = tokenizer.apply_chat_template(prompt_msgs, tokenize=False, add_generation_prompt=True) texts.append(text) # Tokenize with left padding old_side = tokenizer.padding_side tokenizer.padding_side = "left" inputs = tokenizer(texts, return_tensors="pt", padding=True).to(model.device) tokenizer.padding_side = old_side # Generate with torch.inference_mode(): outputs = model.generate( **inputs, max_new_tokens=1536, do_sample=False, temperature=None, pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) # Decode only new tokens new_tokens = outputs[:, inputs["input_ids"].shape[1]:] predictions = tokenizer.batch_decode(new_tokens, skip_special_tokens=True) # Score each example for ex, pred in zip(batch, predictions): pred = pred.strip() gold = get_message(ex["messages"], "assistant") pred_obj, pred_err = parse_json(pred) gold_obj, gold_err = parse_json(gold) row = { "id": ex.get("id"), "target_layer": ex.get("target_layer"), "slice_type": ex.get("slice_type"), "lifecycle_operation": ex.get("lifecycle_operation"), "parse_json": pred_obj is not None, "gold_parse_json": gold_obj is not None, "exact_match": False, "prediction": pred, "gold": gold, } if pred_obj is not None and gold_obj is not None: row["exact_match"] = json_exact_match(pred_obj, gold_obj) row.update(field_f1(pred_obj, gold_obj)) row.update(metadata_constraint_pass(ex, pred, pred_obj)) else: row.update({"field_precision": 0.0, "field_recall": 0.0, "field_f1": 0.0, "field_tp": 0, "field_fp": 0, "field_fn": 0}) row.update({"slice_sst_pass": False, "kpi_text_presence_pass": False, "adversarial_status_pass": False}) rows.append(row) # Save split results write_json(split_dir / "predictions.json", rows) summary = aggregate_metrics(rows) # Per-layer breakdown groups = defaultdict(list) for r in rows: groups[str(r.get("target_layer"))].append(r) summary["by_target_layer"] = {g: aggregate_metrics(v) for g, v in sorted(groups.items())} write_json(split_dir / "metrics.json", summary) all_summary[split] = summary # Print key metrics print(f" {split}: parse={summary.get('parse_json', 0):.3f}, field_f1={summary.get('field_f1', 0):.4f}") write_json(out_dir / "all_metrics.json", all_summary) # Print final summary print("\n" + "=" * 60) print("GRPO v3 CORRECT Evaluation Results") print("=" * 60) print(f"{'Split':<25} {'Parse':>8} {'Field F1':>10} {'Exact':>8}") print("-" * 55) for split, s in all_summary.items(): print(f"{split:<25} {s.get('parse_json',0):>8.3f} {s.get('field_f1',0):>10.4f} {s.get('exact_match',0):>8.4f}") print("=" * 60) print(f"\nResults saved to: {OUTPUT_DIR}") print("Compare with SFT Stage 1: parse=1.0, field_f1=0.687") if __name__ == "__main__": main()