CreativeEngineer commited on
Commit
ede4c5c
·
1 Parent(s): 8254ade

feat: add model-driven llm reward evaluation

Browse files
Files changed (2) hide show
  1. training/README.md +6 -0
  2. training/llm_rollout.py +202 -16
training/README.md CHANGED
@@ -23,6 +23,8 @@ Training policy:
23
  - replay an LLM completion or action plan: `uv run python training/llm_rollout.py replay --seed 0 --completion-file <path>`
24
  - monitor reward terms, action clamping, and verifier outcomes across seeds:
25
  `uv run python training/llm_rollout.py monitor --completion-file <path> --seeds 0,1,2`
 
 
26
 
27
  ## Shared LLM Contract
28
 
@@ -36,3 +38,7 @@ Use that module as the source of truth for:
36
  - action-plan parsing
37
  - local rollout replay
38
  - rollout telemetry structure used by the monitor command
 
 
 
 
 
23
  - replay an LLM completion or action plan: `uv run python training/llm_rollout.py replay --seed 0 --completion-file <path>`
24
  - monitor reward terms, action clamping, and verifier outcomes across seeds:
25
  `uv run python training/llm_rollout.py monitor --completion-file <path> --seeds 0,1,2`
26
+ - generate fresh model completions per seed and save aggregate reward/outcome metrics:
27
+ `uv run python training/llm_rollout.py evaluate --completion-command 'python path/to/model_cli.py' --seeds 0,1,2`
28
 
29
  ## Shared LLM Contract
30
 
 
38
  - action-plan parsing
39
  - local rollout replay
40
  - rollout telemetry structure used by the monitor command
41
+
42
+ For `evaluate`, the completion command reads the prompt from `stdin` and writes a raw completion to `stdout`.
43
+ The current seed is exposed as the `FUSION_LAB_SEED` environment variable so the same command can be used
44
+ for fixed-seed before/after comparisons of untrained and trained checkpoints.
training/llm_rollout.py CHANGED
@@ -2,6 +2,9 @@ from __future__ import annotations
2
 
3
  import argparse
4
  import json
 
 
 
5
  from datetime import UTC, datetime
6
  from pathlib import Path
7
  from typing import Final
@@ -10,12 +13,14 @@ from fusion_lab.llm_agent import (
10
  build_prompt,
11
  parse_action_plan,
12
  run_episode_with_actions,
 
13
  )
14
  from fusion_lab.models import StellaratorAction
15
  from server.environment import StellaratorEnvironment
16
 
17
  DEFAULT_OUTPUT_DIR: Final[Path] = Path("training/artifacts/llm_rollout")
18
  DEFAULT_MONITOR_OUTPUT_DIR: Final[Path] = Path("training/artifacts/llm_monitor")
 
19
 
20
 
21
  def add_action_source_args(parser: argparse.ArgumentParser) -> None:
@@ -81,6 +86,41 @@ def parse_args() -> argparse.Namespace:
81
  default=DEFAULT_MONITOR_OUTPUT_DIR,
82
  help="Directory for monitoring artifacts.",
83
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  return parser.parse_args()
85
 
86
 
@@ -174,6 +214,76 @@ def _reward_terms_summary(reward_breakdown: dict[str, object]) -> str:
174
  return ", ".join(non_zero_terms) if non_zero_terms else "none"
175
 
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  def monitor_payload(
178
  *,
179
  source: str,
@@ -181,43 +291,102 @@ def monitor_payload(
181
  seeds: list[int],
182
  ) -> dict[str, object]:
183
  traces = [run_episode_with_actions(actions, seed_idx=seed) for seed in seeds]
184
- feasible_count = sum(1 for trace in traces if trace.constraints_satisfied)
185
- high_fidelity_count = sum(1 for trace in traces if trace.final_evaluation_fidelity == "high")
186
- mean_reward = sum(trace.total_reward for trace in traces) / len(traces)
187
  return {
188
  "created_at_utc": datetime.now(UTC).isoformat(),
189
  "source": source,
190
  "parsed_action_count": len(actions),
191
  "actions": [action.model_dump(exclude_none=True) for action in actions],
192
  "seeds": seeds,
193
- "summary": {
194
- "episode_count": len(traces),
195
- "feasible_episode_count": feasible_count,
196
- "high_fidelity_episode_count": high_fidelity_count,
197
- "mean_total_reward": round(mean_reward, 4),
198
- },
199
  "episodes": [trace.asdict() for trace in traces],
200
  }
201
 
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  def write_monitor_summary(payload: dict[str, object]) -> None:
204
  summary = payload["summary"]
205
  print(
206
  "episodes="
207
  f"{summary['episode_count']} feasible={summary['feasible_episode_count']} "
208
  f"high_fidelity={summary['high_fidelity_episode_count']} "
209
- f"mean_total_reward={summary['mean_total_reward']:+.4f}"
 
 
 
210
  )
211
  for episode in payload["episodes"]:
 
212
  print(
213
  "seed="
214
- f"{episode['seed']} total_reward={episode['total_reward']:+.4f} "
215
- f"final_fidelity={episode['final_evaluation_fidelity']} "
216
- f"feasible={episode['constraints_satisfied']} "
217
- f"score={episode['final_score']:.6f} "
218
- f"feasibility={episode['final_feasibility']:.6f}"
219
  )
220
- for step in episode["steps"]:
 
 
221
  action_monitor = step["action_monitor"]
222
  print(
223
  " step="
@@ -240,6 +409,20 @@ def run_monitor(args: argparse.Namespace) -> None:
240
  write_monitor_summary(payload)
241
 
242
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
  def main() -> None:
244
  args = parse_args()
245
  if args.command == "prompt":
@@ -248,6 +431,9 @@ def main() -> None:
248
  if args.command == "replay":
249
  run_replay(args)
250
  return
 
 
 
251
  run_monitor(args)
252
 
253
 
 
2
 
3
  import argparse
4
  import json
5
+ import math
6
+ import os
7
+ import subprocess
8
  from datetime import UTC, datetime
9
  from pathlib import Path
10
  from typing import Final
 
13
  build_prompt,
14
  parse_action_plan,
15
  run_episode_with_actions,
16
+ LLMEpisodeTrace,
17
  )
18
  from fusion_lab.models import StellaratorAction
19
  from server.environment import StellaratorEnvironment
20
 
21
  DEFAULT_OUTPUT_DIR: Final[Path] = Path("training/artifacts/llm_rollout")
22
  DEFAULT_MONITOR_OUTPUT_DIR: Final[Path] = Path("training/artifacts/llm_monitor")
23
+ DEFAULT_EVALUATE_OUTPUT_DIR: Final[Path] = Path("training/artifacts/llm_evaluate")
24
 
25
 
26
  def add_action_source_args(parser: argparse.ArgumentParser) -> None:
 
86
  default=DEFAULT_MONITOR_OUTPUT_DIR,
87
  help="Directory for monitoring artifacts.",
88
  )
89
+
90
+ evaluate_parser = subparsers.add_parser(
91
+ "evaluate",
92
+ help=(
93
+ "Generate fresh completions per seed with a model command, replay them, "
94
+ "and save aggregate reward/outcome metrics."
95
+ ),
96
+ )
97
+ evaluate_parser.add_argument(
98
+ "--completion-command",
99
+ type=str,
100
+ required=True,
101
+ help=(
102
+ "Shell command that reads the prompt from stdin and writes a completion to stdout. "
103
+ "The current seed is exposed as FUSION_LAB_SEED."
104
+ ),
105
+ )
106
+ evaluate_parser.add_argument(
107
+ "--seeds",
108
+ type=str,
109
+ default="0,1,2",
110
+ help="Comma-separated reset seed indices to evaluate.",
111
+ )
112
+ evaluate_parser.add_argument(
113
+ "--label",
114
+ type=str,
115
+ default="model",
116
+ help="Short label stored in the evaluation artifact.",
117
+ )
118
+ evaluate_parser.add_argument(
119
+ "--output-dir",
120
+ type=Path,
121
+ default=DEFAULT_EVALUATE_OUTPUT_DIR,
122
+ help="Directory for evaluation artifacts.",
123
+ )
124
  return parser.parse_args()
125
 
126
 
 
214
  return ", ".join(non_zero_terms) if non_zero_terms else "none"
215
 
216
 
217
+ def _mean(values: list[float]) -> float | None:
218
+ if not values:
219
+ return None
220
+ return sum(values) / len(values)
221
+
222
+
223
+ def _round_metric(value: float | None) -> float | None:
224
+ if value is None:
225
+ return None
226
+ return round(value, 4)
227
+
228
+
229
+ def _format_metric(value: object, precision: int = 4, signed: bool = False) -> str:
230
+ if not isinstance(value, (int, float)):
231
+ return "n/a"
232
+ if signed:
233
+ return f"{float(value):+.{precision}f}"
234
+ return f"{float(value):.{precision}f}"
235
+
236
+
237
+ def _pearson_correlation(xs: list[float], ys: list[float]) -> float | None:
238
+ if len(xs) != len(ys) or len(xs) < 2:
239
+ return None
240
+ mean_x = sum(xs) / len(xs)
241
+ mean_y = sum(ys) / len(ys)
242
+ centered_x = [value - mean_x for value in xs]
243
+ centered_y = [value - mean_y for value in ys]
244
+ variance_x = sum(value * value for value in centered_x)
245
+ variance_y = sum(value * value for value in centered_y)
246
+ if math.isclose(variance_x, 0.0) or math.isclose(variance_y, 0.0):
247
+ return None
248
+ covariance = sum(x_value * y_value for x_value, y_value in zip(centered_x, centered_y))
249
+ return covariance / math.sqrt(variance_x * variance_y)
250
+
251
+
252
+ def summarize_traces(traces: list[LLMEpisodeTrace]) -> dict[str, object]:
253
+ feasible_count = sum(1 for trace in traces if trace.constraints_satisfied)
254
+ high_fidelity_traces = [trace for trace in traces if trace.final_evaluation_fidelity == "high"]
255
+ high_fidelity_count = len(high_fidelity_traces)
256
+ failed_count = sum(1 for trace in traces if trace.evaluation_failed)
257
+ total_rewards = [trace.total_reward for trace in traces]
258
+ final_scores = [trace.final_score for trace in traces]
259
+ final_feasibilities = [trace.final_feasibility for trace in traces]
260
+ high_fidelity_scores = [trace.final_score for trace in high_fidelity_traces]
261
+ high_fidelity_feasibilities = [trace.final_feasibility for trace in high_fidelity_traces]
262
+ feasible_flags = [1.0 if trace.constraints_satisfied else 0.0 for trace in traces]
263
+ episode_count = len(traces)
264
+
265
+ return {
266
+ "episode_count": episode_count,
267
+ "feasible_episode_count": feasible_count,
268
+ "high_fidelity_episode_count": high_fidelity_count,
269
+ "evaluation_failed_episode_count": failed_count,
270
+ "feasible_rate": _round_metric(feasible_count / episode_count),
271
+ "high_fidelity_rate": _round_metric(high_fidelity_count / episode_count),
272
+ "evaluation_failed_rate": _round_metric(failed_count / episode_count),
273
+ "mean_total_reward": _round_metric(_mean(total_rewards)),
274
+ "mean_final_score": _round_metric(_mean(final_scores)),
275
+ "mean_final_feasibility": _round_metric(_mean(final_feasibilities)),
276
+ "mean_high_fidelity_score": _round_metric(_mean(high_fidelity_scores)),
277
+ "mean_high_fidelity_feasibility": _round_metric(_mean(high_fidelity_feasibilities)),
278
+ "reward_final_score_correlation": _round_metric(
279
+ _pearson_correlation(total_rewards, final_scores)
280
+ ),
281
+ "reward_feasible_correlation": _round_metric(
282
+ _pearson_correlation(total_rewards, feasible_flags)
283
+ ),
284
+ }
285
+
286
+
287
  def monitor_payload(
288
  *,
289
  source: str,
 
291
  seeds: list[int],
292
  ) -> dict[str, object]:
293
  traces = [run_episode_with_actions(actions, seed_idx=seed) for seed in seeds]
 
 
 
294
  return {
295
  "created_at_utc": datetime.now(UTC).isoformat(),
296
  "source": source,
297
  "parsed_action_count": len(actions),
298
  "actions": [action.model_dump(exclude_none=True) for action in actions],
299
  "seeds": seeds,
300
+ "summary": summarize_traces(traces),
 
 
 
 
 
301
  "episodes": [trace.asdict() for trace in traces],
302
  }
303
 
304
 
305
+ def _run_completion_command(*, prompt: str, seed: int, command: str) -> str:
306
+ env = os.environ.copy()
307
+ env["FUSION_LAB_SEED"] = str(seed)
308
+ shell_path = env.get("SHELL", "/bin/sh")
309
+ completed = subprocess.run(
310
+ [shell_path, "-lc", command],
311
+ input=prompt,
312
+ text=True,
313
+ capture_output=True,
314
+ env=env,
315
+ check=False,
316
+ )
317
+ if completed.returncode != 0:
318
+ raise RuntimeError(
319
+ "completion command failed "
320
+ f"(seed={seed}, exit_code={completed.returncode}): {completed.stderr.strip()}"
321
+ )
322
+ return completed.stdout
323
+
324
+
325
+ def evaluate_payload(
326
+ *,
327
+ completion_command: str,
328
+ label: str,
329
+ seeds: list[int],
330
+ ) -> dict[str, object]:
331
+ evaluations: list[dict[str, object]] = []
332
+ traces: list[LLMEpisodeTrace] = []
333
+
334
+ for seed in seeds:
335
+ observation = StellaratorEnvironment().reset(seed=seed)
336
+ prompt = build_prompt(observation)
337
+ completion = _run_completion_command(
338
+ prompt=prompt,
339
+ seed=seed,
340
+ command=completion_command,
341
+ )
342
+ actions = parse_action_plan(completion)
343
+ trace = run_episode_with_actions(actions, seed_idx=seed)
344
+ traces.append(trace)
345
+ evaluations.append(
346
+ {
347
+ "seed": seed,
348
+ "prompt": prompt,
349
+ "completion": completion,
350
+ "parsed_action_count": len(actions),
351
+ "actions": [action.model_dump(exclude_none=True) for action in actions],
352
+ "trace": trace.asdict(),
353
+ }
354
+ )
355
+
356
+ return {
357
+ "created_at_utc": datetime.now(UTC).isoformat(),
358
+ "label": label,
359
+ "completion_command": completion_command,
360
+ "seeds": seeds,
361
+ "summary": summarize_traces(traces),
362
+ "episodes": evaluations,
363
+ }
364
+
365
+
366
  def write_monitor_summary(payload: dict[str, object]) -> None:
367
  summary = payload["summary"]
368
  print(
369
  "episodes="
370
  f"{summary['episode_count']} feasible={summary['feasible_episode_count']} "
371
  f"high_fidelity={summary['high_fidelity_episode_count']} "
372
+ f"failed={summary['evaluation_failed_episode_count']} "
373
+ f"mean_total_reward={_format_metric(summary['mean_total_reward'], signed=True)} "
374
+ f"mean_high_fidelity_score={_format_metric(summary['mean_high_fidelity_score'], signed=True)} "
375
+ f"reward_score_corr={summary['reward_final_score_correlation']}"
376
  )
377
  for episode in payload["episodes"]:
378
+ trace = episode.get("trace", episode)
379
  print(
380
  "seed="
381
+ f"{trace['seed']} total_reward={trace['total_reward']:+.4f} "
382
+ f"final_fidelity={trace['final_evaluation_fidelity']} "
383
+ f"feasible={trace['constraints_satisfied']} "
384
+ f"score={trace['final_score']:.6f} "
385
+ f"feasibility={trace['final_feasibility']:.6f}"
386
  )
387
+ if "parsed_action_count" in episode:
388
+ print(f" parsed_actions={episode['parsed_action_count']}")
389
+ for step in trace["steps"]:
390
  action_monitor = step["action_monitor"]
391
  print(
392
  " step="
 
409
  write_monitor_summary(payload)
410
 
411
 
412
+ def run_evaluate(args: argparse.Namespace) -> None:
413
+ seeds = parse_seed_list(args.seeds)
414
+ payload = evaluate_payload(
415
+ completion_command=args.completion_command,
416
+ label=args.label,
417
+ seeds=seeds,
418
+ )
419
+ timestamp = datetime.now(UTC).strftime("%Y%m%dT%H%M%SZ")
420
+ output_path = args.output_dir / f"llm_evaluate_{timestamp}.json"
421
+ write_json(output_path, payload)
422
+ print(output_path)
423
+ write_monitor_summary(payload)
424
+
425
+
426
  def main() -> None:
427
  args = parse_args()
428
  if args.command == "prompt":
 
431
  if args.command == "replay":
432
  run_replay(args)
433
  return
434
+ if args.command == "evaluate":
435
+ run_evaluate(args)
436
+ return
437
  run_monitor(args)
438
 
439