nihalaninihal Claude Opus 4.6 commited on
Commit
c7d253a
Β·
1 Parent(s): 3ffb78a

Align with Advanced Llama 3.2 GRPO LoRA reference notebook pattern

Browse files

Key changes to match the official Unsloth reference:
- 4 separate reward functions (format exact, format approx, action
correctness, environment-executing) instead of 1 combined
- lora_rank=64, lora_alpha=lora_rank (was r=16, alpha=32)
- learning_rate=5e-6 (was 5e-5, 10x off)
- Added: weight_decay=0.1, warmup_ratio=0.1, lr_scheduler_type=cosine,
optim=adamw_8bit, max_grad_norm=1.0
- num_generations=4 (was 2), max_steps=500 (was 300)
- max_seq_length=2048 (was 768)
- UNSLOTH_VLLM_STANDBY=1 env var for faster vLLM startup
- random_state=3407 for reproducibility
- tokenizer= instead of processing_class= (TRL 0.22.2 API)
- Pinned transformers==4.56.2, trl==0.22.2 in notebook
- save_steps=250 (was 50)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

Files changed (2) hide show
  1. train.py +175 -63
  2. training/colab_training.ipynb +6 -6
train.py CHANGED
@@ -3,11 +3,11 @@ SentinelOps Arena β€” Multi-Agent Training Script
3
  =================================================
4
  GRPO training for Worker, Attacker, and Oversight agents using TRL + Unsloth.
5
 
6
- Follows the official OpenEnv + Unsloth GRPO reference patterns:
7
  - BF16 precision on H100 (load_in_4bit=False)
8
  - vLLM fast inference (fast_inference=True)
9
- - Environment-executing reward functions (completions run in SentinelOpsArena)
10
- - LoRA with lora_alpha = 2 * lora_rank
11
 
12
  Each agent learns its role:
13
  - Worker: handle enterprise tasks, resist attacks, maintain compliance
@@ -24,8 +24,12 @@ Usage:
24
 
25
  import argparse
26
  import json
 
27
  import random
28
 
 
 
 
29
  from sentinelops_arena.environment import SentinelOpsArena
30
  from sentinelops_arena.models import AgentRole, SentinelAction
31
 
@@ -542,72 +546,175 @@ def _execute_action_in_env(action: SentinelAction, agent_role: str, seed: int =
542
  return obs.reward
543
 
544
 
545
- def make_reward_function(agent_role: str):
546
- """Create an environment-executing reward function for GRPO.
547
-
548
- Follows the official OpenEnv + Unsloth GRPO pattern:
549
- 1. Parse LLM completion into a SentinelAction
550
- 2. Execute it in a fresh SentinelOpsArena environment
551
- 3. Return real environment reward + format bonus
552
 
553
- This replaces pure text-matching with actual environment feedback,
554
- which is the key differentiator in the OpenEnv hackathon.
555
  """
556
- def reward_fn(completions, **kwargs):
557
- rewards = []
558
- for i, completion in enumerate(completions):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
559
  text = completion[0]["content"] if isinstance(completion, list) else str(completion)
560
-
561
- # Step 1: Parse completion into action
562
  action = _parse_completion_to_action(text, agent_role)
563
-
564
  if action is None:
565
- # Invalid output β€” strong negative signal
566
- rewards.append(-1.0)
567
  continue
568
 
569
- # Step 2: Format validation bonus (valid JSON + correct fields)
570
- format_bonus = 0.5
571
-
572
- # Step 3: Execute in environment for real reward
573
- try:
574
- env_reward = _execute_action_in_env(
575
- action, agent_role, seed=42 + i
576
- )
577
- except Exception:
578
- env_reward = 0.0
579
-
580
- # Step 4: Role-specific quality bonus
581
- quality_bonus = 0.0
582
  try:
583
  data = json.loads(text[text.find("{"):text.rfind("}") + 1])
584
  if agent_role == "worker":
585
  at = data.get("action_type", "")
 
 
586
  if at in ("get_schema", "get_current_policy"):
587
- quality_bonus = 0.5 # defensive actions
588
  elif at == "respond":
589
  resp = data.get("response_text", "").lower()
590
  if any(w in resp for w in ["cannot", "verify", "social engineering"]):
591
- quality_bonus = 1.0 # resisting social engineering
592
  elif agent_role == "attacker":
593
- params = data.get("parameters", {})
594
- at_type = params.get("attack_type", "")
595
- target = params.get("target_system", "")
596
- if at_type in VALID_ATTACKS and target in VALID_TARGETS_FOR_ATTACK.get(at_type, []):
597
- quality_bonus = 0.3 # valid attack + target combo
 
 
 
 
 
 
598
  elif agent_role == "oversight":
 
 
 
599
  explanation = data.get("explanation", "")
600
  if explanation and len(explanation) > 50:
601
- quality_bonus = 0.5 # quality explanation
 
 
602
  except (json.JSONDecodeError, ValueError):
603
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
604
 
605
- # Combined reward: environment signal + format + quality
606
- total = env_reward + format_bonus + quality_bonus
607
- rewards.append(total)
608
- return rewards
 
609
 
610
- return reward_fn
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
611
 
612
 
613
  # -------------------------------------------------------------------
@@ -708,13 +815,14 @@ def train_single_agent(role: str, args):
708
 
709
  # --- Step 3: Load model ---
710
  print(f"\n[3/4] Loading model: {args.model_name}...")
711
- lora_rank = 16
 
712
  if args.use_unsloth:
713
  from unsloth import FastLanguageModel
714
 
715
  model, tokenizer = FastLanguageModel.from_pretrained(
716
  model_name=args.model_name,
717
- max_seq_length=768,
718
  load_in_4bit=False, # BF16 for H100s (official recommendation)
719
  fast_inference=True, # vLLM for fast GRPO generation
720
  max_lora_rank=lora_rank,
@@ -727,10 +835,9 @@ def train_single_agent(role: str, args):
727
  "q_proj", "k_proj", "v_proj", "o_proj",
728
  "gate_proj", "up_proj", "down_proj",
729
  ],
730
- lora_alpha=lora_rank * 2, # Official: lora_alpha = 2 * lora_rank
731
- lora_dropout=0,
732
- bias="none",
733
  use_gradient_checkpointing="unsloth",
 
734
  )
735
  print(f" Loaded with Unsloth (BF16 + vLLM + LoRA r={lora_rank})")
736
  else:
@@ -748,27 +855,32 @@ def train_single_agent(role: str, args):
748
 
749
  from trl import GRPOConfig, GRPOTrainer
750
 
751
- reward_fn = make_reward_function(role)
752
 
 
753
  grpo_config = GRPOConfig(
754
  output_dir=output_dir,
755
  max_steps=args.max_steps,
756
  per_device_train_batch_size=1,
757
  gradient_accumulation_steps=4,
758
- num_generations=2, # GRPO group size (official recommendation)
759
- max_completion_length=256,
760
- max_prompt_length=512,
761
- learning_rate=5e-5, # Official reference: 5e-5
762
- temperature=1.0, # Official reference: 1.0
 
 
 
 
763
  logging_steps=1,
764
- save_steps=50,
765
  report_to="none",
766
  )
767
 
768
  trainer = GRPOTrainer(
769
  model=model,
770
- processing_class=tokenizer,
771
- reward_funcs=[reward_fn],
772
  args=grpo_config,
773
  train_dataset=train_dataset,
774
  )
@@ -804,8 +916,8 @@ def main():
804
  help="Use Unsloth for BF16 + vLLM fast inference",
805
  )
806
  parser.add_argument(
807
- "--max_steps", type=int, default=300,
808
- help="Max training steps (official recommendation: 300)",
809
  )
810
  parser.add_argument(
811
  "--num_episodes", type=int, default=20,
 
3
  =================================================
4
  GRPO training for Worker, Attacker, and Oversight agents using TRL + Unsloth.
5
 
6
+ Follows the official Unsloth Advanced GRPO LoRA reference pattern:
7
  - BF16 precision on H100 (load_in_4bit=False)
8
  - vLLM fast inference (fast_inference=True)
9
+ - Multiple reward functions (format + environment-executing)
10
+ - LoRA with lora_alpha = lora_rank, adamw_8bit optimizer, cosine scheduler
11
 
12
  Each agent learns its role:
13
  - Worker: handle enterprise tasks, resist attacks, maintain compliance
 
24
 
25
  import argparse
26
  import json
27
+ import os
28
  import random
29
 
30
+ # Pre-start vLLM standby for faster inference (official pattern)
31
+ os.environ["UNSLOTH_VLLM_STANDBY"] = "1"
32
+
33
  from sentinelops_arena.environment import SentinelOpsArena
34
  from sentinelops_arena.models import AgentRole, SentinelAction
35
 
 
546
  return obs.reward
547
 
548
 
549
+ def match_json_format_exactly(completions, **kwargs):
550
+ """Reward 1: Does the completion contain a valid JSON action object?
 
 
 
 
 
551
 
552
+ Mirrors the reference pattern's `match_format_exactly`.
553
+ Validates: parseable JSON with an 'action_type' field.
554
  """
555
+ scores = []
556
+ for completion in completions:
557
+ text = completion[0]["content"] if isinstance(completion, list) else str(completion)
558
+ score = 0.0
559
+ try:
560
+ start = text.find("{")
561
+ end = text.rfind("}") + 1
562
+ if start >= 0 and end > start:
563
+ data = json.loads(text[start:end])
564
+ if "action_type" in data:
565
+ score = 3.0
566
+ except (json.JSONDecodeError, ValueError):
567
+ pass
568
+ scores.append(score)
569
+ return scores
570
+
571
+
572
+ def match_json_format_approximately(completions, **kwargs):
573
+ """Reward 2: Partial credit for JSON-like structure.
574
+
575
+ Mirrors the reference pattern's `match_format_approximately`.
576
+ Checks for balanced braces, action_type field, and clean output.
577
+ """
578
+ scores = []
579
+ for completion in completions:
580
+ text = completion[0]["content"] if isinstance(completion, list) else str(completion)
581
+ score = 0.0
582
+ # Balanced braces (nested JSON is fine)
583
+ score += 0.5 if text.count("{") == text.count("}") and text.count("{") >= 1 else -1.0
584
+ # Has action_type field
585
+ score += 0.5 if '"action_type"' in text else -1.0
586
+ # Starts with JSON (clean output, no preamble)
587
+ score += 0.5 if text.strip().startswith("{") else -1.0
588
+ # Ends with JSON (no trailing text)
589
+ score += 0.5 if text.strip().endswith("}") else -1.0
590
+ scores.append(score)
591
+ return scores
592
+
593
+
594
+ def make_action_correctness_reward(agent_role: str):
595
+ """Reward 3: Is the action valid for this agent role?
596
+
597
+ Mirrors the reference pattern's `check_answer` β€” verifies the
598
+ extracted action is semantically correct for the role.
599
+ """
600
+ def check_action(completions, **kwargs):
601
+ scores = []
602
+ for completion in completions:
603
  text = completion[0]["content"] if isinstance(completion, list) else str(completion)
 
 
604
  action = _parse_completion_to_action(text, agent_role)
 
605
  if action is None:
606
+ scores.append(0.0)
 
607
  continue
608
 
609
+ score = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
610
  try:
611
  data = json.loads(text[text.find("{"):text.rfind("}") + 1])
612
  if agent_role == "worker":
613
  at = data.get("action_type", "")
614
+ if at in VALID_WORKER_ACTIONS:
615
+ score += 1.5
616
  if at in ("get_schema", "get_current_policy"):
617
+ score += 1.5 # defensive actions bonus
618
  elif at == "respond":
619
  resp = data.get("response_text", "").lower()
620
  if any(w in resp for w in ["cannot", "verify", "social engineering"]):
621
+ score += 3.0 # resisting social engineering
622
  elif agent_role == "attacker":
623
+ at = data.get("action_type", "")
624
+ if at == "launch_attack":
625
+ params = data.get("parameters", {})
626
+ at_type = params.get("attack_type", "")
627
+ target = params.get("target_system", "")
628
+ if at_type in VALID_ATTACKS:
629
+ score += 1.0
630
+ if target in VALID_TARGETS_FOR_ATTACK.get(at_type, []):
631
+ score += 1.5
632
+ elif at == "pass":
633
+ score += 0.5
634
  elif agent_role == "oversight":
635
+ at = data.get("action_type", "")
636
+ if at in ("flag", "approve"):
637
+ score += 1.0
638
  explanation = data.get("explanation", "")
639
  if explanation and len(explanation) > 50:
640
+ score += 1.5
641
+ if explanation and len(explanation) > 20:
642
+ score += 0.5
643
  except (json.JSONDecodeError, ValueError):
644
+ score = -1.5
645
+
646
+ scores.append(score)
647
+ return scores
648
+ return check_action
649
+
650
+
651
+ def make_environment_reward(agent_role: str):
652
+ """Reward 4: Execute the action in a live SentinelOps environment.
653
+
654
+ Follows the OpenEnv 2048 reference pattern: reward functions create
655
+ a fresh environment, execute the completion, and return the real reward.
656
+ Mirrors the reference pattern's `check_numbers` (ground truth check).
657
+ """
658
+ global _ENV_REWARD_PRINTED_TIMES
659
+ _ENV_REWARD_PRINTED_TIMES = 0
660
+
661
+ def check_env(completions, **kwargs):
662
+ global _ENV_REWARD_PRINTED_TIMES
663
+ scores = []
664
+ for i, completion in enumerate(completions):
665
+ text = completion[0]["content"] if isinstance(completion, list) else str(completion)
666
+ action = _parse_completion_to_action(text, agent_role)
667
+
668
+ if action is None:
669
+ scores.append(0.0)
670
+ continue
671
+
672
+ try:
673
+ env_reward = _execute_action_in_env(action, agent_role, seed=42 + i)
674
+ scores.append(env_reward * 1.5) # Scale env reward for impact
675
+ except Exception:
676
+ scores.append(0.0)
677
 
678
+ # Print sample every 5 steps (matches reference debug pattern)
679
+ if _ENV_REWARD_PRINTED_TIMES % 5 == 0 and i == 0:
680
+ print(f" [{agent_role}] completion: {text[:100]}...")
681
+ print(f" [{agent_role}] env_reward: {scores[-1]:.2f}")
682
+ _ENV_REWARD_PRINTED_TIMES += 1
683
 
684
+ return scores
685
+ return check_env
686
+
687
+
688
+ _ENV_REWARD_PRINTED_TIMES = 0
689
+
690
+
691
+ def make_reward_functions(agent_role: str) -> list:
692
+ """Create the full set of reward functions for GRPO.
693
+
694
+ Returns 4 reward functions matching the reference notebook pattern:
695
+ 1. match_json_format_exactly β€” strict format check
696
+ 2. match_json_format_approximately β€” partial format credit
697
+ 3. check_action β€” role-specific action correctness
698
+ 4. check_env β€” environment-executing reward
699
+
700
+ Usage: reward_funcs = make_reward_functions("worker")
701
+ """
702
+ return [
703
+ match_json_format_exactly,
704
+ match_json_format_approximately,
705
+ make_action_correctness_reward(agent_role),
706
+ make_environment_reward(agent_role),
707
+ ]
708
+
709
+
710
+ # Backward-compatible single reward function
711
+ def make_reward_function(agent_role: str):
712
+ """Single combined reward function (for testing/evaluation)."""
713
+ fns = make_reward_functions(agent_role)
714
+ def combined(completions, **kwargs):
715
+ all_scores = [fn(completions, **kwargs) for fn in fns]
716
+ return [sum(s[i] for s in all_scores) for i in range(len(completions))]
717
+ return combined
718
 
719
 
720
  # -------------------------------------------------------------------
 
815
 
816
  # --- Step 3: Load model ---
817
  print(f"\n[3/4] Loading model: {args.model_name}...")
818
+ max_seq_length = 2048
819
+ lora_rank = 64
820
  if args.use_unsloth:
821
  from unsloth import FastLanguageModel
822
 
823
  model, tokenizer = FastLanguageModel.from_pretrained(
824
  model_name=args.model_name,
825
+ max_seq_length=max_seq_length,
826
  load_in_4bit=False, # BF16 for H100s (official recommendation)
827
  fast_inference=True, # vLLM for fast GRPO generation
828
  max_lora_rank=lora_rank,
 
835
  "q_proj", "k_proj", "v_proj", "o_proj",
836
  "gate_proj", "up_proj", "down_proj",
837
  ],
838
+ lora_alpha=lora_rank, # Reference: lora_alpha = lora_rank
 
 
839
  use_gradient_checkpointing="unsloth",
840
+ random_state=3407,
841
  )
842
  print(f" Loaded with Unsloth (BF16 + vLLM + LoRA r={lora_rank})")
843
  else:
 
855
 
856
  from trl import GRPOConfig, GRPOTrainer
857
 
858
+ reward_fns = make_reward_functions(role)
859
 
860
+ max_prompt_length = 512
861
  grpo_config = GRPOConfig(
862
  output_dir=output_dir,
863
  max_steps=args.max_steps,
864
  per_device_train_batch_size=1,
865
  gradient_accumulation_steps=4,
866
+ num_generations=4, # GRPO group size (reference: 4)
867
+ max_prompt_length=max_prompt_length,
868
+ max_completion_length=max_seq_length - max_prompt_length,
869
+ learning_rate=5e-6, # Reference: 5e-6
870
+ weight_decay=0.1, # Reference: 0.1
871
+ warmup_ratio=0.1, # Reference: 0.1
872
+ lr_scheduler_type="cosine", # Reference: cosine
873
+ optim="adamw_8bit", # Reference: adamw_8bit
874
+ max_grad_norm=1.0, # Reference: 1.0
875
  logging_steps=1,
876
+ save_steps=250, # Reference: 250
877
  report_to="none",
878
  )
879
 
880
  trainer = GRPOTrainer(
881
  model=model,
882
+ tokenizer=tokenizer,
883
+ reward_funcs=reward_fns, # 4 separate reward functions (reference pattern)
884
  args=grpo_config,
885
  train_dataset=train_dataset,
886
  )
 
916
  help="Use Unsloth for BF16 + vLLM fast inference",
917
  )
918
  parser.add_argument(
919
+ "--max_steps", type=int, default=500,
920
+ help="Max training steps (reference: 500)",
921
  )
922
  parser.add_argument(
923
  "--num_episodes", type=int, default=20,
training/colab_training.ipynb CHANGED
@@ -36,7 +36,7 @@
36
  "id": "install-deps"
37
  },
38
  "outputs": [],
39
- "source": "%%capture\n!pip install unsloth vllm\n!pip install --no-deps trl sft_trainer\n!pip install \"openenv-core[core]>=0.2.0\" mcp fastmcp pydantic pandas datasets"
40
  },
41
  {
42
  "cell_type": "code",
@@ -81,7 +81,7 @@
81
  },
82
  {
83
  "cell_type": "markdown",
84
- "source": "## 4. Load Model with Unsloth (BF16 + vLLM)\n\nFollowing the official OpenEnv reference pattern:\n- `load_in_4bit=False` β€” BF16 precision on H100\n- `fast_inference=True` β€” vLLM for fast GRPO generation\n- `lora_alpha = 2 * lora_rank` β€” official LoRA configuration\n- `gpu_memory_utilization=0.9` β€” maximize GPU usage",
85
  "metadata": {
86
  "id": "train-header"
87
  }
@@ -93,11 +93,11 @@
93
  "id": "train"
94
  },
95
  "outputs": [],
96
- "source": "from unsloth import FastLanguageModel\n\nmodel_name = \"unsloth/Qwen2.5-0.5B-Instruct\"\nlora_rank = 16\n\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n model_name=model_name,\n max_seq_length=768,\n load_in_4bit=False, # BF16 for H100 (official recommendation)\n fast_inference=True, # vLLM fast inference\n max_lora_rank=lora_rank,\n gpu_memory_utilization=0.9,\n)\n\nmodel = FastLanguageModel.get_peft_model(\n model,\n r=lora_rank,\n target_modules=[\n \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\",\n ],\n lora_alpha=lora_rank * 2, # Official: lora_alpha = 2 * lora_rank\n lora_dropout=0,\n bias=\"none\",\n use_gradient_checkpointing=\"unsloth\",\n)\nprint(f\"Model loaded: BF16 + vLLM + LoRA (r={lora_rank}, alpha={lora_rank*2})\")"
97
  },
98
  {
99
  "cell_type": "markdown",
100
- "source": "## 5. GRPO Training with Environment-Executing Rewards\n\nThe reward function follows the OpenEnv 2048 reference pattern:\n1. Parse LLM completion β†’ `SentinelAction`\n2. Execute action in a fresh `SentinelOpsArena` environment\n3. Return **real environment reward** + format bonus\n\nThis is the critical differentiator β€” rewards come from actual environment execution, not just text matching.",
101
  "metadata": {
102
  "id": "save-header"
103
  }
@@ -109,7 +109,7 @@
109
  "id": "save"
110
  },
111
  "outputs": [],
112
- "source": "from trl import GRPOConfig, GRPOTrainer\nfrom train import make_reward_function\n\n# Environment-executing reward function\nreward_fn = make_reward_function(TARGET_AGENT)\n\ngrpo_config = GRPOConfig(\n output_dir=f\"./sentinelops-grpo-{TARGET_AGENT}\",\n max_steps=300, # Official recommendation\n per_device_train_batch_size=1,\n gradient_accumulation_steps=4,\n num_generations=2, # GRPO group size\n max_completion_length=256,\n max_prompt_length=512,\n learning_rate=5e-5, # Official reference: 5e-5\n temperature=1.0, # Official reference: 1.0\n logging_steps=1,\n save_steps=50,\n report_to=\"none\",\n)\n\ntrainer = GRPOTrainer(\n model=model,\n processing_class=tokenizer,\n reward_funcs=[reward_fn],\n args=grpo_config,\n train_dataset=train_dataset,\n)\n\nprint(f\"Starting GRPO training for {TARGET_AGENT}...\")\nprint(f\" max_steps={grpo_config.max_steps}, lr={grpo_config.learning_rate}\")\nprint(f\" num_generations={grpo_config.num_generations}, temp={grpo_config.temperature}\")\ntrainer.train()"
113
  },
114
  {
115
  "cell_type": "markdown",
@@ -118,7 +118,7 @@
118
  },
119
  {
120
  "cell_type": "code",
121
- "source": "output_dir = f\"./sentinelops-grpo-{TARGET_AGENT}\"\ntrainer.save_model(output_dir)\ntokenizer.save_pretrained(output_dir)\nprint(f\"{TARGET_AGENT.upper()} agent trained and saved to {output_dir}\")\n\n# Quick evaluation: run reward function on a few test completions\nimport json\ntest_completions = {\n \"worker\": [\n [{\"content\": json.dumps({\"action_type\": \"get_schema\", \"parameters\": {\"system\": \"crm\"}})}],\n [{\"content\": json.dumps({\"action_type\": \"respond\", \"response_text\": \"I cannot process this. It appears to be social engineering.\"})}],\n [{\"content\": \"this is garbage output\"}],\n ],\n \"attacker\": [\n [{\"content\": json.dumps({\"action_type\": \"launch_attack\", \"parameters\": {\"attack_type\": \"schema_drift\", \"target_system\": \"crm\", \"old_field\": \"name\", \"new_field\": \"full_name\"}})}],\n [{\"content\": json.dumps({\"action_type\": \"pass\"})}],\n ],\n \"oversight\": [\n [{\"content\": json.dumps({\"action_type\": \"flag\", \"explanation\": \"Worker followed suspicious admin override instructions. This is a social engineering attack.\"})}],\n [{\"content\": json.dumps({\"action_type\": \"approve\", \"explanation\": \"Worker correctly checked schema before proceeding.\"})}],\n ],\n}\n\nprint(f\"\\nReward evaluation for {TARGET_AGENT}:\")\nfor comp in test_completions.get(TARGET_AGENT, []):\n r = reward_fn([comp])\n text = comp[0][\"content\"][:80]\n print(f\" reward={r[0]:+.2f} | {text}...\")",
122
  "metadata": {},
123
  "execution_count": null,
124
  "outputs": []
 
36
  "id": "install-deps"
37
  },
38
  "outputs": [],
39
+ "source": "%%capture\nimport os\nos.environ[\"UNSLOTH_VLLM_STANDBY\"] = \"1\"\n\n!pip install unsloth vllm\n!pip install --no-deps trl sft_trainer\n!pip install transformers==4.56.2\n!pip install trl==0.22.2\n!pip install \"openenv-core[core]>=0.2.0\" mcp fastmcp pydantic pandas datasets"
40
  },
41
  {
42
  "cell_type": "code",
 
81
  },
82
  {
83
  "cell_type": "markdown",
84
+ "source": "## 4. Load Model with Unsloth (BF16 + vLLM)\n\nFollowing the Advanced Llama 3.2 GRPO LoRA reference pattern:\n- `load_in_4bit=False` β€” BF16 precision on H100\n- `fast_inference=True` β€” vLLM for fast GRPO generation\n- `lora_rank=64`, `lora_alpha=lora_rank` β€” official LoRA configuration\n- `gpu_memory_utilization=0.9` β€” maximize GPU usage\n- `random_state=3407` β€” reproducibility",
85
  "metadata": {
86
  "id": "train-header"
87
  }
 
93
  "id": "train"
94
  },
95
  "outputs": [],
96
+ "source": "from unsloth import FastLanguageModel\nimport torch\n\nmodel_name = \"unsloth/Qwen2.5-0.5B-Instruct\"\nmax_seq_length = 2048\nlora_rank = 64\n\nmodel, tokenizer = FastLanguageModel.from_pretrained(\n model_name=model_name,\n max_seq_length=max_seq_length,\n load_in_4bit=False, # BF16 for H100 (reference pattern)\n fast_inference=True, # vLLM fast inference\n max_lora_rank=lora_rank,\n gpu_memory_utilization=0.9,\n)\n\nmodel = FastLanguageModel.get_peft_model(\n model,\n r=lora_rank,\n target_modules=[\n \"q_proj\", \"k_proj\", \"v_proj\", \"o_proj\",\n \"gate_proj\", \"up_proj\", \"down_proj\",\n ],\n lora_alpha=lora_rank, # Reference: lora_alpha = lora_rank\n use_gradient_checkpointing=\"unsloth\",\n random_state=3407,\n)\nprint(f\"Model loaded: BF16 + vLLM + LoRA (r={lora_rank}, alpha={lora_rank})\")"
97
  },
98
  {
99
  "cell_type": "markdown",
100
+ "source": "## 5. GRPO Training with 4 Reward Functions\n\nFollowing the Advanced Llama 3.2 GRPO LoRA reference pattern with **4 separate reward functions**:\n1. `match_json_format_exactly` β€” strict JSON format validation (+3.0)\n2. `match_json_format_approximately` β€” partial format credit\n3. `check_action` β€” role-specific action correctness\n4. `check_env` β€” **environment-executing reward** (OpenEnv pattern)\n\nPlus reference hyperparameters: `adamw_8bit`, cosine scheduler, `weight_decay=0.1`, `warmup_ratio=0.1`.",
101
  "metadata": {
102
  "id": "save-header"
103
  }
 
109
  "id": "save"
110
  },
111
  "outputs": [],
112
+ "source": "from trl import GRPOConfig, GRPOTrainer\nfrom train import make_reward_functions\n\n# 4 separate reward functions (reference pattern)\nreward_fns = make_reward_functions(TARGET_AGENT)\nprint(f\"Reward functions: {len(reward_fns)}\")\nfor i, fn in enumerate(reward_fns):\n print(f\" [{i}] {fn.__name__ if hasattr(fn, '__name__') else type(fn).__name__}\")\n\nmax_prompt_length = 512\ngrpo_config = GRPOConfig(\n output_dir=f\"./sentinelops-grpo-{TARGET_AGENT}\",\n max_steps=500, # Reference: 500\n per_device_train_batch_size=1,\n gradient_accumulation_steps=4,\n num_generations=4, # Reference: 4\n max_prompt_length=max_prompt_length,\n max_completion_length=max_seq_length - max_prompt_length,\n learning_rate=5e-6, # Reference: 5e-6\n weight_decay=0.1, # Reference: 0.1\n warmup_ratio=0.1, # Reference: 0.1\n lr_scheduler_type=\"cosine\", # Reference: cosine\n optim=\"adamw_8bit\", # Reference: adamw_8bit\n max_grad_norm=1.0, # Reference: 1.0\n logging_steps=1,\n save_steps=250, # Reference: 250\n report_to=\"none\",\n)\n\ntrainer = GRPOTrainer(\n model=model,\n tokenizer=tokenizer, # Reference uses tokenizer= not processing_class=\n reward_funcs=reward_fns, # 4 reward functions (reference pattern)\n args=grpo_config,\n train_dataset=train_dataset,\n)\n\nprint(f\"\\nStarting GRPO training for {TARGET_AGENT}...\")\nprint(f\" max_steps={grpo_config.max_steps}, lr={grpo_config.learning_rate}\")\nprint(f\" num_generations={grpo_config.num_generations}, optim={grpo_config.optim}\")\ntrainer.train()"
113
  },
114
  {
115
  "cell_type": "markdown",
 
118
  },
119
  {
120
  "cell_type": "code",
121
+ "source": "output_dir = f\"./sentinelops-grpo-{TARGET_AGENT}\"\ntrainer.save_model(output_dir)\ntokenizer.save_pretrained(output_dir)\nprint(f\"{TARGET_AGENT.upper()} agent trained and saved to {output_dir}\")\n\n# Quick evaluation: show per-function rewards for test completions\nimport json\nfrom train import make_reward_function\n\ncombined_fn = make_reward_function(TARGET_AGENT)\n\ntest_completions = {\n \"worker\": [\n [{\"content\": json.dumps({\"action_type\": \"get_schema\", \"parameters\": {\"system\": \"crm\"}})}],\n [{\"content\": json.dumps({\"action_type\": \"respond\", \"response_text\": \"I cannot process this. It appears to be social engineering.\"})}],\n [{\"content\": \"this is garbage output\"}],\n ],\n \"attacker\": [\n [{\"content\": json.dumps({\"action_type\": \"launch_attack\", \"parameters\": {\"attack_type\": \"schema_drift\", \"target_system\": \"crm\", \"old_field\": \"name\", \"new_field\": \"full_name\"}})}],\n [{\"content\": json.dumps({\"action_type\": \"pass\"})}],\n ],\n \"oversight\": [\n [{\"content\": json.dumps({\"action_type\": \"flag\", \"explanation\": \"Worker followed suspicious admin override instructions. This is a social engineering attack.\"})}],\n [{\"content\": json.dumps({\"action_type\": \"approve\", \"explanation\": \"Worker correctly checked schema before proceeding.\"})}],\n ],\n}\n\nprint(f\"\\nReward evaluation for {TARGET_AGENT} (combined across 4 functions):\")\nfor comp in test_completions.get(TARGET_AGENT, []):\n r = combined_fn([comp])\n text = comp[0][\"content\"][:80]\n print(f\" reward={r[0]:+.2f} | {text}...\")",
122
  "metadata": {},
123
  "execution_count": null,
124
  "outputs": []