Wildfire-Containment-Simulator / scripts /generate_sft_data.py
Eshit's picture
Improve wildfire metrics and training assets
ad92ece
"""
Generate supervised fine-tuning (SFT) training examples by running the
HeuristicAgent through episodes and recording (prompt, action) pairs.
Usage:
python scripts/generate_sft_data.py
python scripts/generate_sft_data.py --output training/sft_data.jsonl --easy-seeds 500
"""
from __future__ import annotations
import argparse
import json
import os
import random
import sys
from pathlib import Path
PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
sys.path.insert(0, PROJECT_ROOT)
from env.wildfire_env import WildfireEnv
from env.serialization import serialize_observation
from env.models import TIER_EASY, TIER_MEDIUM, TIER_HARD, ActionType
from agents.heuristic_agent import HeuristicAgent
SYSTEM_PROMPT = (
"You are an AI Incident Commander managing wildfire containment. "
"You will receive a situation briefing each step. "
"Respond with ONLY a valid JSON action object and nothing else. "
'Example: {"action_type": "idle"}'
)
TIER_CONFIGS = {
"easy": {"max_steps": TIER_EASY.episode_length, "target": 2000},
"medium": {"max_steps": TIER_MEDIUM.episode_length, "target": 1500},
"hard": {"max_steps": TIER_HARD.episode_length, "target": 800},
}
def run_episode(tier: str, seed: int) -> list[dict] | None:
"""Run a full episode with the HeuristicAgent.
Returns a list of raw (prompt, action, step) records for the episode,
or None if the episode is unsuccessful (population lost > 0).
"""
max_steps = TIER_CONFIGS[tier]["max_steps"]
env = WildfireEnv()
obs = env.reset(task_id=tier, seed=seed)
agent = HeuristicAgent()
offset = random.randint(0, min(30, max_steps // 4))
prev_cells_burning = 0
records: list[dict] = []
step_num = 0
while not env.done:
action = agent.act(obs)
if step_num >= offset:
prompt_text = serialize_observation(
obs, step_num, max_steps,
tier=tier, prev_cells_burning=prev_cells_burning,
)
action_json = action.model_dump_json(exclude_none=True)
records.append({
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt_text},
],
"completion": action_json,
"tier": tier,
"seed": seed,
"step": step_num,
"action_type": action.action_type.value,
})
prev_cells_burning = obs.stats.cells_burning
result = env.step(action)
obs = result.observation
step_num += 1
state = env.state()
if state["population_lost"] != 0:
return None
return records
def filter_idle(records: list[dict]) -> list[dict]:
"""Keep all non-IDLE steps, then cap IDLE steps at 20% of total."""
non_idle = [r for r in records if r["action_type"] != "idle"]
idle = [r for r in records if r["action_type"] == "idle"]
if not non_idle:
return idle
max_idle = max(1, int(len(non_idle) * 0.25))
if len(idle) > max_idle:
random.shuffle(idle)
idle = idle[:max_idle]
combined = non_idle + idle
combined.sort(key=lambda r: r["step"])
return combined
def strip_internal_fields(records: list[dict]) -> list[dict]:
"""Remove the action_type helper field before writing."""
for r in records:
r.pop("action_type", None)
return records
def generate(output_path: str, max_seeds: dict[str, int]) -> None:
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
all_examples: list[dict] = []
tier_counts = {t: 0 for t in TIER_CONFIGS}
for tier in ["easy", "medium", "hard"]:
target = TIER_CONFIGS[tier]["target"]
limit = max_seeds[tier]
seed = 0
print(f"\n{'='*50}")
print(f"Generating {tier} tier (target={target}, max_seeds={limit})")
print(f"{'='*50}")
while tier_counts[tier] < target and seed < limit:
records = run_episode(tier, seed)
if records is not None:
filtered = filter_idle(records)
remaining = target - tier_counts[tier]
if len(filtered) > remaining:
filtered = filtered[:remaining]
all_examples.extend(strip_internal_fields(filtered))
tier_counts[tier] += len(filtered)
seed += 1
if seed % 50 == 0:
print(f" [{tier}] seed={seed}, examples={tier_counts[tier]}/{target}")
print(f" [{tier}] DONE — {tier_counts[tier]} examples from {seed} seeds")
with open(output_path, "w", encoding="utf-8") as f:
for ex in all_examples:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
total = len(all_examples)
print(f"\n{'='*50}")
print(f"SFT data saved to {output_path}")
print(f"Total examples: {total}")
print(f"Tier distribution:")
for tier in ["easy", "medium", "hard"]:
print(f" {tier}: {tier_counts[tier]}")
print(f"{'='*50}")
def main():
parser = argparse.ArgumentParser(description="Generate SFT training data from HeuristicAgent episodes")
parser.add_argument("--output", default="training/sft_data.jsonl",
help="Output JSONL file path (default: training/sft_data.jsonl)")
parser.add_argument("--easy-seeds", type=int, default=500,
help="Max seeds to try for easy tier")
parser.add_argument("--medium-seeds", type=int, default=500,
help="Max seeds to try for medium tier")
parser.add_argument("--hard-seeds", type=int, default=500,
help="Max seeds to try for hard tier")
args = parser.parse_args()
max_seeds = {
"easy": args.easy_seeds,
"medium": args.medium_seeds,
"hard": args.hard_seeds,
}
generate(args.output, max_seeds)
if __name__ == "__main__":
main()