File size: 6,001 Bytes
ad92ece | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 | """
Generate supervised fine-tuning (SFT) training examples by running the
HeuristicAgent through episodes and recording (prompt, action) pairs.
Usage:
python scripts/generate_sft_data.py
python scripts/generate_sft_data.py --output training/sft_data.jsonl --easy-seeds 500
"""
from __future__ import annotations
import argparse
import json
import os
import random
import sys
from pathlib import Path
PROJECT_ROOT = str(Path(__file__).resolve().parent.parent)
sys.path.insert(0, PROJECT_ROOT)
from env.wildfire_env import WildfireEnv
from env.serialization import serialize_observation
from env.models import TIER_EASY, TIER_MEDIUM, TIER_HARD, ActionType
from agents.heuristic_agent import HeuristicAgent
SYSTEM_PROMPT = (
"You are an AI Incident Commander managing wildfire containment. "
"You will receive a situation briefing each step. "
"Respond with ONLY a valid JSON action object and nothing else. "
'Example: {"action_type": "idle"}'
)
TIER_CONFIGS = {
"easy": {"max_steps": TIER_EASY.episode_length, "target": 2000},
"medium": {"max_steps": TIER_MEDIUM.episode_length, "target": 1500},
"hard": {"max_steps": TIER_HARD.episode_length, "target": 800},
}
def run_episode(tier: str, seed: int) -> list[dict] | None:
"""Run a full episode with the HeuristicAgent.
Returns a list of raw (prompt, action, step) records for the episode,
or None if the episode is unsuccessful (population lost > 0).
"""
max_steps = TIER_CONFIGS[tier]["max_steps"]
env = WildfireEnv()
obs = env.reset(task_id=tier, seed=seed)
agent = HeuristicAgent()
offset = random.randint(0, min(30, max_steps // 4))
prev_cells_burning = 0
records: list[dict] = []
step_num = 0
while not env.done:
action = agent.act(obs)
if step_num >= offset:
prompt_text = serialize_observation(
obs, step_num, max_steps,
tier=tier, prev_cells_burning=prev_cells_burning,
)
action_json = action.model_dump_json(exclude_none=True)
records.append({
"messages": [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": prompt_text},
],
"completion": action_json,
"tier": tier,
"seed": seed,
"step": step_num,
"action_type": action.action_type.value,
})
prev_cells_burning = obs.stats.cells_burning
result = env.step(action)
obs = result.observation
step_num += 1
state = env.state()
if state["population_lost"] != 0:
return None
return records
def filter_idle(records: list[dict]) -> list[dict]:
"""Keep all non-IDLE steps, then cap IDLE steps at 20% of total."""
non_idle = [r for r in records if r["action_type"] != "idle"]
idle = [r for r in records if r["action_type"] == "idle"]
if not non_idle:
return idle
max_idle = max(1, int(len(non_idle) * 0.25))
if len(idle) > max_idle:
random.shuffle(idle)
idle = idle[:max_idle]
combined = non_idle + idle
combined.sort(key=lambda r: r["step"])
return combined
def strip_internal_fields(records: list[dict]) -> list[dict]:
"""Remove the action_type helper field before writing."""
for r in records:
r.pop("action_type", None)
return records
def generate(output_path: str, max_seeds: dict[str, int]) -> None:
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
all_examples: list[dict] = []
tier_counts = {t: 0 for t in TIER_CONFIGS}
for tier in ["easy", "medium", "hard"]:
target = TIER_CONFIGS[tier]["target"]
limit = max_seeds[tier]
seed = 0
print(f"\n{'='*50}")
print(f"Generating {tier} tier (target={target}, max_seeds={limit})")
print(f"{'='*50}")
while tier_counts[tier] < target and seed < limit:
records = run_episode(tier, seed)
if records is not None:
filtered = filter_idle(records)
remaining = target - tier_counts[tier]
if len(filtered) > remaining:
filtered = filtered[:remaining]
all_examples.extend(strip_internal_fields(filtered))
tier_counts[tier] += len(filtered)
seed += 1
if seed % 50 == 0:
print(f" [{tier}] seed={seed}, examples={tier_counts[tier]}/{target}")
print(f" [{tier}] DONE — {tier_counts[tier]} examples from {seed} seeds")
with open(output_path, "w", encoding="utf-8") as f:
for ex in all_examples:
f.write(json.dumps(ex, ensure_ascii=False) + "\n")
total = len(all_examples)
print(f"\n{'='*50}")
print(f"SFT data saved to {output_path}")
print(f"Total examples: {total}")
print(f"Tier distribution:")
for tier in ["easy", "medium", "hard"]:
print(f" {tier}: {tier_counts[tier]}")
print(f"{'='*50}")
def main():
parser = argparse.ArgumentParser(description="Generate SFT training data from HeuristicAgent episodes")
parser.add_argument("--output", default="training/sft_data.jsonl",
help="Output JSONL file path (default: training/sft_data.jsonl)")
parser.add_argument("--easy-seeds", type=int, default=500,
help="Max seeds to try for easy tier")
parser.add_argument("--medium-seeds", type=int, default=500,
help="Max seeds to try for medium tier")
parser.add_argument("--hard-seeds", type=int, default=500,
help="Max seeds to try for hard tier")
args = parser.parse_args()
max_seeds = {
"easy": args.easy_seeds,
"medium": args.medium_seeds,
"hard": args.hard_seeds,
}
generate(args.output, max_seeds)
if __name__ == "__main__":
main()
|