PEFT
qlora
sft
trl
qwen3
tmf921
intent-based-networking
network-slicing
rtx-6000-ada
ml-intern
tmf921-intent-training / scripts /evaluate_grpo_correct.py
nraptisss's picture
Add correct GRPO evaluation script that loads SFT-merged base + GRPO adapter
9b7e923 verified
#!/usr/bin/env python3
"""Correct evaluation of GRPO adapter on top of SFT-merged base.
The GRPO adapter was trained on: SFT-merged Qwen3-8B (not raw Qwen3-8B).
Previous evaluations incorrectly loaded: raw Qwen3-8B + GRPO adapter.
This script loads the correct stack: raw Qwen3-8B + SFT adapter (merged) + GRPO adapter.
Usage:
export PYTHONPATH="$PWD/src"
python scripts/evaluate_grpo_correct.py
"""
import os
import gc
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# ---- Configuration ----
BASE_MODEL = "Qwen/Qwen3-8B"
SFT_ADAPTER = "nraptisss/Qwen3-8B-TMF921-Intent-QLoRA-qwen3-8b-qlora-20260501-083834"
GRPO_ADAPTER = "outputs/qwen3-8b-tmf921-grpo-v3" # local path from training
DATASET = "nraptisss/TMF921-intent-to-config-research-sota"
OUTPUT_DIR = "outputs/qwen3-8b-tmf921-grpo-v3/eval_correct"
MAX_SAMPLES = 100
BATCH_SIZE = 4
def main():
print("=" * 60)
print("GRPO Correct Evaluation")
print("Model stack: Qwen3-8B + SFT adapter (merged) + GRPO adapter")
print("=" * 60)
# Step 1: Load base model in 4-bit
print("\nStep 1: Loading base model in 4-bit...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
trust_remote_code=True,
quantization_config=bnb_config,
device_map={"": 0},
torch_dtype=torch.bfloat16,
)
print(f" GPU memory after base: {torch.cuda.memory_allocated()/1e9:.1f} GB")
# Step 2: Load and merge SFT adapter (recreate the GRPO training base)
print("\nStep 2: Loading SFT adapter and merging...")
model = PeftModel.from_pretrained(model, SFT_ADAPTER)
model = model.merge_and_unload()
gc.collect()
torch.cuda.empty_cache()
print(f" GPU memory after SFT merge: {torch.cuda.memory_allocated()/1e9:.1f} GB")
# Step 3: Load GRPO adapter on top of merged SFT
print("\nStep 3: Loading GRPO adapter on top of SFT-merged base...")
model = PeftModel.from_pretrained(model, GRPO_ADAPTER)
print(f" GPU memory after GRPO adapter: {torch.cuda.memory_allocated()/1e9:.1f} GB")
model.eval()
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
print("\nModel stack loaded correctly!")
print("Now running evaluation via evaluate_model.py logic...\n")
# Step 4: Run evaluation using the same logic as evaluate_model.py
from datasets import load_dataset
from tmf921_train.utils import aggregate_metrics, field_f1, get_message, json_exact_match, metadata_constraint_pass, parse_json, write_json
from collections import defaultdict
from pathlib import Path
from tqdm import tqdm
import json
ds = load_dataset(DATASET)
splits = ["test_in_distribution", "test_template_ood", "test_use_case_ood", "test_sector_ood", "test_adversarial"]
out_dir = Path(OUTPUT_DIR)
out_dir.mkdir(parents=True, exist_ok=True)
all_summary = {}
for split in splits:
split_ds = ds[split]
if MAX_SAMPLES and MAX_SAMPLES < len(split_ds):
split_ds = split_ds.select(range(MAX_SAMPLES))
split_dir = out_dir / split
split_dir.mkdir(parents=True, exist_ok=True)
rows = []
print(f"Evaluating {split}: {len(split_ds)} examples, batch_size={BATCH_SIZE}")
for start in tqdm(range(0, len(split_ds), BATCH_SIZE), desc=split):
batch = [split_ds[i] for i in range(start, min(start + BATCH_SIZE, len(split_ds)))]
# Build prompts (system + user, no assistant)
texts = []
for ex in batch:
messages = ex["messages"]
prompt_msgs = [m for m in messages if m["role"] != "assistant"]
text = tokenizer.apply_chat_template(prompt_msgs, tokenize=False, add_generation_prompt=True)
texts.append(text)
# Tokenize with left padding
old_side = tokenizer.padding_side
tokenizer.padding_side = "left"
inputs = tokenizer(texts, return_tensors="pt", padding=True).to(model.device)
tokenizer.padding_side = old_side
# Generate
with torch.inference_mode():
outputs = model.generate(
**inputs,
max_new_tokens=1536,
do_sample=False,
temperature=None,
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
# Decode only new tokens
new_tokens = outputs[:, inputs["input_ids"].shape[1]:]
predictions = tokenizer.batch_decode(new_tokens, skip_special_tokens=True)
# Score each example
for ex, pred in zip(batch, predictions):
pred = pred.strip()
gold = get_message(ex["messages"], "assistant")
pred_obj, pred_err = parse_json(pred)
gold_obj, gold_err = parse_json(gold)
row = {
"id": ex.get("id"),
"target_layer": ex.get("target_layer"),
"slice_type": ex.get("slice_type"),
"lifecycle_operation": ex.get("lifecycle_operation"),
"parse_json": pred_obj is not None,
"gold_parse_json": gold_obj is not None,
"exact_match": False,
"prediction": pred,
"gold": gold,
}
if pred_obj is not None and gold_obj is not None:
row["exact_match"] = json_exact_match(pred_obj, gold_obj)
row.update(field_f1(pred_obj, gold_obj))
row.update(metadata_constraint_pass(ex, pred, pred_obj))
else:
row.update({"field_precision": 0.0, "field_recall": 0.0, "field_f1": 0.0,
"field_tp": 0, "field_fp": 0, "field_fn": 0})
row.update({"slice_sst_pass": False, "kpi_text_presence_pass": False, "adversarial_status_pass": False})
rows.append(row)
# Save split results
write_json(split_dir / "predictions.json", rows)
summary = aggregate_metrics(rows)
# Per-layer breakdown
groups = defaultdict(list)
for r in rows:
groups[str(r.get("target_layer"))].append(r)
summary["by_target_layer"] = {g: aggregate_metrics(v) for g, v in sorted(groups.items())}
write_json(split_dir / "metrics.json", summary)
all_summary[split] = summary
# Print key metrics
print(f" {split}: parse={summary.get('parse_json', 0):.3f}, field_f1={summary.get('field_f1', 0):.4f}")
write_json(out_dir / "all_metrics.json", all_summary)
# Print final summary
print("\n" + "=" * 60)
print("GRPO v3 CORRECT Evaluation Results")
print("=" * 60)
print(f"{'Split':<25} {'Parse':>8} {'Field F1':>10} {'Exact':>8}")
print("-" * 55)
for split, s in all_summary.items():
print(f"{split:<25} {s.get('parse_json',0):>8.3f} {s.get('field_f1',0):>10.4f} {s.get('exact_match',0):>8.4f}")
print("=" * 60)
print(f"\nResults saved to: {OUTPUT_DIR}")
print("Compare with SFT Stage 1: parse=1.0, field_f1=0.687")
if __name__ == "__main__":
main()