from __future__ import annotations """ Onsite training entrypoint. This file is intentionally import-light so it can run locally without GPU packages. On the finale machine, install the training extras from pyproject and run without --dry-run to train a small orchestrator policy with GRPO. """ import argparse import json import random import re import sys from pathlib import Path ROOT = Path(__file__).resolve().parents[1] if str(ROOT) not in sys.path: sys.path.insert(0, str(ROOT)) from environment import SentinelEnv from mission_context import build_orchestrator_prompt from sentinel_config import ADVERSARIAL_AWARENESS_STAKES ACTION_RE = re.compile(r"\{.*\}", re.DOTALL) def build_prompt(observation: dict) -> str: return build_orchestrator_prompt(observation) def build_dataset_records(episodes: int, task_type: str, seed: int) -> list[dict]: records = [] task_choices = ["task1", "task2", "task3"] if task_type == "all" else [task_type] for idx in range(episodes): selected_task = task_choices[idx % len(task_choices)] env = SentinelEnv() result = env.reset(task_type=selected_task, seed=seed + idx) obs = result["observation"] records.append( { "prompt": build_prompt(obs), "task_type": selected_task, "seed": seed + idx, } ) return records def parse_action(text: str, observation: dict) -> dict: match = ACTION_RE.search(text or "") payload = {} if match: try: payload = json.loads(match.group(0)) except json.JSONDecodeError: payload = {} action_type = payload.get("action_type", "delegate") specialist_id = payload.get("specialist_id") if action_type in ("delegate", "verify") and specialist_id not in observation["available_specialists"]: specialist_id = max( observation["available_specialists"], key=lambda sid: observation["trust_snapshot"].get(sid, 0.5), ) if action_type == "solve_independently": specialist_id = None return { "session_id": observation["session_id"], "task_type": observation["task_type"], "action_type": action_type, "specialist_id": specialist_id, "subtask_response": "SELF_SOLVED" if action_type == "solve_independently" else None, "reasoning": payload.get("reasoning", "parsed-training-action"), } def score_completion(completion: str, task_type: str, seed: int) -> float: env = SentinelEnv() result = env.reset(task_type=task_type, seed=seed) obs = result["observation"] action = parse_action(completion, obs) result = env.step(action) return float(result["reward"]["value"]) def sentinel_reward(completions, prompts=None, task_type=None, seed=None, **kwargs): rewards = [] task_values = task_type or kwargs.get("task_type") or ["task3"] * len(completions) seed_values = seed or kwargs.get("seed") or list(range(len(completions))) for idx, completion in enumerate(completions): text = _completion_text(completion) try: rewards.append(score_completion(text, str(task_values[idx]), int(seed_values[idx]))) except Exception: rewards.append(0.01) return rewards def _completion_text(completion) -> str: if isinstance(completion, str): return completion if isinstance(completion, list): parts = [] for item in completion: if isinstance(item, dict): parts.append(str(item.get("content", ""))) else: parts.append(str(item)) return "\n".join(parts) if isinstance(completion, dict): return str(completion.get("content", completion)) return str(completion) def dry_run_rollouts(episodes: int, seed: int) -> dict: rng = random.Random(seed) scores = [] for idx in range(episodes): env = SentinelEnv() result = env.reset(task_type="task3", seed=seed + idx) while not result["done"]: obs = result["observation"] specialist = max(obs["available_specialists"], key=lambda sid: obs["trust_snapshot"].get(sid, 0.5)) action = { "session_id": obs["session_id"], "task_type": obs["task_type"], "action_type": ( "verify" if obs["stakes_level"] >= ADVERSARIAL_AWARENESS_STAKES and rng.random() < 0.5 else "delegate" ), "specialist_id": specialist, "subtask_response": None, "reasoning": "dry-run heuristic", } result = env.step(action) scores.append(result["info"]["score"]) return {"episodes": episodes, "avg_score": round(sum(scores) / max(1, len(scores)), 4)} def run_grpo(args) -> None: try: from datasets import Dataset from trl import GRPOConfig, GRPOTrainer from unsloth import FastLanguageModel except ImportError: print("Training dependencies are not installed locally.") print("Local check passed. For onsite GPU training run:") print(" pip install '.[training]'") print(" python training/train.py --episodes 300 --task all") return records = build_dataset_records(args.episodes, args.task, args.seed) dataset = Dataset.from_list(records) model, tokenizer = FastLanguageModel.from_pretrained( model_name=args.model, max_seq_length=args.max_seq_length, load_in_4bit=True, ) model = FastLanguageModel.get_peft_model( model, r=args.lora_rank, target_modules=["q_proj", "k_proj", "v_proj", "o_proj"], lora_alpha=args.lora_rank, ) config = GRPOConfig( output_dir=args.output_dir, learning_rate=args.learning_rate, num_train_epochs=args.epochs, per_device_train_batch_size=args.batch_size, num_generations=args.num_generations, logging_steps=10, save_steps=50, max_prompt_length=args.max_seq_length, max_completion_length=192, ) trainer_kwargs = { "model": model, "reward_funcs": [sentinel_reward], "args": config, "train_dataset": dataset, } try: trainer = GRPOTrainer(processing_class=tokenizer, **trainer_kwargs) except TypeError: trainer = GRPOTrainer(tokenizer=tokenizer, **trainer_kwargs) trainer.train() model.save_pretrained(args.output_dir) tokenizer.save_pretrained(args.output_dir) print(f"Training complete. Saved LoRA adapter to {args.output_dir}") def main() -> None: parser = argparse.ArgumentParser(description="SENTINEL GRPO training harness.") parser.add_argument("--dry-run", action="store_true", help="Run local rollouts without GPU dependencies.") parser.add_argument("--episodes", type=int, default=5) parser.add_argument("--seed", type=int, default=0) parser.add_argument("--task", default="task3", choices=["task1", "task2", "task3", "all"]) parser.add_argument("--model", default="unsloth/Qwen2.5-1.5B-Instruct") parser.add_argument("--output-dir", default="training/sentinel_model") parser.add_argument("--epochs", type=int, default=1) parser.add_argument("--batch-size", type=int, default=2) parser.add_argument("--learning-rate", type=float, default=5e-6) parser.add_argument("--max-seq-length", type=int, default=1024) parser.add_argument("--lora-rank", type=int, default=16) parser.add_argument("--num-generations", type=int, default=2) args = parser.parse_args() if args.dry_run: print(json.dumps(dry_run_rollouts(args.episodes, args.seed), indent=2)) return run_grpo(args) if __name__ == "__main__": main()