adaptshield / train.py
SaiManish123's picture
Initial deploy of AdaptShield two-phase cybersecurity environment
c1060df verified
#!/usr/bin/env python3
"""AdaptShield onsite GPU training harness with safe local fallback."""
from __future__ import annotations
import argparse
import inspect
import json
import os
import random
import sys
import time
from pathlib import Path
from typing import Any, Dict, List, Tuple
REPO_ROOT = Path(__file__).resolve().parent
if str(REPO_ROOT) not in sys.path:
sys.path.insert(0, str(REPO_ROOT))
from models import AdaptShieldAction
from server.adaptshield_environment import AdaptShieldEnvironment
from train_smoke import TASKS, run_smoke_training
from plot_training import plot as plot_training_output
from soc_tools import attach_tool_results, investigate_local_with_depth, summarize_tool_results
MODEL_CHOICES = {
"0.5b": "unsloth/Qwen2.5-0.5B-Instruct",
"1.5b": "unsloth/Qwen2.5-1.5B-Instruct",
"3b": "unsloth/Qwen2.5-3B-Instruct",
"7b": "unsloth/Qwen2.5-7B-Instruct",
}
DEFAULT_MODEL = "1.5b"
MAX_SEQ_LEN = 2048
LORA_RANK = 16
CURRICULUM_STAGES = [
("direct-triage", 0.30),
("dual-pivot", 0.40),
("polymorphic-zero-day", 0.30),
]
PHASE1_SYS = """You are a Threat Analyst for a 4-node enterprise network.
Analyze SIEM metrics and alerts. Identify the threat.
Attack strategies: brute_force, lateral_movement, exfiltration, supply_chain, benign
Nodes: auth_service, payment_service, database, api_gateway
Actions: rate_limit, isolate, honeypot, patch, monitor
If SOC tool evidence is provided, use it to update your belief before classifying.
Respond ONLY with valid JSON:
{"threat_type":"...","confidence":0.0,"target_node":"...","recommended_action":"...","reasoning":"..."}"""
PHASE2_SYS = """You are a Tactical Executor. Act only on the analyst handoff.
You cannot see raw network data in Phase 2.
Use the analyst handoff plus any SOC tool trace from this turn.
Actions: rate_limit, isolate, honeypot, patch, monitor
Nodes: auth_service, payment_service, database, api_gateway
Respond ONLY with valid JSON:
{"action":"...","target_node":"...","reasoning":"..."}"""
def obs_to_dict(obs: Any) -> Dict[str, Any]:
if hasattr(obs, "model_dump"):
return obs.model_dump(mode="json")
return dict(obs)
def make_phase1_prompt(obs: Dict[str, Any]) -> str:
return "\n".join([
"Network nodes:",
json.dumps(obs.get("network_nodes", {}), indent=2),
"",
"Active alerts:",
"\n".join(obs.get("active_alerts", [])),
"",
"SOC tool evidence:",
summarize_tool_results(obs.get("tool_results", [])),
"",
"Recent history:",
json.dumps(obs.get("history", [])[-3:], indent=2),
"",
"Classify the threat:",
])
def make_phase2_prompt(obs: Dict[str, Any]) -> str:
metadata = obs.get("metadata", {}) if isinstance(obs.get("metadata", {}), dict) else {}
current_turn = int(obs.get("turn", 0) or 0)
tool_trace = [
row for row in metadata.get("tool_trace", [])
if int(row.get("turn", -1)) == current_turn
]
return "\n".join([
"Threat assessment from analyst:",
json.dumps(obs.get("phase1_assessment", {}), indent=2),
"",
"SOC tool trace for this turn:",
json.dumps(tool_trace, indent=2),
"",
"Choose the defensive action:",
])
def build_messages(obs: Dict[str, Any]) -> List[Dict[str, str]]:
if int(obs.get("phase", 1)) == 1:
return [
{"role": "system", "content": PHASE1_SYS},
{"role": "user", "content": make_phase1_prompt(obs)},
]
return [
{"role": "system", "content": PHASE2_SYS},
{"role": "user", "content": make_phase2_prompt(obs)},
]
def task_for_episode(
episode: int,
total_episodes: int,
selected_task: str,
curriculum: bool,
) -> Tuple[str, str]:
if not curriculum:
if selected_task == "all":
task = TASKS[(episode - 1) % len(TASKS)]
return task, "round_robin"
return selected_task, "fixed"
progress = episode / max(1, total_episodes)
cumulative = 0.0
for task, fraction in CURRICULUM_STAGES:
cumulative += fraction
if progress <= cumulative:
return task, f"curriculum:{task}"
return CURRICULUM_STAGES[-1][0], f"curriculum:{CURRICULUM_STAGES[-1][0]}"
def save_metrics(
output_dir: Path,
rows: List[Dict[str, Any]],
model_name: str,
episodes: int,
curriculum: bool,
use_tools: bool,
trainer: str = "pg",
evaluation_rows: List[Dict[str, Any]] | None = None,
heldout_evaluation_rows: List[Dict[str, Any]] | None = None,
prompt_bank_size: int = 0,
extra: Dict[str, Any] | None = None,
) -> Path:
output_dir.mkdir(parents=True, exist_ok=True)
best_score = max((float(row["score"]) for row in rows), default=0.0)
metrics_path = output_dir / "metrics.json"
payload = {
"model": model_name,
"episodes": episodes,
"curriculum": curriculum,
"curriculum_stages": CURRICULUM_STAGES,
"use_tools": use_tools,
"trainer": trainer,
"rows": rows,
"best_score": best_score,
}
if evaluation_rows is not None:
payload["evaluation_rows"] = evaluation_rows
if heldout_evaluation_rows is not None:
payload["heldout_evaluation_rows"] = heldout_evaluation_rows
if prompt_bank_size:
payload["prompt_bank_size"] = prompt_bank_size
if extra:
payload.update(extra)
metrics_path.write_text(json.dumps(payload, indent=2))
return metrics_path
def maybe_plot(metrics_path: Path, output_dir: Path) -> None:
try:
plot_training_output(metrics_path, output_dir / "reward_curve.png")
except Exception as exc:
print(f"Plot generation skipped: {exc}")
def parse_response(text: str, phase: int) -> Dict[str, Any]:
"""Parse model JSON. Invalid output becomes a safe phase-correct action."""
if "```" in text:
for part in text.split("```"):
if "{" in part:
text = part.strip().removeprefix("json").strip()
break
try:
parsed = json.loads(text)
if phase == 1:
return {
"threat_type": str(parsed.get("threat_type", "brute_force")),
"confidence": float(parsed.get("confidence", 0.5)),
"target_node": str(parsed.get("target_node", "auth_service")),
"recommended_action": str(parsed.get("recommended_action", "monitor")),
"reasoning": str(parsed.get("reasoning", "")),
}
return {
"action": str(parsed.get("action", "monitor")),
"target_node": str(parsed.get("target_node", "auth_service")),
"reasoning": str(parsed.get("reasoning", "")),
}
except Exception:
if phase == 1:
return {
"threat_type": "brute_force",
"confidence": 0.5,
"target_node": "auth_service",
"recommended_action": "monitor",
"reasoning": "parse_error",
}
return {
"action": "monitor",
"target_node": "auth_service",
"reasoning": "parse_error",
}
def render_messages(messages: List[Dict[str, str]], tokenizer: Any | None = None) -> str:
if tokenizer is not None and hasattr(tokenizer, "apply_chat_template"):
return tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
)
return "\n\n".join(
f"{message.get('role', 'user').upper()}:\n{message.get('content', '')}"
for message in messages
)
def generate_response(model: Any, tokenizer: Any, messages: List[Dict[str, str]]) -> Tuple[str, str]:
import torch
prompt = render_messages(messages, tokenizer=tokenizer)
device = getattr(model, "device", None)
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
inputs = tokenizer(prompt, return_tensors="pt").to(device)
pad_token_id = (
tokenizer.pad_token_id
if getattr(tokenizer, "pad_token_id", None) is not None
else tokenizer.eos_token_id
)
with torch.no_grad():
_normalize_generation_config(model)
output_ids = model.generate(
**inputs,
max_new_tokens=220,
temperature=0.7,
do_sample=True,
pad_token_id=pad_token_id,
)
new_ids = output_ids[0][inputs["input_ids"].shape[1]:]
response = tokenizer.decode(new_ids, skip_special_tokens=True).strip()
return prompt, response
def _current_reference(env: AdaptShieldEnvironment) -> Dict[str, Any]:
turn_config = dict(getattr(env, "_turn_config", {}) or {})
is_benign = bool(turn_config.get("is_benign", False))
threat_type = "benign" if is_benign else str(turn_config.get("strategy", "benign"))
target_node = str(turn_config.get("correct_target", "auth_service"))
expected_action = str(turn_config.get("correct_action", "monitor"))
return {
"threat_type": threat_type,
"target_node": target_node,
"expected_action": expected_action,
"stage": str(turn_config.get("attack_stage", getattr(env._attacker, "current_stage", lambda: "recon")())),
"is_benign": is_benign,
}
def _align_trainable_dtypes(model: Any, target_dtype: Any | None = None) -> str:
"""Keep LoRA/trainable params on the same compute dtype as the main model.
Some adapter checkpoints reload trainable LoRA weights as float32, while
Unsloth GRPO kernels run activations in float16/bfloat16. That mismatch
trips fast_lora matmuls at runtime. We fix only trainable floating params.
"""
import torch
if target_dtype is None:
for param in model.parameters():
if param.is_floating_point() and not param.requires_grad:
target_dtype = param.dtype
break
if target_dtype is None:
for param in model.parameters():
if param.is_floating_point():
target_dtype = param.dtype
break
if target_dtype is None:
return "no-floating-params"
converted = 0
for param in model.parameters():
if param.requires_grad and param.is_floating_point() and param.dtype != target_dtype:
param.data = param.data.to(target_dtype)
converted += 1
for buffer_name, buffer in model.named_buffers():
if "lora_" in buffer_name and buffer.is_floating_point() and buffer.dtype != target_dtype:
buffer.data = buffer.data.to(target_dtype)
if getattr(model, "generation_config", None) is not None:
_normalize_generation_config(model)
return f"{target_dtype} ({converted} trainable params aligned)"
def _normalize_generation_config(model: Any) -> None:
generation_config = getattr(model, "generation_config", None)
if generation_config is None:
return
for field in ("max_length",):
try:
setattr(generation_config, field, None)
except Exception:
continue
def _load_training_model_and_tokenizer(
model_name: str,
model_key: str,
max_seq_length: int,
compute_dtype: Any,
seed: int,
):
from unsloth import FastLanguageModel
adapter_path = model_name if _looks_like_adapter_path(model_name) else ""
base_model_name = MODEL_CHOICES[model_key] if adapter_path else model_name
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=base_model_name,
max_seq_length=max_seq_length,
load_in_4bit=True,
dtype=compute_dtype,
)
if adapter_path:
from peft import PeftModel
model = PeftModel.from_pretrained(
model,
adapter_path,
is_trainable=True,
autocast_adapter_dtype=False,
)
try:
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True)
except Exception:
pass
else:
model = FastLanguageModel.get_peft_model(
model,
r=LORA_RANK,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj",
],
lora_alpha=LORA_RANK * 2,
lora_dropout=0.0,
bias="none",
use_gradient_checkpointing="unsloth",
random_state=seed,
)
return model, tokenizer
def _teacher_payload(phase: int, reference: Dict[str, Any]) -> Dict[str, Any]:
if phase == 1:
return {
"threat_type": reference["threat_type"],
"confidence": 0.92 if reference["threat_type"] != "benign" else 0.78,
"target_node": reference["target_node"],
"recommended_action": reference["expected_action"],
"reasoning": "reference policy",
}
return {
"action": reference["expected_action"],
"target_node": reference["target_node"],
"reasoning": "reference policy",
}
def build_prompt_bank(
tokenizer: Any | None,
selected_task: str,
curriculum: bool,
rollout_episodes: int,
max_steps: int,
use_tools: bool,
seed: int,
world_split: str = "train",
world_family: str | None = None,
hard_multiplier: int = 2,
borderline_bonus: int = 1,
) -> List[Dict[str, Any]]:
random.seed(seed)
rows: List[Dict[str, Any]] = []
for episode in range(1, rollout_episodes + 1):
task, stage = task_for_episode(
episode=episode,
total_episodes=rollout_episodes,
selected_task=selected_task,
curriculum=curriculum,
)
env = AdaptShieldEnvironment(
task_name=task,
world_split=world_split,
world_family=world_family,
)
obs = env.reset()
step_count = 0
while not obs.done and step_count < max_steps:
phase = int(getattr(obs, "phase", 1))
tool_results = investigate_local_with_depth(
env,
obs,
use_tools=use_tools,
thorough=True,
)
obs_dict = attach_tool_results(obs_to_dict(obs), tool_results)
messages = build_messages(obs_dict)
reference = _current_reference(env)
rows.append({
"prompt": render_messages(messages, tokenizer=tokenizer),
"task": task,
"stage": stage,
"phase": phase,
"turn": int(getattr(obs, "turn", 0) or 0),
"attack_stage": reference["stage"],
"world_split": getattr(env, "_world_split", world_split),
"world_family": getattr(env, "_world_family", world_family or ""),
"operational_mode": getattr(env, "_operational_mode", ""),
"expected_threat_type": reference["threat_type"],
"expected_target_node": reference["target_node"],
"expected_recommended_action": reference["expected_action"] if phase == 1 else "",
"expected_action": reference["expected_action"] if phase == 2 else "",
"tool_calls": len(tool_results),
"history_length": len(obs_dict.get("history", [])),
"difficulty_tags": _difficulty_tags(
task=task,
phase=phase,
attack_stage=reference["stage"],
tool_calls=len(tool_results),
handoff_quality=str((obs_dict.get("phase1_assessment") or {}).get("handoff_quality", "")),
),
})
base_row = rows[-1]
for _ in range(_prompt_bank_extra_copies(
row=base_row,
hard_multiplier=hard_multiplier,
borderline_bonus=borderline_bonus,
)):
rows.append(dict(base_row))
obs = env.step(AdaptShieldAction(**_teacher_payload(phase, reference)))
step_count += 1
return rows
def _difficulty_tags(
task: str,
phase: int,
attack_stage: str,
tool_calls: int,
handoff_quality: str,
) -> List[str]:
tags: List[str] = []
if task == "polymorphic-zero-day":
tags.append("hard")
elif task == "dual-pivot":
tags.append("medium")
if phase == 2:
tags.append("phase2")
if attack_stage in {"exploit", "exfiltration"}:
tags.append("late_stage")
if tool_calls >= 3:
tags.append("tool_fusion")
if handoff_quality == "degraded":
tags.append("borderline")
return tags
def _prompt_bank_extra_copies(
row: Dict[str, Any],
hard_multiplier: int,
borderline_bonus: int,
) -> int:
tags = set(row.get("difficulty_tags", []) or [])
extra = 0
if row.get("task") == "polymorphic-zero-day":
extra += max(0, hard_multiplier - 1)
elif row.get("task") == "dual-pivot" and "late_stage" in tags:
extra += 1
if "borderline" in tags or ("phase2" in tags and "tool_fusion" in tags and "late_stage" in tags):
extra += max(0, borderline_bonus)
return extra
def _completion_to_text(completion: Any) -> str:
if isinstance(completion, str):
return completion
if isinstance(completion, dict):
if "content" in completion:
return str(completion.get("content", ""))
if "text" in completion:
return str(completion.get("text", ""))
if isinstance(completion, list):
parts = []
for item in completion:
if isinstance(item, dict):
parts.append(str(item.get("content", item.get("text", ""))))
else:
parts.append(str(item))
return "".join(parts)
return str(completion)
def _phase1_reward(
parsed: Dict[str, Any],
expected_threat_type: str,
expected_target_node: str,
expected_recommended_action: str,
) -> float:
reward = 0.08
if parsed.get("threat_type") == expected_threat_type:
reward += 0.36
if parsed.get("target_node") == expected_target_node:
reward += 0.20
if parsed.get("recommended_action") == expected_recommended_action:
reward += 0.18
try:
confidence = float(parsed.get("confidence", 0.5))
except Exception:
confidence = 0.5
if 0.0 <= confidence <= 1.0:
reward += 0.05
if parsed.get("threat_type") == expected_threat_type and confidence >= 0.65:
reward += 0.06
elif parsed.get("threat_type") != expected_threat_type and confidence >= 0.80:
reward -= 0.05
if parsed.get("recommended_action") == "monitor" and expected_threat_type != "benign":
reward -= 0.05
return max(0.01, min(0.99, round(reward, 2)))
def _phase2_reward(
parsed: Dict[str, Any],
expected_action: str,
expected_target_node: str,
tool_calls: int,
) -> float:
reward = 0.08
if parsed.get("action") == expected_action:
reward += 0.62
if parsed.get("target_node") == expected_target_node:
reward += 0.18
if parsed.get("action") == expected_action and tool_calls >= 2:
reward += 0.07
if parsed.get("action") == "monitor" and expected_action != "monitor":
reward -= 0.08
return max(0.01, min(0.99, round(reward, 2)))
def build_grpo_reward_fn():
def reward_fn(completions: List[Any], **kwargs: Any) -> List[float]:
phases = kwargs.get("phase", [])
expected_threat_types = kwargs.get("expected_threat_type", [])
expected_targets = kwargs.get("expected_target_node", [])
expected_recommended_actions = kwargs.get("expected_recommended_action", [])
expected_actions = kwargs.get("expected_action", [])
tool_calls = kwargs.get("tool_calls", [])
rewards: List[float] = []
for index, completion in enumerate(completions):
phase = int(phases[index]) if phases else 1
text = _completion_to_text(completion)
parsed = parse_response(text, phase)
if phase == 1:
reward = _phase1_reward(
parsed=parsed,
expected_threat_type=str(expected_threat_types[index]),
expected_target_node=str(expected_targets[index]),
expected_recommended_action=str(expected_recommended_actions[index]),
)
else:
reward = _phase2_reward(
parsed=parsed,
expected_action=str(expected_actions[index]),
expected_target_node=str(expected_targets[index]),
tool_calls=int(tool_calls[index]) if tool_calls else 0,
)
rewards.append(reward)
return rewards
return reward_fn
def _filter_supported_kwargs(callable_obj: Any, kwargs: Dict[str, Any]) -> Dict[str, Any]:
try:
signature = inspect.signature(callable_obj)
except (TypeError, ValueError):
return kwargs
valid = {}
for key, value in kwargs.items():
if key in signature.parameters:
valid[key] = value
return valid
def _trainer_log_rows(log_history: List[Dict[str, Any]], selected_task: str) -> List[Dict[str, Any]]:
rows: List[Dict[str, Any]] = []
for entry in log_history:
step = entry.get("step")
if step is None:
continue
reward_keys = [
"reward",
"mean_reward",
"rewards/mean",
"objective",
"objective/rlhf_reward",
]
score = None
for key in reward_keys:
if key in entry:
try:
score = float(entry[key])
break
except Exception:
continue
if score is None:
score = 0.50
row = {
"episode": int(step),
"task": "mixed" if selected_task == "all" else selected_task,
"stage": "grpo",
"score": max(0.01, min(0.99, score)),
"loss": float(entry.get("loss", 0.0) or 0.0),
"learning_rate": float(entry.get("learning_rate", 0.0) or 0.0),
}
rows.append(row)
return rows
def evaluate_model_suite(
model: Any,
tokenizer: Any,
selected_task: str,
eval_episodes: int,
max_steps: int,
use_tools: bool,
world_split: str = "train",
world_family: str | None = None,
seed_start: int | None = None,
) -> List[Dict[str, Any]]:
tasks = TASKS if selected_task == "all" else [selected_task]
rows: List[Dict[str, Any]] = []
for task in tasks:
scores: List[float] = []
steps: List[int] = []
tool_calls: List[int] = []
original_seed = os.environ.get("ADAPTSHIELD_SEED")
for episode_index in range(eval_episodes):
if seed_start is not None:
os.environ["ADAPTSHIELD_SEED"] = str(seed_start + len(rows) * 100 + episode_index)
_, metrics = run_model_episode(
model=model,
tokenizer=tokenizer,
task=task,
max_steps=max_steps,
use_tools=use_tools,
world_split=world_split,
world_family=world_family,
)
scores.append(float(metrics["score"]))
steps.append(int(metrics["steps"]))
tool_calls.append(int(metrics["tool_calls"]))
if original_seed is None:
os.environ.pop("ADAPTSHIELD_SEED", None)
else:
os.environ["ADAPTSHIELD_SEED"] = original_seed
rows.append({
"episode": len(rows) + 1,
"task": task,
"stage": "evaluation",
"score": round(sum(scores) / len(scores), 3) if scores else 0.50,
"steps": round(sum(steps) / len(steps), 2) if steps else 0.0,
"tool_calls": round(sum(tool_calls) / len(tool_calls), 2) if tool_calls else 0.0,
"eval_episodes": eval_episodes,
"world_split": world_split,
"world_family": world_family or "auto",
})
return rows
def run_model_episode(
model: Any,
tokenizer: Any,
task: str,
max_steps: int,
use_tools: bool,
world_split: str = "train",
world_family: str | None = None,
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
env = AdaptShieldEnvironment(
task_name=task,
world_split=world_split,
world_family=world_family,
)
obs = env.reset()
samples: List[Dict[str, Any]] = []
rewards: List[float] = []
tool_calls = 0
while not obs.done and len(samples) < max_steps:
phase = int(getattr(obs, "phase", 1))
tool_results = investigate_local_with_depth(
env,
obs,
use_tools=use_tools,
thorough=True,
)
tool_calls += len(tool_results)
obs_dict = obs_to_dict(obs)
obs_dict = attach_tool_results(obs_dict, tool_results)
messages = build_messages(obs_dict)
prompt, response = generate_response(model, tokenizer, messages)
payload = parse_response(response, phase)
try:
obs = env.step(AdaptShieldAction(**payload))
reward = float(obs.reward)
except Exception as exc:
reward = 0.01
samples.append({
"prompt": prompt,
"response": response,
"reward": reward,
"phase": phase,
"tool_calls": len(tool_results),
"error": str(exc),
})
break
rewards.append(reward)
samples.append({
"prompt": prompt,
"response": response,
"reward": reward,
"phase": phase,
"tool_calls": len(tool_results),
"error": None,
})
metadata = obs.metadata if isinstance(obs.metadata, dict) else {}
if "normalized_score" not in metadata:
raise RuntimeError("normalized_score missing after training episode")
return samples, {
"score": float(metadata["normalized_score"]),
"steps": len(samples),
"reward_sum": sum(rewards),
"mean_reward": sum(rewards) / len(rewards) if rewards else 0.0,
"tool_calls": tool_calls,
"world_split": world_split,
"world_family": metadata.get("world_family", world_family or "auto"),
"operational_mode": metadata.get("operational_mode", "unknown"),
}
def train_policy_gradient(args: argparse.Namespace) -> None:
import torch
from torch.optim import AdamW
random.seed(args.seed)
torch.manual_seed(args.seed)
model_name = args.model_path or MODEL_CHOICES[args.model]
output_dir = Path(args.output)
output_dir.mkdir(parents=True, exist_ok=True)
print("AdaptShield policy-gradient GPU training")
print(f"Task: {args.task}")
print(f"Curriculum: {args.curriculum}")
print(f"Use tools: {args.use_tools}")
print(f"Model: {model_name}")
print(f"Episodes: {args.episodes}")
print(f"Output: {output_dir}")
print()
model, tokenizer = _load_training_model_and_tokenizer(
model_name=model_name,
model_key=args.model,
max_seq_length=MAX_SEQ_LEN,
compute_dtype=None,
seed=args.seed,
)
from unsloth import FastLanguageModel
FastLanguageModel.for_training(model)
dtype_summary = _align_trainable_dtypes(model)
print(f"Aligned trainable parameter dtypes: {dtype_summary}")
optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)
rows: List[Dict[str, Any]] = []
best_score = -1.0
for episode in range(1, args.episodes + 1):
started = time.time()
task, stage = task_for_episode(
episode=episode,
total_episodes=args.episodes,
selected_task=args.task,
curriculum=args.curriculum,
)
samples, metrics = run_model_episode(
model=model,
tokenizer=tokenizer,
task=task,
max_steps=args.max_steps,
use_tools=args.use_tools,
world_split=args.train_world_split,
)
rewards = [float(sample["reward"]) for sample in samples]
baseline = sum(rewards) / len(rewards) if rewards else 0.0
total_loss = 0.0
for sample in samples:
advantage = float(sample["reward"]) - baseline
full_text = sample["prompt"] + sample["response"] + tokenizer.eos_token
inputs = tokenizer(
full_text,
return_tensors="pt",
truncation=True,
max_length=MAX_SEQ_LEN,
).to("cuda")
outputs = model(**inputs, labels=inputs["input_ids"])
loss = outputs.loss * (-advantage)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
total_loss += float(loss.item())
row = {
"episode": episode,
"task": task,
"stage": stage,
"score": metrics["score"],
"steps": metrics["steps"],
"reward_sum": metrics["reward_sum"],
"mean_reward": metrics["mean_reward"],
"tool_calls": metrics["tool_calls"],
"loss": total_loss,
"seconds": round(time.time() - started, 2),
}
rows.append(row)
print(
f"episode={episode:03d} task={task:<20} "
f"stage={stage:<32} "
f"score={row['score']:.3f} mean_reward={row['mean_reward']:.3f} "
f"loss={row['loss']:.4f} steps={row['steps']:02d} tools={row['tool_calls']:02d}"
)
if row["score"] > best_score:
best_score = row["score"]
model.save_pretrained(output_dir / "best")
tokenizer.save_pretrained(output_dir / "best")
if args.save_every and episode % args.save_every == 0:
model.save_pretrained(output_dir / f"checkpoint-{episode}")
tokenizer.save_pretrained(output_dir / f"checkpoint-{episode}")
model.save_pretrained(output_dir / "final")
tokenizer.save_pretrained(output_dir / "final")
evaluation_rows = evaluate_model_suite(
model=model,
tokenizer=tokenizer,
selected_task=args.task,
eval_episodes=args.eval_episodes,
max_steps=args.max_steps,
use_tools=args.use_tools,
world_split=args.train_world_split,
seed_start=args.heldout_seed,
)
heldout_evaluation_rows = evaluate_model_suite(
model=model,
tokenizer=tokenizer,
selected_task=args.task,
eval_episodes=args.eval_episodes,
max_steps=args.max_steps,
use_tools=args.use_tools,
world_split=args.heldout_world_split,
seed_start=args.heldout_seed,
)
metrics_path = save_metrics(
output_dir=output_dir,
rows=rows,
model_name=model_name,
episodes=args.episodes,
curriculum=args.curriculum,
use_tools=args.use_tools,
trainer="pg",
evaluation_rows=evaluation_rows,
heldout_evaluation_rows=heldout_evaluation_rows,
extra={
"train_world_split": args.train_world_split,
"heldout_world_split": args.heldout_world_split,
"heldout_seed": args.heldout_seed,
},
)
if args.plot:
maybe_plot(metrics_path, output_dir)
print()
print(f"Training complete. Best score: {best_score:.3f}")
print("Post-train online evaluation:")
for row in evaluation_rows:
print(
f" task={row['task']:<20} score={row['score']:.3f} "
f"steps={row['steps']} tools={row['tool_calls']}"
)
print("Held-out family evaluation:")
for row in heldout_evaluation_rows:
print(
f" task={row['task']:<20} score={row['score']:.3f} "
f"steps={row['steps']} tools={row['tool_calls']}"
)
print(f"Metrics saved to: {metrics_path}")
def train_grpo(args: argparse.Namespace) -> None:
from datasets import Dataset
from trl import GRPOConfig, GRPOTrainer
import torch
random.seed(args.seed)
torch.manual_seed(args.seed)
model_name = args.model_path or MODEL_CHOICES[args.model]
output_dir = Path(args.output)
output_dir.mkdir(parents=True, exist_ok=True)
print("AdaptShield GRPO training")
print(f"Task: {args.task}")
print(f"Curriculum: {args.curriculum}")
print(f"Use tools: {args.use_tools}")
print(f"Model: {model_name}")
print(f"Prompt-bank episodes: {args.prompt_bank_episodes}")
print(f"GRPO epochs: {args.grpo_epochs}")
print(f"Eval episodes: {args.eval_episodes}")
print(f"Output: {output_dir}")
print()
bf16_supported = bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)())
compute_dtype = torch.bfloat16 if bf16_supported else torch.float16
model, tokenizer = _load_training_model_and_tokenizer(
model_name=model_name,
model_key=args.model,
max_seq_length=MAX_SEQ_LEN,
compute_dtype=compute_dtype,
seed=args.seed,
)
from unsloth import FastLanguageModel
if getattr(tokenizer, "pad_token", None) is None:
tokenizer.pad_token = tokenizer.eos_token
if getattr(model, "config", None) is not None:
try:
model.config.return_dict = True
except Exception:
pass
try:
model.config.use_cache = False
except Exception:
pass
if getattr(model, "generation_config", None) is not None:
try:
model.generation_config.pad_token_id = tokenizer.pad_token_id
except Exception:
pass
FastLanguageModel.for_training(model)
dtype_summary = _align_trainable_dtypes(model, target_dtype=compute_dtype)
print(f"Using GRPO compute dtype: {compute_dtype}")
print(f"Aligned trainable parameter dtypes: {dtype_summary}")
prompt_bank = build_prompt_bank(
tokenizer=tokenizer,
selected_task=args.task,
curriculum=args.curriculum,
rollout_episodes=args.prompt_bank_episodes,
max_steps=args.max_steps,
use_tools=args.use_tools,
seed=args.seed,
world_split=args.train_world_split,
hard_multiplier=args.prompt_bank_hard_multiplier,
borderline_bonus=args.prompt_bank_borderline_bonus,
)
if not prompt_bank:
raise RuntimeError("Prompt bank is empty; cannot start GRPO training.")
dataset = Dataset.from_list(prompt_bank)
reward_fn = build_grpo_reward_fn()
config_kwargs = {
"output_dir": str(output_dir),
"learning_rate": args.lr,
"per_device_train_batch_size": args.per_device_batch_size,
"gradient_accumulation_steps": args.gradient_accumulation_steps,
"num_train_epochs": args.grpo_epochs,
"max_prompt_length": MAX_SEQ_LEN - 256,
"max_completion_length": 256,
"num_generations": args.num_generations,
"logging_steps": 1,
"save_strategy": "no" if args.save_every <= 0 else "steps",
"report_to": "none",
"remove_unused_columns": False,
"bf16": bf16_supported,
"fp16": not bf16_supported,
"max_grad_norm": 1.0,
"seed": args.seed,
}
if args.save_every > 0:
config_kwargs["save_steps"] = args.save_every
grpo_config = GRPOConfig(**_filter_supported_kwargs(GRPOConfig, config_kwargs))
trainer_kwargs = {
"model": model,
"reward_funcs": [reward_fn],
"args": grpo_config,
"train_dataset": dataset,
"processing_class": tokenizer,
"tokenizer": tokenizer,
}
trainer = GRPOTrainer(**_filter_supported_kwargs(GRPOTrainer, trainer_kwargs))
trainer.train()
model.save_pretrained(output_dir / "final")
tokenizer.save_pretrained(output_dir / "final")
log_history = list(getattr(getattr(trainer, "state", None), "log_history", []) or [])
train_rows = _trainer_log_rows(log_history, selected_task=args.task)
if not train_rows:
train_rows = [{
"episode": index + 1,
"task": "mixed" if args.task == "all" else args.task,
"stage": "grpo",
"score": 0.50,
} for index in range(max(1, args.grpo_epochs))]
try:
evaluation_rows = evaluate_model_suite(
model=model,
tokenizer=tokenizer,
selected_task=args.task,
eval_episodes=args.eval_episodes,
max_steps=args.max_steps,
use_tools=args.use_tools,
world_split=args.train_world_split,
seed_start=args.heldout_seed,
)
except Exception as exc:
print(f"GRPO in-distribution evaluation failed: {exc}")
evaluation_rows = []
try:
heldout_evaluation_rows = evaluate_model_suite(
model=model,
tokenizer=tokenizer,
selected_task=args.task,
eval_episodes=args.eval_episodes,
max_steps=args.max_steps,
use_tools=args.use_tools,
world_split=args.heldout_world_split,
seed_start=args.heldout_seed,
)
except Exception as exc:
print(f"GRPO held-out evaluation failed: {exc}")
heldout_evaluation_rows = []
metrics_path = save_metrics(
output_dir=output_dir,
rows=train_rows,
model_name=model_name,
episodes=max(1, len(train_rows)),
curriculum=args.curriculum,
use_tools=args.use_tools,
trainer="grpo",
evaluation_rows=evaluation_rows,
heldout_evaluation_rows=heldout_evaluation_rows,
prompt_bank_size=len(prompt_bank),
extra={
"train_world_split": args.train_world_split,
"heldout_world_split": args.heldout_world_split,
"heldout_seed": args.heldout_seed,
"base_model": model_name,
},
)
if args.plot:
maybe_plot(metrics_path, output_dir)
print("GRPO training complete.")
print(f"Prompt bank size: {len(prompt_bank)}")
print("Post-train online evaluation:")
for row in evaluation_rows:
print(
f" task={row['task']:<20} score={row['score']:.3f} "
f"steps={row['steps']} tools={row['tool_calls']}"
)
print("Held-out family evaluation:")
for row in heldout_evaluation_rows:
print(
f" task={row['task']:<20} score={row['score']:.3f} "
f"steps={row['steps']} tools={row['tool_calls']}"
)
if log_history:
final_keys = sorted(log_history[-1].keys())
print(f"Trainer log keys: {final_keys}")
print(f"Metrics saved to: {metrics_path}")
def _looks_like_adapter_path(model_name: str) -> bool:
path = Path(str(model_name))
return path.exists() and (path / "adapter_config.json").exists()
def run_fallback_smoke(args: argparse.Namespace) -> None:
if args.use_tools:
run_tool_fallback_smoke(args)
return
if args.curriculum:
tasks = [
task_for_episode(
episode=episode,
total_episodes=min(args.episodes, args.smoke_episodes),
selected_task=args.task,
curriculum=True,
)[0]
for episode in range(1, min(args.episodes, args.smoke_episodes) + 1)
]
else:
tasks = TASKS if args.task == "all" else [args.task]
rows = run_smoke_training(
tasks=tasks,
episodes=min(args.episodes, args.smoke_episodes),
output=Path(args.output) / "train_smoke.csv",
seed=args.seed,
epsilon=0.85,
epsilon_decay=0.94,
epsilon_floor=0.08,
lr=0.35,
max_steps=args.max_steps,
)
output_dir = Path(args.output)
metrics_rows = []
for row in rows:
row = dict(row)
episode = int(row["episode"])
_, stage = task_for_episode(
episode=episode,
total_episodes=min(args.episodes, args.smoke_episodes),
selected_task=args.task,
curriculum=args.curriculum,
)
row["stage"] = stage
metrics_rows.append(row)
metrics_path = save_metrics(
output_dir=output_dir,
rows=metrics_rows,
model_name="smoke-tabular-policy",
episodes=min(args.episodes, args.smoke_episodes),
curriculum=args.curriculum,
use_tools=False,
)
print(f"Metrics saved to: {metrics_path}")
if args.plot:
maybe_plot(metrics_path, output_dir)
def run_tool_fallback_smoke(args: argparse.Namespace) -> None:
"""No-GPU tool-aware rehearsal. This validates flow, not model learning."""
from tool_baseline import run_task as run_tool_task
total = min(args.episodes, args.smoke_episodes)
if args.curriculum:
tasks = [
task_for_episode(
episode=episode,
total_episodes=total,
selected_task=args.task,
curriculum=True,
)[0]
for episode in range(1, total + 1)
]
else:
tasks = TASKS if args.task == "all" else [args.task]
print("AdaptShield tool-aware smoke evaluation")
print("Mode: no-GPU flow validation, not model learning")
print(f"Tasks: {', '.join(tasks)}")
print(f"Episodes: {total}")
print()
rows: List[Dict[str, Any]] = []
for episode in range(1, total + 1):
task = tasks[(episode - 1) % len(tasks)]
result = run_tool_task(task, emit_logs=False)
metadata = result.get("metadata", {})
tool_calls = len(metadata.get("tool_trace", [])) if isinstance(metadata, dict) else 0
_, stage = task_for_episode(
episode=episode,
total_episodes=total,
selected_task=args.task,
curriculum=args.curriculum,
)
row = {
"episode": episode,
"task": task,
"stage": stage,
"score": result["score"],
"steps": result["steps"],
"reward_sum": sum(result["rewards"]),
"mean_reward": sum(result["rewards"]) / len(result["rewards"]) if result["rewards"] else 0.0,
"tool_calls": tool_calls,
"status": "PASS" if result["success"] else "FAIL",
}
rows.append(row)
print(
f"episode={episode:03d} task={task:<20} "
f"score={row['score']:.3f} steps={row['steps']:02d} "
f"tools={tool_calls:02d} {row['status']}"
)
output_dir = Path(args.output)
metrics_path = save_metrics(
output_dir=output_dir,
rows=rows,
model_name="tool-aware-smoke-policy",
episodes=total,
curriculum=args.curriculum,
use_tools=True,
)
print(f"Metrics saved to: {metrics_path}")
if args.plot:
maybe_plot(metrics_path, output_dir)
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="AdaptShield training harness.")
parser.add_argument("--task", default="direct-triage", choices=TASKS + ["all"])
parser.add_argument("--model", default=DEFAULT_MODEL, choices=list(MODEL_CHOICES))
parser.add_argument("--model-path", default="", help="Optional local/HF adapter path to continue training from.")
parser.add_argument("--episodes", type=int, default=60)
parser.add_argument("--max-steps", type=int, default=30)
parser.add_argument("--output", default="checkpoints/adaptshield")
parser.add_argument("--seed", type=int, default=42)
parser.add_argument("--lr", type=float, default=1e-5)
parser.add_argument("--save-every", type=int, default=20)
parser.add_argument("--smoke", action="store_true", help="Force dependency-free smoke mode.")
parser.add_argument("--smoke-episodes", type=int, default=30)
parser.add_argument("--curriculum", action="store_true", help="Train direct -> dual -> hard instead of fixed/round-robin tasks.")
parser.add_argument("--use-tools", action="store_true", help="Let GPU training query SOC tools before hard-task actions.")
parser.add_argument("--plot", action="store_true", help="Generate reward_curve.png from metrics.json after training.")
parser.add_argument("--trainer", default="auto", choices=["auto", "pg", "grpo"], help="Training backend: safe policy-gradient fallback or TRL GRPO.")
parser.add_argument("--prompt-bank-episodes", type=int, default=24, help="Reference rollout episodes used to build the GRPO prompt bank.")
parser.add_argument("--prompt-bank-hard-multiplier", type=int, default=2, help="Duplicate hard-task GRPO prompts this many times to emphasize difficult slices.")
parser.add_argument("--prompt-bank-borderline-bonus", type=int, default=1, help="Extra copies for degraded-handoff / borderline GRPO prompts.")
parser.add_argument("--grpo-epochs", type=int, default=1, help="Number of epochs over the prompt bank for GRPO runs.")
parser.add_argument("--num-generations", type=int, default=4, help="GRPO generations per prompt when TRL path is active.")
parser.add_argument("--per-device-batch-size", type=int, default=1, help="Per-device batch size for GRPO training.")
parser.add_argument("--gradient-accumulation-steps", type=int, default=4, help="Gradient accumulation for GRPO training.")
parser.add_argument("--eval-episodes", type=int, default=2, help="Online environment episodes per task after GPU training.")
parser.add_argument("--train-world-split", default="train", choices=["train", "eval"], help="World split used for training/prompt-bank generation.")
parser.add_argument("--heldout-world-split", default="eval", choices=["train", "eval"], help="World split used for held-out evaluation.")
parser.add_argument("--heldout-seed", type=int, default=314, help="Seed offset used for held-out evaluation episodes.")
return parser.parse_args()
def main() -> int:
args = parse_args()
if args.smoke:
run_fallback_smoke(args)
return 0
trainer_choice = args.trainer
if trainer_choice == "auto":
try:
import datasets # noqa: F401
import trl # noqa: F401
trainer_choice = "grpo"
except ImportError:
trainer_choice = "pg"
try:
if trainer_choice == "grpo":
train_grpo(args)
else:
train_policy_gradient(args)
except ImportError as exc:
print(f"GPU training dependency missing for trainer={trainer_choice}: {exc}")
if trainer_choice == "grpo":
print("Falling back to policy-gradient GPU trainer.")
try:
train_policy_gradient(args)
return 0
except ImportError as nested_exc:
print(f"Policy-gradient fallback also unavailable: {nested_exc}")
print("Falling back to dependency-free smoke training.")
run_fallback_smoke(args)
return 0
if __name__ == "__main__":
raise SystemExit(main())