Add trained model evaluation script
Browse files- scripts/eval_trained_model.py +290 -0
scripts/eval_trained_model.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Evaluate a trained HF adapter model against heuristic and random baselines
|
| 3 |
+
on the Wildfire Containment Simulator.
|
| 4 |
+
|
| 5 |
+
Saves results to scripts/trained_results.json.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python scripts/eval_trained_model.py --model-path Eshit/wildfire-grpo-7b
|
| 9 |
+
python scripts/eval_trained_model.py --model-path ./grpo_final --num-seeds 10
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
from __future__ import annotations
|
| 13 |
+
|
| 14 |
+
import argparse
|
| 15 |
+
import json
|
| 16 |
+
import os
|
| 17 |
+
import sys
|
| 18 |
+
import time
|
| 19 |
+
from pathlib import Path
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
|
| 23 |
+
PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
|
| 24 |
+
sys.path.insert(0, PROJECT_ROOT)
|
| 25 |
+
|
| 26 |
+
from env.wildfire_env import WildfireEnv
|
| 27 |
+
from env.serialization import serialize_observation
|
| 28 |
+
from env.action_parser import parse_action
|
| 29 |
+
from env.models import TIER_EASY, TIER_MEDIUM, TIER_HARD
|
| 30 |
+
from agents.heuristic_agent import HeuristicAgent
|
| 31 |
+
from agents.random_agent import RandomAgent
|
| 32 |
+
|
| 33 |
+
TIER_MAX_STEPS = {
|
| 34 |
+
"easy": TIER_EASY.episode_length,
|
| 35 |
+
"medium": TIER_MEDIUM.episode_length,
|
| 36 |
+
"hard": TIER_HARD.episode_length,
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
SYSTEM_PROMPT = (
|
| 40 |
+
"You are an AI Incident Commander managing wildfire containment. "
|
| 41 |
+
"You will receive a situation briefing each step. "
|
| 42 |
+
"Respond with ONLY a valid JSON action object and nothing else. "
|
| 43 |
+
'Example: {"action_type": "idle"}'
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class LLMAgent:
|
| 48 |
+
"""
|
| 49 |
+
Wraps the trained model for grader compatibility.
|
| 50 |
+
Must be re-instantiated for every episode β _step and _prev_burning
|
| 51 |
+
are per-episode state and will produce wrong prompts if reused.
|
| 52 |
+
"""
|
| 53 |
+
|
| 54 |
+
def __init__(self, model, tokenizer, tier, max_steps):
|
| 55 |
+
self.model = model
|
| 56 |
+
self.tokenizer = tokenizer
|
| 57 |
+
self.tier = tier
|
| 58 |
+
self.max_steps = max_steps
|
| 59 |
+
self._step = 0
|
| 60 |
+
self._prev_burning = 0
|
| 61 |
+
self.json_success = self.regex_fallback = self.safe_idle = 0
|
| 62 |
+
|
| 63 |
+
def act(self, obs):
|
| 64 |
+
import torch
|
| 65 |
+
|
| 66 |
+
prompt = serialize_observation(
|
| 67 |
+
obs, self._step, self.max_steps,
|
| 68 |
+
tier=self.tier,
|
| 69 |
+
prev_cells_burning=self._prev_burning,
|
| 70 |
+
)
|
| 71 |
+
self._prev_burning = obs.stats.cells_burning
|
| 72 |
+
messages = [
|
| 73 |
+
{"role": "system", "content": SYSTEM_PROMPT},
|
| 74 |
+
{"role": "user", "content": prompt},
|
| 75 |
+
]
|
| 76 |
+
input_ids = self.tokenizer.apply_chat_template(
|
| 77 |
+
messages, tokenize=True,
|
| 78 |
+
add_generation_prompt=True, return_tensors="pt",
|
| 79 |
+
).to(self.model.device)
|
| 80 |
+
with torch.no_grad():
|
| 81 |
+
out = self.model.generate(
|
| 82 |
+
input_ids, max_new_tokens=128,
|
| 83 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
| 84 |
+
)
|
| 85 |
+
text = self.tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)
|
| 86 |
+
action, status = parse_action(text, obs)
|
| 87 |
+
if status == "json_success":
|
| 88 |
+
self.json_success += 1
|
| 89 |
+
elif status == "regex_fallback":
|
| 90 |
+
self.regex_fallback += 1
|
| 91 |
+
else:
|
| 92 |
+
self.safe_idle += 1
|
| 93 |
+
self._step += 1
|
| 94 |
+
return action
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def run_llm_episode(model, tokenizer, tier, seed):
|
| 98 |
+
"""Run a full episode with a fresh LLMAgent. Returns (reward, details)."""
|
| 99 |
+
max_steps = TIER_MAX_STEPS[tier]
|
| 100 |
+
agent = LLMAgent(model, tokenizer, tier, max_steps)
|
| 101 |
+
env = WildfireEnv()
|
| 102 |
+
obs = env.reset(task_id=tier, seed=seed)
|
| 103 |
+
total_reward = 0.0
|
| 104 |
+
|
| 105 |
+
while not env.done:
|
| 106 |
+
action = agent.act(obs)
|
| 107 |
+
result = env.step(action)
|
| 108 |
+
total_reward += result.reward
|
| 109 |
+
obs = result.observation
|
| 110 |
+
|
| 111 |
+
final = env.state()
|
| 112 |
+
total_pop = final.get("total_population", 1) or 1
|
| 113 |
+
pop_lost = final.get("population_lost", 0)
|
| 114 |
+
|
| 115 |
+
details = {
|
| 116 |
+
"total_reward": round(total_reward, 4),
|
| 117 |
+
"containment_pct": round(
|
| 118 |
+
final.get("reward_breakdown", {}).get("containment", 0.0), 4
|
| 119 |
+
),
|
| 120 |
+
"pop_saved_pct": round(1.0 - pop_lost / total_pop, 4),
|
| 121 |
+
"steps": env.current_step,
|
| 122 |
+
"crew_casualty": env._crew_casualty_occurred,
|
| 123 |
+
"json_success": agent.json_success,
|
| 124 |
+
"regex_fallback": agent.regex_fallback,
|
| 125 |
+
"safe_idle": agent.safe_idle,
|
| 126 |
+
}
|
| 127 |
+
return total_reward, details
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def load_model(model_path: str, base_model: str):
|
| 131 |
+
"""Load a trained model, handling both full repos and PEFT adapters."""
|
| 132 |
+
from unsloth import FastLanguageModel
|
| 133 |
+
|
| 134 |
+
# Try loading directly (works for merged models and HF adapter repos
|
| 135 |
+
# that embed base_model_name_or_path in adapter_config.json)
|
| 136 |
+
try:
|
| 137 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 138 |
+
model_name=model_path,
|
| 139 |
+
max_seq_length=2048,
|
| 140 |
+
load_in_4bit=True,
|
| 141 |
+
)
|
| 142 |
+
print(f"Loaded model directly from: {model_path}")
|
| 143 |
+
return model, tokenizer
|
| 144 |
+
except Exception as e:
|
| 145 |
+
print(f"Direct load failed ({e}), trying base + adapter...")
|
| 146 |
+
|
| 147 |
+
# Fallback: load base model then attach adapter (for standalone PEFT adapters)
|
| 148 |
+
model, tokenizer = FastLanguageModel.from_pretrained(
|
| 149 |
+
model_name=base_model,
|
| 150 |
+
max_seq_length=2048,
|
| 151 |
+
load_in_4bit=True,
|
| 152 |
+
)
|
| 153 |
+
model.load_adapter(model_path, adapter_name="default")
|
| 154 |
+
print(f"Loaded base model ({base_model}) + adapter ({model_path})")
|
| 155 |
+
return model, tokenizer
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def main():
|
| 159 |
+
parser = argparse.ArgumentParser(description="Evaluate trained model vs baselines")
|
| 160 |
+
parser.add_argument("--model-path", required=True,
|
| 161 |
+
help="HF hub ID or local path to the trained adapter")
|
| 162 |
+
parser.add_argument("--base-model", default="unsloth/Qwen2.5-7B-Instruct",
|
| 163 |
+
help="Base model for PEFT adapter loading "
|
| 164 |
+
"(default: unsloth/Qwen2.5-7B-Instruct)")
|
| 165 |
+
parser.add_argument("--num-seeds", type=int, default=15,
|
| 166 |
+
help="Evaluation seeds per tier (default: 15, uses seeds 200+)")
|
| 167 |
+
parser.add_argument("--tiers", nargs="+", default=["easy", "medium", "hard"],
|
| 168 |
+
help="Tiers to evaluate (default: easy medium hard)")
|
| 169 |
+
args = parser.parse_args()
|
| 170 |
+
|
| 171 |
+
seeds = list(range(200, 200 + args.num_seeds))
|
| 172 |
+
|
| 173 |
+
# Load trained model (Issue 1 fix: uses --base-model for adapter fallback)
|
| 174 |
+
print(f"Loading model: {args.model_path}")
|
| 175 |
+
model, tokenizer = load_model(args.model_path, args.base_model)
|
| 176 |
+
if tokenizer.pad_token is None:
|
| 177 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 178 |
+
|
| 179 |
+
from unsloth import FastLanguageModel
|
| 180 |
+
FastLanguageModel.for_inference(model)
|
| 181 |
+
print("Model ready for inference.\n")
|
| 182 |
+
|
| 183 |
+
# Load existing baselines (Issue 3 fix: use stored values for comparison table)
|
| 184 |
+
baselines_path = os.path.join(os.path.dirname(__file__), "results.json")
|
| 185 |
+
if not os.path.exists(baselines_path):
|
| 186 |
+
print(f"WARNING: {baselines_path} not found. Run scripts/evaluate.py first.")
|
| 187 |
+
sys.exit(1)
|
| 188 |
+
with open(baselines_path, "r") as f:
|
| 189 |
+
baselines = json.load(f)
|
| 190 |
+
|
| 191 |
+
# Output in same shape as results.json: {agent: {tier: {...}}} (Issue 2 fix)
|
| 192 |
+
all_results = {"trained": {}}
|
| 193 |
+
|
| 194 |
+
for tier in args.tiers:
|
| 195 |
+
max_steps = TIER_MAX_STEPS[tier]
|
| 196 |
+
print(f"{'='*60}")
|
| 197 |
+
print(f" Tier: {tier} | Seeds: {seeds[0]}-{seeds[-1]} | Max steps: {max_steps}")
|
| 198 |
+
print(f"{'='*60}")
|
| 199 |
+
|
| 200 |
+
tier_rewards = []
|
| 201 |
+
tier_pop_saved = []
|
| 202 |
+
tier_containment = []
|
| 203 |
+
tier_json_success = 0
|
| 204 |
+
tier_total_actions = 0
|
| 205 |
+
tier_casualty_count = 0
|
| 206 |
+
tier_times = []
|
| 207 |
+
|
| 208 |
+
for seed in seeds:
|
| 209 |
+
start = time.time()
|
| 210 |
+
reward, details = run_llm_episode(model, tokenizer, tier, seed)
|
| 211 |
+
elapsed = time.time() - start
|
| 212 |
+
|
| 213 |
+
tier_rewards.append(reward)
|
| 214 |
+
tier_pop_saved.append(details["pop_saved_pct"])
|
| 215 |
+
tier_containment.append(details["containment_pct"])
|
| 216 |
+
tier_json_success += details["json_success"]
|
| 217 |
+
tier_total_actions += (details["json_success"]
|
| 218 |
+
+ details["regex_fallback"]
|
| 219 |
+
+ details["safe_idle"])
|
| 220 |
+
if details["crew_casualty"]:
|
| 221 |
+
tier_casualty_count += 1
|
| 222 |
+
tier_times.append(elapsed)
|
| 223 |
+
|
| 224 |
+
print(f" seed={seed}: reward={reward:+.2f}, "
|
| 225 |
+
f"pop_saved={details['pop_saved_pct']*100:.0f}%, "
|
| 226 |
+
f"steps={details['steps']}, time={elapsed:.1f}s")
|
| 227 |
+
|
| 228 |
+
json_rate = (100.0 * tier_json_success / tier_total_actions
|
| 229 |
+
if tier_total_actions > 0 else 0)
|
| 230 |
+
|
| 231 |
+
all_results["trained"][tier] = {
|
| 232 |
+
"scores": [round(r, 4) for r in tier_rewards],
|
| 233 |
+
"mean": round(float(np.mean(tier_rewards)), 4),
|
| 234 |
+
"std": round(float(np.std(tier_rewards)), 4),
|
| 235 |
+
"mean_containment_pct": round(float(np.mean(tier_containment)), 4),
|
| 236 |
+
"mean_pop_saved_pct": round(float(np.mean(tier_pop_saved)), 4),
|
| 237 |
+
"crew_casualty_rate": round(tier_casualty_count / len(seeds), 2),
|
| 238 |
+
"mean_time_s": round(float(np.mean(tier_times)), 3),
|
| 239 |
+
"json_success_rate": round(json_rate, 2),
|
| 240 |
+
}
|
| 241 |
+
print()
|
| 242 |
+
|
| 243 |
+
# ββ Print comparison table using stored baselines ββ
|
| 244 |
+
print()
|
| 245 |
+
print("=" * 65)
|
| 246 |
+
print("=== Evaluation: Trained Model vs Baselines ===")
|
| 247 |
+
print(f"Model: {args.model_path}")
|
| 248 |
+
print(f"Seeds: {seeds[0]}-{seeds[-1]} ({len(seeds)} per tier)")
|
| 249 |
+
print("=" * 65)
|
| 250 |
+
print(f"{'Tier':<10} {'Trained':>12} {'Heuristic':>12} {'Random':>12} {'vs Heuristic':>14}")
|
| 251 |
+
print("-" * 65)
|
| 252 |
+
|
| 253 |
+
for tier in args.tiers:
|
| 254 |
+
t = all_results["trained"][tier]
|
| 255 |
+
h_mean = baselines["heuristic"][tier]["mean"]
|
| 256 |
+
h_std = baselines["heuristic"][tier]["std"]
|
| 257 |
+
r_mean = baselines["random"][tier]["mean"]
|
| 258 |
+
r_std = baselines["random"][tier]["std"]
|
| 259 |
+
delta = t["mean"] - h_mean
|
| 260 |
+
marker = " OK" if delta >= -1.0 else ""
|
| 261 |
+
print(
|
| 262 |
+
f"{tier:<10} "
|
| 263 |
+
f"{t['mean']:+.2f}+/-{t['std']:.1f} "
|
| 264 |
+
f"{h_mean:+.2f}+/-{h_std:.1f} "
|
| 265 |
+
f"{r_mean:+.2f}+/-{r_std:.1f} "
|
| 266 |
+
f"{delta:+.2f}{marker}"
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
print()
|
| 270 |
+
print("JSON success rate: ", end="")
|
| 271 |
+
print(" ".join(
|
| 272 |
+
f"{t}={all_results['trained'][t]['json_success_rate']:.1f}%"
|
| 273 |
+
for t in args.tiers
|
| 274 |
+
))
|
| 275 |
+
print("Pop saved rate: ", end="")
|
| 276 |
+
print(" ".join(
|
| 277 |
+
f"{t}={all_results['trained'][t]['mean_pop_saved_pct']*100:.0f}%"
|
| 278 |
+
for t in args.tiers
|
| 279 |
+
))
|
| 280 |
+
print("=" * 65)
|
| 281 |
+
|
| 282 |
+
# ββ Save results (same top-level shape as results.json) ββ
|
| 283 |
+
output_path = os.path.join(os.path.dirname(__file__), "trained_results.json")
|
| 284 |
+
with open(output_path, "w") as f:
|
| 285 |
+
json.dump(all_results, f, indent=2)
|
| 286 |
+
print(f"\nResults saved to {output_path}")
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
if __name__ == "__main__":
|
| 290 |
+
main()
|