Eshit commited on
Commit
a070e2f
Β·
1 Parent(s): ad92ece

Add trained model evaluation script

Browse files
Files changed (1) hide show
  1. scripts/eval_trained_model.py +290 -0
scripts/eval_trained_model.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Evaluate a trained HF adapter model against heuristic and random baselines
3
+ on the Wildfire Containment Simulator.
4
+
5
+ Saves results to scripts/trained_results.json.
6
+
7
+ Usage:
8
+ python scripts/eval_trained_model.py --model-path Eshit/wildfire-grpo-7b
9
+ python scripts/eval_trained_model.py --model-path ./grpo_final --num-seeds 10
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import argparse
15
+ import json
16
+ import os
17
+ import sys
18
+ import time
19
+ from pathlib import Path
20
+
21
+ import numpy as np
22
+
23
+ PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
24
+ sys.path.insert(0, PROJECT_ROOT)
25
+
26
+ from env.wildfire_env import WildfireEnv
27
+ from env.serialization import serialize_observation
28
+ from env.action_parser import parse_action
29
+ from env.models import TIER_EASY, TIER_MEDIUM, TIER_HARD
30
+ from agents.heuristic_agent import HeuristicAgent
31
+ from agents.random_agent import RandomAgent
32
+
33
+ TIER_MAX_STEPS = {
34
+ "easy": TIER_EASY.episode_length,
35
+ "medium": TIER_MEDIUM.episode_length,
36
+ "hard": TIER_HARD.episode_length,
37
+ }
38
+
39
+ SYSTEM_PROMPT = (
40
+ "You are an AI Incident Commander managing wildfire containment. "
41
+ "You will receive a situation briefing each step. "
42
+ "Respond with ONLY a valid JSON action object and nothing else. "
43
+ 'Example: {"action_type": "idle"}'
44
+ )
45
+
46
+
47
+ class LLMAgent:
48
+ """
49
+ Wraps the trained model for grader compatibility.
50
+ Must be re-instantiated for every episode β€” _step and _prev_burning
51
+ are per-episode state and will produce wrong prompts if reused.
52
+ """
53
+
54
+ def __init__(self, model, tokenizer, tier, max_steps):
55
+ self.model = model
56
+ self.tokenizer = tokenizer
57
+ self.tier = tier
58
+ self.max_steps = max_steps
59
+ self._step = 0
60
+ self._prev_burning = 0
61
+ self.json_success = self.regex_fallback = self.safe_idle = 0
62
+
63
+ def act(self, obs):
64
+ import torch
65
+
66
+ prompt = serialize_observation(
67
+ obs, self._step, self.max_steps,
68
+ tier=self.tier,
69
+ prev_cells_burning=self._prev_burning,
70
+ )
71
+ self._prev_burning = obs.stats.cells_burning
72
+ messages = [
73
+ {"role": "system", "content": SYSTEM_PROMPT},
74
+ {"role": "user", "content": prompt},
75
+ ]
76
+ input_ids = self.tokenizer.apply_chat_template(
77
+ messages, tokenize=True,
78
+ add_generation_prompt=True, return_tensors="pt",
79
+ ).to(self.model.device)
80
+ with torch.no_grad():
81
+ out = self.model.generate(
82
+ input_ids, max_new_tokens=128,
83
+ pad_token_id=self.tokenizer.eos_token_id,
84
+ )
85
+ text = self.tokenizer.decode(out[0][input_ids.shape[1]:], skip_special_tokens=True)
86
+ action, status = parse_action(text, obs)
87
+ if status == "json_success":
88
+ self.json_success += 1
89
+ elif status == "regex_fallback":
90
+ self.regex_fallback += 1
91
+ else:
92
+ self.safe_idle += 1
93
+ self._step += 1
94
+ return action
95
+
96
+
97
+ def run_llm_episode(model, tokenizer, tier, seed):
98
+ """Run a full episode with a fresh LLMAgent. Returns (reward, details)."""
99
+ max_steps = TIER_MAX_STEPS[tier]
100
+ agent = LLMAgent(model, tokenizer, tier, max_steps)
101
+ env = WildfireEnv()
102
+ obs = env.reset(task_id=tier, seed=seed)
103
+ total_reward = 0.0
104
+
105
+ while not env.done:
106
+ action = agent.act(obs)
107
+ result = env.step(action)
108
+ total_reward += result.reward
109
+ obs = result.observation
110
+
111
+ final = env.state()
112
+ total_pop = final.get("total_population", 1) or 1
113
+ pop_lost = final.get("population_lost", 0)
114
+
115
+ details = {
116
+ "total_reward": round(total_reward, 4),
117
+ "containment_pct": round(
118
+ final.get("reward_breakdown", {}).get("containment", 0.0), 4
119
+ ),
120
+ "pop_saved_pct": round(1.0 - pop_lost / total_pop, 4),
121
+ "steps": env.current_step,
122
+ "crew_casualty": env._crew_casualty_occurred,
123
+ "json_success": agent.json_success,
124
+ "regex_fallback": agent.regex_fallback,
125
+ "safe_idle": agent.safe_idle,
126
+ }
127
+ return total_reward, details
128
+
129
+
130
+ def load_model(model_path: str, base_model: str):
131
+ """Load a trained model, handling both full repos and PEFT adapters."""
132
+ from unsloth import FastLanguageModel
133
+
134
+ # Try loading directly (works for merged models and HF adapter repos
135
+ # that embed base_model_name_or_path in adapter_config.json)
136
+ try:
137
+ model, tokenizer = FastLanguageModel.from_pretrained(
138
+ model_name=model_path,
139
+ max_seq_length=2048,
140
+ load_in_4bit=True,
141
+ )
142
+ print(f"Loaded model directly from: {model_path}")
143
+ return model, tokenizer
144
+ except Exception as e:
145
+ print(f"Direct load failed ({e}), trying base + adapter...")
146
+
147
+ # Fallback: load base model then attach adapter (for standalone PEFT adapters)
148
+ model, tokenizer = FastLanguageModel.from_pretrained(
149
+ model_name=base_model,
150
+ max_seq_length=2048,
151
+ load_in_4bit=True,
152
+ )
153
+ model.load_adapter(model_path, adapter_name="default")
154
+ print(f"Loaded base model ({base_model}) + adapter ({model_path})")
155
+ return model, tokenizer
156
+
157
+
158
+ def main():
159
+ parser = argparse.ArgumentParser(description="Evaluate trained model vs baselines")
160
+ parser.add_argument("--model-path", required=True,
161
+ help="HF hub ID or local path to the trained adapter")
162
+ parser.add_argument("--base-model", default="unsloth/Qwen2.5-7B-Instruct",
163
+ help="Base model for PEFT adapter loading "
164
+ "(default: unsloth/Qwen2.5-7B-Instruct)")
165
+ parser.add_argument("--num-seeds", type=int, default=15,
166
+ help="Evaluation seeds per tier (default: 15, uses seeds 200+)")
167
+ parser.add_argument("--tiers", nargs="+", default=["easy", "medium", "hard"],
168
+ help="Tiers to evaluate (default: easy medium hard)")
169
+ args = parser.parse_args()
170
+
171
+ seeds = list(range(200, 200 + args.num_seeds))
172
+
173
+ # Load trained model (Issue 1 fix: uses --base-model for adapter fallback)
174
+ print(f"Loading model: {args.model_path}")
175
+ model, tokenizer = load_model(args.model_path, args.base_model)
176
+ if tokenizer.pad_token is None:
177
+ tokenizer.pad_token = tokenizer.eos_token
178
+
179
+ from unsloth import FastLanguageModel
180
+ FastLanguageModel.for_inference(model)
181
+ print("Model ready for inference.\n")
182
+
183
+ # Load existing baselines (Issue 3 fix: use stored values for comparison table)
184
+ baselines_path = os.path.join(os.path.dirname(__file__), "results.json")
185
+ if not os.path.exists(baselines_path):
186
+ print(f"WARNING: {baselines_path} not found. Run scripts/evaluate.py first.")
187
+ sys.exit(1)
188
+ with open(baselines_path, "r") as f:
189
+ baselines = json.load(f)
190
+
191
+ # Output in same shape as results.json: {agent: {tier: {...}}} (Issue 2 fix)
192
+ all_results = {"trained": {}}
193
+
194
+ for tier in args.tiers:
195
+ max_steps = TIER_MAX_STEPS[tier]
196
+ print(f"{'='*60}")
197
+ print(f" Tier: {tier} | Seeds: {seeds[0]}-{seeds[-1]} | Max steps: {max_steps}")
198
+ print(f"{'='*60}")
199
+
200
+ tier_rewards = []
201
+ tier_pop_saved = []
202
+ tier_containment = []
203
+ tier_json_success = 0
204
+ tier_total_actions = 0
205
+ tier_casualty_count = 0
206
+ tier_times = []
207
+
208
+ for seed in seeds:
209
+ start = time.time()
210
+ reward, details = run_llm_episode(model, tokenizer, tier, seed)
211
+ elapsed = time.time() - start
212
+
213
+ tier_rewards.append(reward)
214
+ tier_pop_saved.append(details["pop_saved_pct"])
215
+ tier_containment.append(details["containment_pct"])
216
+ tier_json_success += details["json_success"]
217
+ tier_total_actions += (details["json_success"]
218
+ + details["regex_fallback"]
219
+ + details["safe_idle"])
220
+ if details["crew_casualty"]:
221
+ tier_casualty_count += 1
222
+ tier_times.append(elapsed)
223
+
224
+ print(f" seed={seed}: reward={reward:+.2f}, "
225
+ f"pop_saved={details['pop_saved_pct']*100:.0f}%, "
226
+ f"steps={details['steps']}, time={elapsed:.1f}s")
227
+
228
+ json_rate = (100.0 * tier_json_success / tier_total_actions
229
+ if tier_total_actions > 0 else 0)
230
+
231
+ all_results["trained"][tier] = {
232
+ "scores": [round(r, 4) for r in tier_rewards],
233
+ "mean": round(float(np.mean(tier_rewards)), 4),
234
+ "std": round(float(np.std(tier_rewards)), 4),
235
+ "mean_containment_pct": round(float(np.mean(tier_containment)), 4),
236
+ "mean_pop_saved_pct": round(float(np.mean(tier_pop_saved)), 4),
237
+ "crew_casualty_rate": round(tier_casualty_count / len(seeds), 2),
238
+ "mean_time_s": round(float(np.mean(tier_times)), 3),
239
+ "json_success_rate": round(json_rate, 2),
240
+ }
241
+ print()
242
+
243
+ # ── Print comparison table using stored baselines ──
244
+ print()
245
+ print("=" * 65)
246
+ print("=== Evaluation: Trained Model vs Baselines ===")
247
+ print(f"Model: {args.model_path}")
248
+ print(f"Seeds: {seeds[0]}-{seeds[-1]} ({len(seeds)} per tier)")
249
+ print("=" * 65)
250
+ print(f"{'Tier':<10} {'Trained':>12} {'Heuristic':>12} {'Random':>12} {'vs Heuristic':>14}")
251
+ print("-" * 65)
252
+
253
+ for tier in args.tiers:
254
+ t = all_results["trained"][tier]
255
+ h_mean = baselines["heuristic"][tier]["mean"]
256
+ h_std = baselines["heuristic"][tier]["std"]
257
+ r_mean = baselines["random"][tier]["mean"]
258
+ r_std = baselines["random"][tier]["std"]
259
+ delta = t["mean"] - h_mean
260
+ marker = " OK" if delta >= -1.0 else ""
261
+ print(
262
+ f"{tier:<10} "
263
+ f"{t['mean']:+.2f}+/-{t['std']:.1f} "
264
+ f"{h_mean:+.2f}+/-{h_std:.1f} "
265
+ f"{r_mean:+.2f}+/-{r_std:.1f} "
266
+ f"{delta:+.2f}{marker}"
267
+ )
268
+
269
+ print()
270
+ print("JSON success rate: ", end="")
271
+ print(" ".join(
272
+ f"{t}={all_results['trained'][t]['json_success_rate']:.1f}%"
273
+ for t in args.tiers
274
+ ))
275
+ print("Pop saved rate: ", end="")
276
+ print(" ".join(
277
+ f"{t}={all_results['trained'][t]['mean_pop_saved_pct']*100:.0f}%"
278
+ for t in args.tiers
279
+ ))
280
+ print("=" * 65)
281
+
282
+ # ── Save results (same top-level shape as results.json) ──
283
+ output_path = os.path.join(os.path.dirname(__file__), "trained_results.json")
284
+ with open(output_path, "w") as f:
285
+ json.dump(all_results, f, indent=2)
286
+ print(f"\nResults saved to {output_path}")
287
+
288
+
289
+ if __name__ == "__main__":
290
+ main()