| """ |
| Generate supervised fine-tuning (SFT) training examples by running the |
| HeuristicAgent through episodes and recording (prompt, action) pairs. |
| |
| Usage: |
| python scripts/generate_sft_data.py |
| python scripts/generate_sft_data.py --output training/sft_data.jsonl --easy-seeds 500 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import random |
| import sys |
| from pathlib import Path |
|
|
| PROJECT_ROOT = str(Path(__file__).resolve().parent.parent) |
| sys.path.insert(0, PROJECT_ROOT) |
|
|
| from env.wildfire_env import WildfireEnv |
| from env.serialization import serialize_observation |
| from env.models import TIER_EASY, TIER_MEDIUM, TIER_HARD, ActionType |
| from agents.heuristic_agent import HeuristicAgent |
|
|
| SYSTEM_PROMPT = ( |
| "You are an AI Incident Commander managing wildfire containment. " |
| "You will receive a situation briefing each step. " |
| "Respond with ONLY a valid JSON action object and nothing else. " |
| 'Example: {"action_type": "idle"}' |
| ) |
|
|
| TIER_CONFIGS = { |
| "easy": {"max_steps": TIER_EASY.episode_length, "target": 2000}, |
| "medium": {"max_steps": TIER_MEDIUM.episode_length, "target": 1500}, |
| "hard": {"max_steps": TIER_HARD.episode_length, "target": 800}, |
| } |
|
|
|
|
| def run_episode(tier: str, seed: int) -> list[dict] | None: |
| """Run a full episode with the HeuristicAgent. |
| |
| Returns a list of raw (prompt, action, step) records for the episode, |
| or None if the episode is unsuccessful (population lost > 0). |
| """ |
| max_steps = TIER_CONFIGS[tier]["max_steps"] |
| env = WildfireEnv() |
| obs = env.reset(task_id=tier, seed=seed) |
| agent = HeuristicAgent() |
|
|
| offset = random.randint(0, min(30, max_steps // 4)) |
|
|
| prev_cells_burning = 0 |
| records: list[dict] = [] |
| step_num = 0 |
|
|
| while not env.done: |
| action = agent.act(obs) |
|
|
| if step_num >= offset: |
| prompt_text = serialize_observation( |
| obs, step_num, max_steps, |
| tier=tier, prev_cells_burning=prev_cells_burning, |
| ) |
| action_json = action.model_dump_json(exclude_none=True) |
| records.append({ |
| "messages": [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| {"role": "user", "content": prompt_text}, |
| ], |
| "completion": action_json, |
| "tier": tier, |
| "seed": seed, |
| "step": step_num, |
| "action_type": action.action_type.value, |
| }) |
|
|
| prev_cells_burning = obs.stats.cells_burning |
| result = env.step(action) |
| obs = result.observation |
| step_num += 1 |
|
|
| state = env.state() |
| if state["population_lost"] != 0: |
| return None |
|
|
| return records |
|
|
|
|
| def filter_idle(records: list[dict]) -> list[dict]: |
| """Keep all non-IDLE steps, then cap IDLE steps at 20% of total.""" |
| non_idle = [r for r in records if r["action_type"] != "idle"] |
| idle = [r for r in records if r["action_type"] == "idle"] |
|
|
| if not non_idle: |
| return idle |
|
|
| max_idle = max(1, int(len(non_idle) * 0.25)) |
| if len(idle) > max_idle: |
| random.shuffle(idle) |
| idle = idle[:max_idle] |
|
|
| combined = non_idle + idle |
| combined.sort(key=lambda r: r["step"]) |
| return combined |
|
|
|
|
| def strip_internal_fields(records: list[dict]) -> list[dict]: |
| """Remove the action_type helper field before writing.""" |
| for r in records: |
| r.pop("action_type", None) |
| return records |
|
|
|
|
| def generate(output_path: str, max_seeds: dict[str, int]) -> None: |
| os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) |
|
|
| all_examples: list[dict] = [] |
| tier_counts = {t: 0 for t in TIER_CONFIGS} |
|
|
| for tier in ["easy", "medium", "hard"]: |
| target = TIER_CONFIGS[tier]["target"] |
| limit = max_seeds[tier] |
| seed = 0 |
|
|
| print(f"\n{'='*50}") |
| print(f"Generating {tier} tier (target={target}, max_seeds={limit})") |
| print(f"{'='*50}") |
|
|
| while tier_counts[tier] < target and seed < limit: |
| records = run_episode(tier, seed) |
|
|
| if records is not None: |
| filtered = filter_idle(records) |
| remaining = target - tier_counts[tier] |
| if len(filtered) > remaining: |
| filtered = filtered[:remaining] |
| all_examples.extend(strip_internal_fields(filtered)) |
| tier_counts[tier] += len(filtered) |
|
|
| seed += 1 |
| if seed % 50 == 0: |
| print(f" [{tier}] seed={seed}, examples={tier_counts[tier]}/{target}") |
|
|
| print(f" [{tier}] DONE — {tier_counts[tier]} examples from {seed} seeds") |
|
|
| with open(output_path, "w", encoding="utf-8") as f: |
| for ex in all_examples: |
| f.write(json.dumps(ex, ensure_ascii=False) + "\n") |
|
|
| total = len(all_examples) |
| print(f"\n{'='*50}") |
| print(f"SFT data saved to {output_path}") |
| print(f"Total examples: {total}") |
| print(f"Tier distribution:") |
| for tier in ["easy", "medium", "hard"]: |
| print(f" {tier}: {tier_counts[tier]}") |
| print(f"{'='*50}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Generate SFT training data from HeuristicAgent episodes") |
| parser.add_argument("--output", default="training/sft_data.jsonl", |
| help="Output JSONL file path (default: training/sft_data.jsonl)") |
| parser.add_argument("--easy-seeds", type=int, default=500, |
| help="Max seeds to try for easy tier") |
| parser.add_argument("--medium-seeds", type=int, default=500, |
| help="Max seeds to try for medium tier") |
| parser.add_argument("--hard-seeds", type=int, default=500, |
| help="Max seeds to try for hard tier") |
| args = parser.parse_args() |
|
|
| max_seeds = { |
| "easy": args.easy_seeds, |
| "medium": args.medium_seeds, |
| "hard": args.hard_seeds, |
| } |
| generate(args.output, max_seeds) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|