#!/usr/bin/env python3 """Generate supervised fine-tuning data directly from AdaptShield rollouts.""" from __future__ import annotations import argparse import json import random from pathlib import Path from typing import Any, Dict, List from models import AdaptShieldAction from server.adaptshield_environment import AdaptShieldEnvironment from train import ( TASKS, _current_reference, _teacher_payload, build_messages, obs_to_dict, render_messages, task_for_episode, ) from soc_tools import attach_tool_results, investigate_local_with_depth def build_dataset( selected_task: str, curriculum: bool, use_tools: bool, rollout_episodes: int, max_steps: int, seed: int, world_split: str, world_family: str | None, ) -> List[Dict[str, Any]]: random.seed(seed) rows: List[Dict[str, Any]] = [] for episode in range(1, rollout_episodes + 1): task, stage = task_for_episode( episode=episode, total_episodes=rollout_episodes, selected_task=selected_task, curriculum=curriculum, ) env = AdaptShieldEnvironment( task_name=task, world_split=world_split, world_family=world_family, ) obs = env.reset() step_count = 0 while not obs.done and step_count < max_steps: phase = int(getattr(obs, "phase", 1)) tool_results = investigate_local_with_depth( env, obs, use_tools=use_tools, thorough=(task == "polymorphic-zero-day"), ) obs_dict = attach_tool_results(obs_to_dict(obs), tool_results) messages = build_messages(obs_dict) reference = _current_reference(env) teacher_payload = _teacher_payload(phase, reference) response_text = json.dumps(teacher_payload, separators=(",", ":")) rows.append({ "task": task, "stage": stage, "episode": episode, "turn": int(getattr(obs, "turn", 0) or 0), "phase": phase, "attack_stage": reference["stage"], "world_split": getattr(env, "_world_split", world_split), "world_family": getattr(env, "_world_family", world_family or ""), "operational_mode": getattr(env, "_operational_mode", ""), "is_benign": bool(reference["is_benign"]), "expected_threat_type": reference["threat_type"], "expected_target_node": reference["target_node"], "expected_action": reference["expected_action"], "tool_calls": len(tool_results), "messages": messages, "response": response_text, "text": f"{render_messages(messages)}\n\nASSISTANT:\n{response_text}", }) obs = env.step(AdaptShieldAction(**teacher_payload)) step_count += 1 return rows def summarize_rows(rows: List[Dict[str, Any]]) -> Dict[str, Any]: by_task = {task: 0 for task in TASKS} by_phase = {1: 0, 2: 0} with_tools = 0 for row in rows: task = str(row.get("task", "")) phase = int(row.get("phase", 1) or 1) if task in by_task: by_task[task] += 1 by_phase[phase] = by_phase.get(phase, 0) + 1 if int(row.get("tool_calls", 0) or 0) > 0: with_tools += 1 return { "rows": len(rows), "task_counts": by_task, "phase_counts": by_phase, "rows_with_tool_calls": with_tools, } def main() -> None: parser = argparse.ArgumentParser(description="Generate AdaptShield SFT JSONL data") parser.add_argument( "--task", default="all", choices=["all", *TASKS], help="Task to sample. Use 'all' with --curriculum for mixed data.", ) parser.add_argument( "--episodes", type=int, default=120, help="Number of rollout episodes to sample.", ) parser.add_argument( "--max-steps", type=int, default=20, help="Maximum env steps per episode.", ) parser.add_argument( "--seed", type=int, default=42, help="Dataset generation seed.", ) parser.add_argument( "--curriculum", action="store_true", help="Use easy->medium->hard sampling schedule.", ) parser.add_argument( "--use-tools", action="store_true", help="Include SOC tool evidence in prompts where applicable.", ) parser.add_argument( "--output", default="data/adaptshield_sft.jsonl", help="Where to write the JSONL dataset.", ) parser.add_argument( "--world-split", default="train", choices=["train", "eval"], help="World-family split used to generate the dataset.", ) parser.add_argument( "--world-family", default=None, help="Optional fixed world family override (e.g. train-a, eval-x).", ) args = parser.parse_args() rows = build_dataset( selected_task=args.task, curriculum=args.curriculum, use_tools=args.use_tools, rollout_episodes=args.episodes, max_steps=args.max_steps, seed=args.seed, world_split=args.world_split, world_family=args.world_family, ) output_path = Path(args.output) output_path.parent.mkdir(parents=True, exist_ok=True) with output_path.open("w", encoding="utf-8") as handle: for row in rows: handle.write(json.dumps(row, ensure_ascii=True) + "\n") summary = summarize_rows(rows) summary_path = output_path.with_suffix(".summary.json") summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") print(f"Wrote {len(rows)} rows to {output_path}") print(f"Summary saved to {summary_path}") print(json.dumps(summary, indent=2)) if __name__ == "__main__": main()