File size: 7,895 Bytes
136ea72
 
 
a434e53
136ea72
 
 
a434e53
136ea72
 
 
 
 
 
 
 
 
 
 
 
 
 
1ae45f3
aad7819
136ea72
 
 
 
 
 
1ae45f3
136ea72
 
a434e53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136ea72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a434e53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136ea72
 
 
 
 
 
 
 
 
 
 
 
aad7819
 
 
 
 
136ea72
 
 
 
 
 
 
 
 
a434e53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c47715e
a434e53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136ea72
 
 
 
 
a434e53
 
 
 
 
 
 
 
c47715e
136ea72
 
 
 
 
 
a434e53
136ea72
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
from __future__ import annotations

"""
Onsite training entrypoint.

This file is intentionally import-light so it can run locally without GPU
packages. On the finale machine, install the training extras from pyproject and
run without --dry-run to train a small orchestrator policy with GRPO.
"""

import argparse
import json
import random
import re
import sys
from pathlib import Path

ROOT = Path(__file__).resolve().parents[1]
if str(ROOT) not in sys.path:
    sys.path.insert(0, str(ROOT))

from environment import SentinelEnv
from mission_context import build_orchestrator_prompt
from sentinel_config import ADVERSARIAL_AWARENESS_STAKES


ACTION_RE = re.compile(r"\{.*\}", re.DOTALL)


def build_prompt(observation: dict) -> str:
    return build_orchestrator_prompt(observation)


def build_dataset_records(episodes: int, task_type: str, seed: int) -> list[dict]:
    records = []
    task_choices = ["task1", "task2", "task3"] if task_type == "all" else [task_type]
    for idx in range(episodes):
        selected_task = task_choices[idx % len(task_choices)]
        env = SentinelEnv()
        result = env.reset(task_type=selected_task, seed=seed + idx)
        obs = result["observation"]
        records.append(
            {
                "prompt": build_prompt(obs),
                "task_type": selected_task,
                "seed": seed + idx,
            }
        )
    return records


def parse_action(text: str, observation: dict) -> dict:
    match = ACTION_RE.search(text or "")
    payload = {}
    if match:
        try:
            payload = json.loads(match.group(0))
        except json.JSONDecodeError:
            payload = {}

    action_type = payload.get("action_type", "delegate")
    specialist_id = payload.get("specialist_id")
    if action_type in ("delegate", "verify") and specialist_id not in observation["available_specialists"]:
        specialist_id = max(
            observation["available_specialists"],
            key=lambda sid: observation["trust_snapshot"].get(sid, 0.5),
        )
    if action_type == "solve_independently":
        specialist_id = None

    return {
        "session_id": observation["session_id"],
        "task_type": observation["task_type"],
        "action_type": action_type,
        "specialist_id": specialist_id,
        "subtask_response": "SELF_SOLVED" if action_type == "solve_independently" else None,
        "reasoning": payload.get("reasoning", "parsed-training-action"),
    }


def score_completion(completion: str, task_type: str, seed: int) -> float:
    env = SentinelEnv()
    result = env.reset(task_type=task_type, seed=seed)
    obs = result["observation"]
    action = parse_action(completion, obs)
    result = env.step(action)
    return float(result["reward"]["value"])


def sentinel_reward(completions, prompts=None, task_type=None, seed=None, **kwargs):
    rewards = []
    task_values = task_type or kwargs.get("task_type") or ["task3"] * len(completions)
    seed_values = seed or kwargs.get("seed") or list(range(len(completions)))
    for idx, completion in enumerate(completions):
        text = _completion_text(completion)
        try:
            rewards.append(score_completion(text, str(task_values[idx]), int(seed_values[idx])))
        except Exception:
            rewards.append(0.01)
    return rewards


def _completion_text(completion) -> str:
    if isinstance(completion, str):
        return completion
    if isinstance(completion, list):
        parts = []
        for item in completion:
            if isinstance(item, dict):
                parts.append(str(item.get("content", "")))
            else:
                parts.append(str(item))
        return "\n".join(parts)
    if isinstance(completion, dict):
        return str(completion.get("content", completion))
    return str(completion)


def dry_run_rollouts(episodes: int, seed: int) -> dict:
    rng = random.Random(seed)
    scores = []
    for idx in range(episodes):
        env = SentinelEnv()
        result = env.reset(task_type="task3", seed=seed + idx)
        while not result["done"]:
            obs = result["observation"]
            specialist = max(obs["available_specialists"], key=lambda sid: obs["trust_snapshot"].get(sid, 0.5))
            action = {
                "session_id": obs["session_id"],
                "task_type": obs["task_type"],
                "action_type": (
                    "verify"
                    if obs["stakes_level"] >= ADVERSARIAL_AWARENESS_STAKES and rng.random() < 0.5
                    else "delegate"
                ),
                "specialist_id": specialist,
                "subtask_response": None,
                "reasoning": "dry-run heuristic",
            }
            result = env.step(action)
        scores.append(result["info"]["score"])
    return {"episodes": episodes, "avg_score": round(sum(scores) / max(1, len(scores)), 4)}


def run_grpo(args) -> None:
    try:
        from datasets import Dataset
        from trl import GRPOConfig, GRPOTrainer
        from unsloth import FastLanguageModel
    except ImportError:
        print("Training dependencies are not installed locally.")
        print("Local check passed. For onsite GPU training run:")
        print("  pip install '.[training]'")
        print("  python training/train.py --episodes 300 --task all")
        return

    records = build_dataset_records(args.episodes, args.task, args.seed)
    dataset = Dataset.from_list(records)

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=args.model,
        max_seq_length=args.max_seq_length,
        load_in_4bit=True,
    )
    model = FastLanguageModel.get_peft_model(
        model,
        r=args.lora_rank,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
        lora_alpha=args.lora_rank,
    )

    config = GRPOConfig(
        output_dir=args.output_dir,
        learning_rate=args.learning_rate,
        num_train_epochs=args.epochs,
        per_device_train_batch_size=args.batch_size,
        num_generations=args.num_generations,
        logging_steps=10,
        save_steps=50,
        max_prompt_length=args.max_seq_length,
        max_completion_length=192,
    )

    trainer_kwargs = {
        "model": model,
        "reward_funcs": [sentinel_reward],
        "args": config,
        "train_dataset": dataset,
    }
    try:
        trainer = GRPOTrainer(processing_class=tokenizer, **trainer_kwargs)
    except TypeError:
        trainer = GRPOTrainer(tokenizer=tokenizer, **trainer_kwargs)

    trainer.train()
    model.save_pretrained(args.output_dir)
    tokenizer.save_pretrained(args.output_dir)
    print(f"Training complete. Saved LoRA adapter to {args.output_dir}")


def main() -> None:
    parser = argparse.ArgumentParser(description="SENTINEL GRPO training harness.")
    parser.add_argument("--dry-run", action="store_true", help="Run local rollouts without GPU dependencies.")
    parser.add_argument("--episodes", type=int, default=5)
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--task", default="task3", choices=["task1", "task2", "task3", "all"])
    parser.add_argument("--model", default="unsloth/Qwen2.5-1.5B-Instruct")
    parser.add_argument("--output-dir", default="training/sentinel_model")
    parser.add_argument("--epochs", type=int, default=1)
    parser.add_argument("--batch-size", type=int, default=2)
    parser.add_argument("--learning-rate", type=float, default=5e-6)
    parser.add_argument("--max-seq-length", type=int, default=1024)
    parser.add_argument("--lora-rank", type=int, default=16)
    parser.add_argument("--num-generations", type=int, default=2)
    args = parser.parse_args()

    if args.dry_run:
        print(json.dumps(dry_run_rollouts(args.episodes, args.seed), indent=2))
        return

    run_grpo(args)


if __name__ == "__main__":
    main()