Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """Generate supervised fine-tuning data directly from AdaptShield rollouts.""" | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import random | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| from models import AdaptShieldAction | |
| from server.adaptshield_environment import AdaptShieldEnvironment | |
| from train import ( | |
| TASKS, | |
| _current_reference, | |
| _teacher_payload, | |
| build_messages, | |
| obs_to_dict, | |
| render_messages, | |
| task_for_episode, | |
| ) | |
| from soc_tools import attach_tool_results, investigate_local_with_depth | |
| def build_dataset( | |
| selected_task: str, | |
| curriculum: bool, | |
| use_tools: bool, | |
| rollout_episodes: int, | |
| max_steps: int, | |
| seed: int, | |
| world_split: str, | |
| world_family: str | None, | |
| ) -> List[Dict[str, Any]]: | |
| random.seed(seed) | |
| rows: List[Dict[str, Any]] = [] | |
| for episode in range(1, rollout_episodes + 1): | |
| task, stage = task_for_episode( | |
| episode=episode, | |
| total_episodes=rollout_episodes, | |
| selected_task=selected_task, | |
| curriculum=curriculum, | |
| ) | |
| env = AdaptShieldEnvironment( | |
| task_name=task, | |
| world_split=world_split, | |
| world_family=world_family, | |
| ) | |
| obs = env.reset() | |
| step_count = 0 | |
| while not obs.done and step_count < max_steps: | |
| phase = int(getattr(obs, "phase", 1)) | |
| tool_results = investigate_local_with_depth( | |
| env, | |
| obs, | |
| use_tools=use_tools, | |
| thorough=(task == "polymorphic-zero-day"), | |
| ) | |
| obs_dict = attach_tool_results(obs_to_dict(obs), tool_results) | |
| messages = build_messages(obs_dict) | |
| reference = _current_reference(env) | |
| teacher_payload = _teacher_payload(phase, reference) | |
| response_text = json.dumps(teacher_payload, separators=(",", ":")) | |
| rows.append({ | |
| "task": task, | |
| "stage": stage, | |
| "episode": episode, | |
| "turn": int(getattr(obs, "turn", 0) or 0), | |
| "phase": phase, | |
| "attack_stage": reference["stage"], | |
| "world_split": getattr(env, "_world_split", world_split), | |
| "world_family": getattr(env, "_world_family", world_family or ""), | |
| "operational_mode": getattr(env, "_operational_mode", ""), | |
| "is_benign": bool(reference["is_benign"]), | |
| "expected_threat_type": reference["threat_type"], | |
| "expected_target_node": reference["target_node"], | |
| "expected_action": reference["expected_action"], | |
| "tool_calls": len(tool_results), | |
| "messages": messages, | |
| "response": response_text, | |
| "text": f"{render_messages(messages)}\n\nASSISTANT:\n{response_text}", | |
| }) | |
| obs = env.step(AdaptShieldAction(**teacher_payload)) | |
| step_count += 1 | |
| return rows | |
| def summarize_rows(rows: List[Dict[str, Any]]) -> Dict[str, Any]: | |
| by_task = {task: 0 for task in TASKS} | |
| by_phase = {1: 0, 2: 0} | |
| with_tools = 0 | |
| for row in rows: | |
| task = str(row.get("task", "")) | |
| phase = int(row.get("phase", 1) or 1) | |
| if task in by_task: | |
| by_task[task] += 1 | |
| by_phase[phase] = by_phase.get(phase, 0) + 1 | |
| if int(row.get("tool_calls", 0) or 0) > 0: | |
| with_tools += 1 | |
| return { | |
| "rows": len(rows), | |
| "task_counts": by_task, | |
| "phase_counts": by_phase, | |
| "rows_with_tool_calls": with_tools, | |
| } | |
| def main() -> None: | |
| parser = argparse.ArgumentParser(description="Generate AdaptShield SFT JSONL data") | |
| parser.add_argument( | |
| "--task", | |
| default="all", | |
| choices=["all", *TASKS], | |
| help="Task to sample. Use 'all' with --curriculum for mixed data.", | |
| ) | |
| parser.add_argument( | |
| "--episodes", | |
| type=int, | |
| default=120, | |
| help="Number of rollout episodes to sample.", | |
| ) | |
| parser.add_argument( | |
| "--max-steps", | |
| type=int, | |
| default=20, | |
| help="Maximum env steps per episode.", | |
| ) | |
| parser.add_argument( | |
| "--seed", | |
| type=int, | |
| default=42, | |
| help="Dataset generation seed.", | |
| ) | |
| parser.add_argument( | |
| "--curriculum", | |
| action="store_true", | |
| help="Use easy->medium->hard sampling schedule.", | |
| ) | |
| parser.add_argument( | |
| "--use-tools", | |
| action="store_true", | |
| help="Include SOC tool evidence in prompts where applicable.", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| default="data/adaptshield_sft.jsonl", | |
| help="Where to write the JSONL dataset.", | |
| ) | |
| parser.add_argument( | |
| "--world-split", | |
| default="train", | |
| choices=["train", "eval"], | |
| help="World-family split used to generate the dataset.", | |
| ) | |
| parser.add_argument( | |
| "--world-family", | |
| default=None, | |
| help="Optional fixed world family override (e.g. train-a, eval-x).", | |
| ) | |
| args = parser.parse_args() | |
| rows = build_dataset( | |
| selected_task=args.task, | |
| curriculum=args.curriculum, | |
| use_tools=args.use_tools, | |
| rollout_episodes=args.episodes, | |
| max_steps=args.max_steps, | |
| seed=args.seed, | |
| world_split=args.world_split, | |
| world_family=args.world_family, | |
| ) | |
| output_path = Path(args.output) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| with output_path.open("w", encoding="utf-8") as handle: | |
| for row in rows: | |
| handle.write(json.dumps(row, ensure_ascii=True) + "\n") | |
| summary = summarize_rows(rows) | |
| summary_path = output_path.with_suffix(".summary.json") | |
| summary_path.write_text(json.dumps(summary, indent=2), encoding="utf-8") | |
| print(f"Wrote {len(rows)} rows to {output_path}") | |
| print(f"Summary saved to {summary_path}") | |
| print(json.dumps(summary, indent=2)) | |
| if __name__ == "__main__": | |
| main() | |