Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| """AdaptShield onsite GPU training harness with safe local fallback.""" | |
| from __future__ import annotations | |
| import argparse | |
| import inspect | |
| import json | |
| import os | |
| import random | |
| import sys | |
| import time | |
| from pathlib import Path | |
| from typing import Any, Dict, List, Tuple | |
| REPO_ROOT = Path(__file__).resolve().parent | |
| if str(REPO_ROOT) not in sys.path: | |
| sys.path.insert(0, str(REPO_ROOT)) | |
| from models import AdaptShieldAction | |
| from server.adaptshield_environment import AdaptShieldEnvironment | |
| from train_smoke import TASKS, run_smoke_training | |
| from plot_training import plot as plot_training_output | |
| from soc_tools import attach_tool_results, investigate_local_with_depth, summarize_tool_results | |
| MODEL_CHOICES = { | |
| "0.5b": "unsloth/Qwen2.5-0.5B-Instruct", | |
| "1.5b": "unsloth/Qwen2.5-1.5B-Instruct", | |
| "3b": "unsloth/Qwen2.5-3B-Instruct", | |
| "7b": "unsloth/Qwen2.5-7B-Instruct", | |
| } | |
| DEFAULT_MODEL = "1.5b" | |
| MAX_SEQ_LEN = 2048 | |
| LORA_RANK = 16 | |
| CURRICULUM_STAGES = [ | |
| ("direct-triage", 0.30), | |
| ("dual-pivot", 0.40), | |
| ("polymorphic-zero-day", 0.30), | |
| ] | |
| PHASE1_SYS = """You are a Threat Analyst for a 4-node enterprise network. | |
| Analyze SIEM metrics and alerts. Identify the threat. | |
| Attack strategies: brute_force, lateral_movement, exfiltration, supply_chain, benign | |
| Nodes: auth_service, payment_service, database, api_gateway | |
| Actions: rate_limit, isolate, honeypot, patch, monitor | |
| If SOC tool evidence is provided, use it to update your belief before classifying. | |
| Respond ONLY with valid JSON: | |
| {"threat_type":"...","confidence":0.0,"target_node":"...","recommended_action":"...","reasoning":"..."}""" | |
| PHASE2_SYS = """You are a Tactical Executor. Act only on the analyst handoff. | |
| You cannot see raw network data in Phase 2. | |
| Use the analyst handoff plus any SOC tool trace from this turn. | |
| Actions: rate_limit, isolate, honeypot, patch, monitor | |
| Nodes: auth_service, payment_service, database, api_gateway | |
| Respond ONLY with valid JSON: | |
| {"action":"...","target_node":"...","reasoning":"..."}""" | |
| def obs_to_dict(obs: Any) -> Dict[str, Any]: | |
| if hasattr(obs, "model_dump"): | |
| return obs.model_dump(mode="json") | |
| return dict(obs) | |
| def make_phase1_prompt(obs: Dict[str, Any]) -> str: | |
| return "\n".join([ | |
| "Network nodes:", | |
| json.dumps(obs.get("network_nodes", {}), indent=2), | |
| "", | |
| "Active alerts:", | |
| "\n".join(obs.get("active_alerts", [])), | |
| "", | |
| "SOC tool evidence:", | |
| summarize_tool_results(obs.get("tool_results", [])), | |
| "", | |
| "Recent history:", | |
| json.dumps(obs.get("history", [])[-3:], indent=2), | |
| "", | |
| "Classify the threat:", | |
| ]) | |
| def make_phase2_prompt(obs: Dict[str, Any]) -> str: | |
| metadata = obs.get("metadata", {}) if isinstance(obs.get("metadata", {}), dict) else {} | |
| current_turn = int(obs.get("turn", 0) or 0) | |
| tool_trace = [ | |
| row for row in metadata.get("tool_trace", []) | |
| if int(row.get("turn", -1)) == current_turn | |
| ] | |
| return "\n".join([ | |
| "Threat assessment from analyst:", | |
| json.dumps(obs.get("phase1_assessment", {}), indent=2), | |
| "", | |
| "SOC tool trace for this turn:", | |
| json.dumps(tool_trace, indent=2), | |
| "", | |
| "Choose the defensive action:", | |
| ]) | |
| def build_messages(obs: Dict[str, Any]) -> List[Dict[str, str]]: | |
| if int(obs.get("phase", 1)) == 1: | |
| return [ | |
| {"role": "system", "content": PHASE1_SYS}, | |
| {"role": "user", "content": make_phase1_prompt(obs)}, | |
| ] | |
| return [ | |
| {"role": "system", "content": PHASE2_SYS}, | |
| {"role": "user", "content": make_phase2_prompt(obs)}, | |
| ] | |
| def task_for_episode( | |
| episode: int, | |
| total_episodes: int, | |
| selected_task: str, | |
| curriculum: bool, | |
| ) -> Tuple[str, str]: | |
| if not curriculum: | |
| if selected_task == "all": | |
| task = TASKS[(episode - 1) % len(TASKS)] | |
| return task, "round_robin" | |
| return selected_task, "fixed" | |
| progress = episode / max(1, total_episodes) | |
| cumulative = 0.0 | |
| for task, fraction in CURRICULUM_STAGES: | |
| cumulative += fraction | |
| if progress <= cumulative: | |
| return task, f"curriculum:{task}" | |
| return CURRICULUM_STAGES[-1][0], f"curriculum:{CURRICULUM_STAGES[-1][0]}" | |
| def save_metrics( | |
| output_dir: Path, | |
| rows: List[Dict[str, Any]], | |
| model_name: str, | |
| episodes: int, | |
| curriculum: bool, | |
| use_tools: bool, | |
| trainer: str = "pg", | |
| evaluation_rows: List[Dict[str, Any]] | None = None, | |
| heldout_evaluation_rows: List[Dict[str, Any]] | None = None, | |
| prompt_bank_size: int = 0, | |
| extra: Dict[str, Any] | None = None, | |
| ) -> Path: | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| best_score = max((float(row["score"]) for row in rows), default=0.0) | |
| metrics_path = output_dir / "metrics.json" | |
| payload = { | |
| "model": model_name, | |
| "episodes": episodes, | |
| "curriculum": curriculum, | |
| "curriculum_stages": CURRICULUM_STAGES, | |
| "use_tools": use_tools, | |
| "trainer": trainer, | |
| "rows": rows, | |
| "best_score": best_score, | |
| } | |
| if evaluation_rows is not None: | |
| payload["evaluation_rows"] = evaluation_rows | |
| if heldout_evaluation_rows is not None: | |
| payload["heldout_evaluation_rows"] = heldout_evaluation_rows | |
| if prompt_bank_size: | |
| payload["prompt_bank_size"] = prompt_bank_size | |
| if extra: | |
| payload.update(extra) | |
| metrics_path.write_text(json.dumps(payload, indent=2)) | |
| return metrics_path | |
| def maybe_plot(metrics_path: Path, output_dir: Path) -> None: | |
| try: | |
| plot_training_output(metrics_path, output_dir / "reward_curve.png") | |
| except Exception as exc: | |
| print(f"Plot generation skipped: {exc}") | |
| def parse_response(text: str, phase: int) -> Dict[str, Any]: | |
| """Parse model JSON. Invalid output becomes a safe phase-correct action.""" | |
| if "```" in text: | |
| for part in text.split("```"): | |
| if "{" in part: | |
| text = part.strip().removeprefix("json").strip() | |
| break | |
| try: | |
| parsed = json.loads(text) | |
| if phase == 1: | |
| return { | |
| "threat_type": str(parsed.get("threat_type", "brute_force")), | |
| "confidence": float(parsed.get("confidence", 0.5)), | |
| "target_node": str(parsed.get("target_node", "auth_service")), | |
| "recommended_action": str(parsed.get("recommended_action", "monitor")), | |
| "reasoning": str(parsed.get("reasoning", "")), | |
| } | |
| return { | |
| "action": str(parsed.get("action", "monitor")), | |
| "target_node": str(parsed.get("target_node", "auth_service")), | |
| "reasoning": str(parsed.get("reasoning", "")), | |
| } | |
| except Exception: | |
| if phase == 1: | |
| return { | |
| "threat_type": "brute_force", | |
| "confidence": 0.5, | |
| "target_node": "auth_service", | |
| "recommended_action": "monitor", | |
| "reasoning": "parse_error", | |
| } | |
| return { | |
| "action": "monitor", | |
| "target_node": "auth_service", | |
| "reasoning": "parse_error", | |
| } | |
| def render_messages(messages: List[Dict[str, str]], tokenizer: Any | None = None) -> str: | |
| if tokenizer is not None and hasattr(tokenizer, "apply_chat_template"): | |
| return tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True, | |
| ) | |
| return "\n\n".join( | |
| f"{message.get('role', 'user').upper()}:\n{message.get('content', '')}" | |
| for message in messages | |
| ) | |
| def generate_response(model: Any, tokenizer: Any, messages: List[Dict[str, str]]) -> Tuple[str, str]: | |
| import torch | |
| prompt = render_messages(messages, tokenizer=tokenizer) | |
| device = getattr(model, "device", None) | |
| if device is None: | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
| pad_token_id = ( | |
| tokenizer.pad_token_id | |
| if getattr(tokenizer, "pad_token_id", None) is not None | |
| else tokenizer.eos_token_id | |
| ) | |
| with torch.no_grad(): | |
| _normalize_generation_config(model) | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=220, | |
| temperature=0.7, | |
| do_sample=True, | |
| pad_token_id=pad_token_id, | |
| ) | |
| new_ids = output_ids[0][inputs["input_ids"].shape[1]:] | |
| response = tokenizer.decode(new_ids, skip_special_tokens=True).strip() | |
| return prompt, response | |
| def _current_reference(env: AdaptShieldEnvironment) -> Dict[str, Any]: | |
| turn_config = dict(getattr(env, "_turn_config", {}) or {}) | |
| is_benign = bool(turn_config.get("is_benign", False)) | |
| threat_type = "benign" if is_benign else str(turn_config.get("strategy", "benign")) | |
| target_node = str(turn_config.get("correct_target", "auth_service")) | |
| expected_action = str(turn_config.get("correct_action", "monitor")) | |
| return { | |
| "threat_type": threat_type, | |
| "target_node": target_node, | |
| "expected_action": expected_action, | |
| "stage": str(turn_config.get("attack_stage", getattr(env._attacker, "current_stage", lambda: "recon")())), | |
| "is_benign": is_benign, | |
| } | |
| def _align_trainable_dtypes(model: Any, target_dtype: Any | None = None) -> str: | |
| """Keep LoRA/trainable params on the same compute dtype as the main model. | |
| Some adapter checkpoints reload trainable LoRA weights as float32, while | |
| Unsloth GRPO kernels run activations in float16/bfloat16. That mismatch | |
| trips fast_lora matmuls at runtime. We fix only trainable floating params. | |
| """ | |
| import torch | |
| if target_dtype is None: | |
| for param in model.parameters(): | |
| if param.is_floating_point() and not param.requires_grad: | |
| target_dtype = param.dtype | |
| break | |
| if target_dtype is None: | |
| for param in model.parameters(): | |
| if param.is_floating_point(): | |
| target_dtype = param.dtype | |
| break | |
| if target_dtype is None: | |
| return "no-floating-params" | |
| converted = 0 | |
| for param in model.parameters(): | |
| if param.requires_grad and param.is_floating_point() and param.dtype != target_dtype: | |
| param.data = param.data.to(target_dtype) | |
| converted += 1 | |
| for buffer_name, buffer in model.named_buffers(): | |
| if "lora_" in buffer_name and buffer.is_floating_point() and buffer.dtype != target_dtype: | |
| buffer.data = buffer.data.to(target_dtype) | |
| if getattr(model, "generation_config", None) is not None: | |
| _normalize_generation_config(model) | |
| return f"{target_dtype} ({converted} trainable params aligned)" | |
| def _normalize_generation_config(model: Any) -> None: | |
| generation_config = getattr(model, "generation_config", None) | |
| if generation_config is None: | |
| return | |
| for field in ("max_length",): | |
| try: | |
| setattr(generation_config, field, None) | |
| except Exception: | |
| continue | |
| def _load_training_model_and_tokenizer( | |
| model_name: str, | |
| model_key: str, | |
| max_seq_length: int, | |
| compute_dtype: Any, | |
| seed: int, | |
| ): | |
| from unsloth import FastLanguageModel | |
| adapter_path = model_name if _looks_like_adapter_path(model_name) else "" | |
| base_model_name = MODEL_CHOICES[model_key] if adapter_path else model_name | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name=base_model_name, | |
| max_seq_length=max_seq_length, | |
| load_in_4bit=True, | |
| dtype=compute_dtype, | |
| ) | |
| if adapter_path: | |
| from peft import PeftModel | |
| model = PeftModel.from_pretrained( | |
| model, | |
| adapter_path, | |
| is_trainable=True, | |
| autocast_adapter_dtype=False, | |
| ) | |
| try: | |
| from transformers import AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(adapter_path, trust_remote_code=True) | |
| except Exception: | |
| pass | |
| else: | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=LORA_RANK, | |
| target_modules=[ | |
| "q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj", | |
| ], | |
| lora_alpha=LORA_RANK * 2, | |
| lora_dropout=0.0, | |
| bias="none", | |
| use_gradient_checkpointing="unsloth", | |
| random_state=seed, | |
| ) | |
| return model, tokenizer | |
| def _teacher_payload(phase: int, reference: Dict[str, Any]) -> Dict[str, Any]: | |
| if phase == 1: | |
| return { | |
| "threat_type": reference["threat_type"], | |
| "confidence": 0.92 if reference["threat_type"] != "benign" else 0.78, | |
| "target_node": reference["target_node"], | |
| "recommended_action": reference["expected_action"], | |
| "reasoning": "reference policy", | |
| } | |
| return { | |
| "action": reference["expected_action"], | |
| "target_node": reference["target_node"], | |
| "reasoning": "reference policy", | |
| } | |
| def build_prompt_bank( | |
| tokenizer: Any | None, | |
| selected_task: str, | |
| curriculum: bool, | |
| rollout_episodes: int, | |
| max_steps: int, | |
| use_tools: bool, | |
| seed: int, | |
| world_split: str = "train", | |
| world_family: str | None = None, | |
| hard_multiplier: int = 2, | |
| borderline_bonus: int = 1, | |
| ) -> List[Dict[str, Any]]: | |
| random.seed(seed) | |
| rows: List[Dict[str, Any]] = [] | |
| for episode in range(1, rollout_episodes + 1): | |
| task, stage = task_for_episode( | |
| episode=episode, | |
| total_episodes=rollout_episodes, | |
| selected_task=selected_task, | |
| curriculum=curriculum, | |
| ) | |
| env = AdaptShieldEnvironment( | |
| task_name=task, | |
| world_split=world_split, | |
| world_family=world_family, | |
| ) | |
| obs = env.reset() | |
| step_count = 0 | |
| while not obs.done and step_count < max_steps: | |
| phase = int(getattr(obs, "phase", 1)) | |
| tool_results = investigate_local_with_depth( | |
| env, | |
| obs, | |
| use_tools=use_tools, | |
| thorough=True, | |
| ) | |
| obs_dict = attach_tool_results(obs_to_dict(obs), tool_results) | |
| messages = build_messages(obs_dict) | |
| reference = _current_reference(env) | |
| rows.append({ | |
| "prompt": render_messages(messages, tokenizer=tokenizer), | |
| "task": task, | |
| "stage": stage, | |
| "phase": phase, | |
| "turn": int(getattr(obs, "turn", 0) or 0), | |
| "attack_stage": reference["stage"], | |
| "world_split": getattr(env, "_world_split", world_split), | |
| "world_family": getattr(env, "_world_family", world_family or ""), | |
| "operational_mode": getattr(env, "_operational_mode", ""), | |
| "expected_threat_type": reference["threat_type"], | |
| "expected_target_node": reference["target_node"], | |
| "expected_recommended_action": reference["expected_action"] if phase == 1 else "", | |
| "expected_action": reference["expected_action"] if phase == 2 else "", | |
| "tool_calls": len(tool_results), | |
| "history_length": len(obs_dict.get("history", [])), | |
| "difficulty_tags": _difficulty_tags( | |
| task=task, | |
| phase=phase, | |
| attack_stage=reference["stage"], | |
| tool_calls=len(tool_results), | |
| handoff_quality=str((obs_dict.get("phase1_assessment") or {}).get("handoff_quality", "")), | |
| ), | |
| }) | |
| base_row = rows[-1] | |
| for _ in range(_prompt_bank_extra_copies( | |
| row=base_row, | |
| hard_multiplier=hard_multiplier, | |
| borderline_bonus=borderline_bonus, | |
| )): | |
| rows.append(dict(base_row)) | |
| obs = env.step(AdaptShieldAction(**_teacher_payload(phase, reference))) | |
| step_count += 1 | |
| return rows | |
| def _difficulty_tags( | |
| task: str, | |
| phase: int, | |
| attack_stage: str, | |
| tool_calls: int, | |
| handoff_quality: str, | |
| ) -> List[str]: | |
| tags: List[str] = [] | |
| if task == "polymorphic-zero-day": | |
| tags.append("hard") | |
| elif task == "dual-pivot": | |
| tags.append("medium") | |
| if phase == 2: | |
| tags.append("phase2") | |
| if attack_stage in {"exploit", "exfiltration"}: | |
| tags.append("late_stage") | |
| if tool_calls >= 3: | |
| tags.append("tool_fusion") | |
| if handoff_quality == "degraded": | |
| tags.append("borderline") | |
| return tags | |
| def _prompt_bank_extra_copies( | |
| row: Dict[str, Any], | |
| hard_multiplier: int, | |
| borderline_bonus: int, | |
| ) -> int: | |
| tags = set(row.get("difficulty_tags", []) or []) | |
| extra = 0 | |
| if row.get("task") == "polymorphic-zero-day": | |
| extra += max(0, hard_multiplier - 1) | |
| elif row.get("task") == "dual-pivot" and "late_stage" in tags: | |
| extra += 1 | |
| if "borderline" in tags or ("phase2" in tags and "tool_fusion" in tags and "late_stage" in tags): | |
| extra += max(0, borderline_bonus) | |
| return extra | |
| def _completion_to_text(completion: Any) -> str: | |
| if isinstance(completion, str): | |
| return completion | |
| if isinstance(completion, dict): | |
| if "content" in completion: | |
| return str(completion.get("content", "")) | |
| if "text" in completion: | |
| return str(completion.get("text", "")) | |
| if isinstance(completion, list): | |
| parts = [] | |
| for item in completion: | |
| if isinstance(item, dict): | |
| parts.append(str(item.get("content", item.get("text", "")))) | |
| else: | |
| parts.append(str(item)) | |
| return "".join(parts) | |
| return str(completion) | |
| def _phase1_reward( | |
| parsed: Dict[str, Any], | |
| expected_threat_type: str, | |
| expected_target_node: str, | |
| expected_recommended_action: str, | |
| ) -> float: | |
| reward = 0.08 | |
| if parsed.get("threat_type") == expected_threat_type: | |
| reward += 0.36 | |
| if parsed.get("target_node") == expected_target_node: | |
| reward += 0.20 | |
| if parsed.get("recommended_action") == expected_recommended_action: | |
| reward += 0.18 | |
| try: | |
| confidence = float(parsed.get("confidence", 0.5)) | |
| except Exception: | |
| confidence = 0.5 | |
| if 0.0 <= confidence <= 1.0: | |
| reward += 0.05 | |
| if parsed.get("threat_type") == expected_threat_type and confidence >= 0.65: | |
| reward += 0.06 | |
| elif parsed.get("threat_type") != expected_threat_type and confidence >= 0.80: | |
| reward -= 0.05 | |
| if parsed.get("recommended_action") == "monitor" and expected_threat_type != "benign": | |
| reward -= 0.05 | |
| return max(0.01, min(0.99, round(reward, 2))) | |
| def _phase2_reward( | |
| parsed: Dict[str, Any], | |
| expected_action: str, | |
| expected_target_node: str, | |
| tool_calls: int, | |
| ) -> float: | |
| reward = 0.08 | |
| if parsed.get("action") == expected_action: | |
| reward += 0.62 | |
| if parsed.get("target_node") == expected_target_node: | |
| reward += 0.18 | |
| if parsed.get("action") == expected_action and tool_calls >= 2: | |
| reward += 0.07 | |
| if parsed.get("action") == "monitor" and expected_action != "monitor": | |
| reward -= 0.08 | |
| return max(0.01, min(0.99, round(reward, 2))) | |
| def build_grpo_reward_fn(): | |
| def reward_fn(completions: List[Any], **kwargs: Any) -> List[float]: | |
| phases = kwargs.get("phase", []) | |
| expected_threat_types = kwargs.get("expected_threat_type", []) | |
| expected_targets = kwargs.get("expected_target_node", []) | |
| expected_recommended_actions = kwargs.get("expected_recommended_action", []) | |
| expected_actions = kwargs.get("expected_action", []) | |
| tool_calls = kwargs.get("tool_calls", []) | |
| rewards: List[float] = [] | |
| for index, completion in enumerate(completions): | |
| phase = int(phases[index]) if phases else 1 | |
| text = _completion_to_text(completion) | |
| parsed = parse_response(text, phase) | |
| if phase == 1: | |
| reward = _phase1_reward( | |
| parsed=parsed, | |
| expected_threat_type=str(expected_threat_types[index]), | |
| expected_target_node=str(expected_targets[index]), | |
| expected_recommended_action=str(expected_recommended_actions[index]), | |
| ) | |
| else: | |
| reward = _phase2_reward( | |
| parsed=parsed, | |
| expected_action=str(expected_actions[index]), | |
| expected_target_node=str(expected_targets[index]), | |
| tool_calls=int(tool_calls[index]) if tool_calls else 0, | |
| ) | |
| rewards.append(reward) | |
| return rewards | |
| return reward_fn | |
| def _filter_supported_kwargs(callable_obj: Any, kwargs: Dict[str, Any]) -> Dict[str, Any]: | |
| try: | |
| signature = inspect.signature(callable_obj) | |
| except (TypeError, ValueError): | |
| return kwargs | |
| valid = {} | |
| for key, value in kwargs.items(): | |
| if key in signature.parameters: | |
| valid[key] = value | |
| return valid | |
| def _trainer_log_rows(log_history: List[Dict[str, Any]], selected_task: str) -> List[Dict[str, Any]]: | |
| rows: List[Dict[str, Any]] = [] | |
| for entry in log_history: | |
| step = entry.get("step") | |
| if step is None: | |
| continue | |
| reward_keys = [ | |
| "reward", | |
| "mean_reward", | |
| "rewards/mean", | |
| "objective", | |
| "objective/rlhf_reward", | |
| ] | |
| score = None | |
| for key in reward_keys: | |
| if key in entry: | |
| try: | |
| score = float(entry[key]) | |
| break | |
| except Exception: | |
| continue | |
| if score is None: | |
| score = 0.50 | |
| row = { | |
| "episode": int(step), | |
| "task": "mixed" if selected_task == "all" else selected_task, | |
| "stage": "grpo", | |
| "score": max(0.01, min(0.99, score)), | |
| "loss": float(entry.get("loss", 0.0) or 0.0), | |
| "learning_rate": float(entry.get("learning_rate", 0.0) or 0.0), | |
| } | |
| rows.append(row) | |
| return rows | |
| def evaluate_model_suite( | |
| model: Any, | |
| tokenizer: Any, | |
| selected_task: str, | |
| eval_episodes: int, | |
| max_steps: int, | |
| use_tools: bool, | |
| world_split: str = "train", | |
| world_family: str | None = None, | |
| seed_start: int | None = None, | |
| ) -> List[Dict[str, Any]]: | |
| tasks = TASKS if selected_task == "all" else [selected_task] | |
| rows: List[Dict[str, Any]] = [] | |
| for task in tasks: | |
| scores: List[float] = [] | |
| steps: List[int] = [] | |
| tool_calls: List[int] = [] | |
| original_seed = os.environ.get("ADAPTSHIELD_SEED") | |
| for episode_index in range(eval_episodes): | |
| if seed_start is not None: | |
| os.environ["ADAPTSHIELD_SEED"] = str(seed_start + len(rows) * 100 + episode_index) | |
| _, metrics = run_model_episode( | |
| model=model, | |
| tokenizer=tokenizer, | |
| task=task, | |
| max_steps=max_steps, | |
| use_tools=use_tools, | |
| world_split=world_split, | |
| world_family=world_family, | |
| ) | |
| scores.append(float(metrics["score"])) | |
| steps.append(int(metrics["steps"])) | |
| tool_calls.append(int(metrics["tool_calls"])) | |
| if original_seed is None: | |
| os.environ.pop("ADAPTSHIELD_SEED", None) | |
| else: | |
| os.environ["ADAPTSHIELD_SEED"] = original_seed | |
| rows.append({ | |
| "episode": len(rows) + 1, | |
| "task": task, | |
| "stage": "evaluation", | |
| "score": round(sum(scores) / len(scores), 3) if scores else 0.50, | |
| "steps": round(sum(steps) / len(steps), 2) if steps else 0.0, | |
| "tool_calls": round(sum(tool_calls) / len(tool_calls), 2) if tool_calls else 0.0, | |
| "eval_episodes": eval_episodes, | |
| "world_split": world_split, | |
| "world_family": world_family or "auto", | |
| }) | |
| return rows | |
| def run_model_episode( | |
| model: Any, | |
| tokenizer: Any, | |
| task: str, | |
| max_steps: int, | |
| use_tools: bool, | |
| world_split: str = "train", | |
| world_family: str | None = None, | |
| ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: | |
| env = AdaptShieldEnvironment( | |
| task_name=task, | |
| world_split=world_split, | |
| world_family=world_family, | |
| ) | |
| obs = env.reset() | |
| samples: List[Dict[str, Any]] = [] | |
| rewards: List[float] = [] | |
| tool_calls = 0 | |
| while not obs.done and len(samples) < max_steps: | |
| phase = int(getattr(obs, "phase", 1)) | |
| tool_results = investigate_local_with_depth( | |
| env, | |
| obs, | |
| use_tools=use_tools, | |
| thorough=True, | |
| ) | |
| tool_calls += len(tool_results) | |
| obs_dict = obs_to_dict(obs) | |
| obs_dict = attach_tool_results(obs_dict, tool_results) | |
| messages = build_messages(obs_dict) | |
| prompt, response = generate_response(model, tokenizer, messages) | |
| payload = parse_response(response, phase) | |
| try: | |
| obs = env.step(AdaptShieldAction(**payload)) | |
| reward = float(obs.reward) | |
| except Exception as exc: | |
| reward = 0.01 | |
| samples.append({ | |
| "prompt": prompt, | |
| "response": response, | |
| "reward": reward, | |
| "phase": phase, | |
| "tool_calls": len(tool_results), | |
| "error": str(exc), | |
| }) | |
| break | |
| rewards.append(reward) | |
| samples.append({ | |
| "prompt": prompt, | |
| "response": response, | |
| "reward": reward, | |
| "phase": phase, | |
| "tool_calls": len(tool_results), | |
| "error": None, | |
| }) | |
| metadata = obs.metadata if isinstance(obs.metadata, dict) else {} | |
| if "normalized_score" not in metadata: | |
| raise RuntimeError("normalized_score missing after training episode") | |
| return samples, { | |
| "score": float(metadata["normalized_score"]), | |
| "steps": len(samples), | |
| "reward_sum": sum(rewards), | |
| "mean_reward": sum(rewards) / len(rewards) if rewards else 0.0, | |
| "tool_calls": tool_calls, | |
| "world_split": world_split, | |
| "world_family": metadata.get("world_family", world_family or "auto"), | |
| "operational_mode": metadata.get("operational_mode", "unknown"), | |
| } | |
| def train_policy_gradient(args: argparse.Namespace) -> None: | |
| import torch | |
| from torch.optim import AdamW | |
| random.seed(args.seed) | |
| torch.manual_seed(args.seed) | |
| model_name = args.model_path or MODEL_CHOICES[args.model] | |
| output_dir = Path(args.output) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| print("AdaptShield policy-gradient GPU training") | |
| print(f"Task: {args.task}") | |
| print(f"Curriculum: {args.curriculum}") | |
| print(f"Use tools: {args.use_tools}") | |
| print(f"Model: {model_name}") | |
| print(f"Episodes: {args.episodes}") | |
| print(f"Output: {output_dir}") | |
| print() | |
| model, tokenizer = _load_training_model_and_tokenizer( | |
| model_name=model_name, | |
| model_key=args.model, | |
| max_seq_length=MAX_SEQ_LEN, | |
| compute_dtype=None, | |
| seed=args.seed, | |
| ) | |
| from unsloth import FastLanguageModel | |
| FastLanguageModel.for_training(model) | |
| dtype_summary = _align_trainable_dtypes(model) | |
| print(f"Aligned trainable parameter dtypes: {dtype_summary}") | |
| optimizer = AdamW(model.parameters(), lr=args.lr, weight_decay=0.01) | |
| rows: List[Dict[str, Any]] = [] | |
| best_score = -1.0 | |
| for episode in range(1, args.episodes + 1): | |
| started = time.time() | |
| task, stage = task_for_episode( | |
| episode=episode, | |
| total_episodes=args.episodes, | |
| selected_task=args.task, | |
| curriculum=args.curriculum, | |
| ) | |
| samples, metrics = run_model_episode( | |
| model=model, | |
| tokenizer=tokenizer, | |
| task=task, | |
| max_steps=args.max_steps, | |
| use_tools=args.use_tools, | |
| world_split=args.train_world_split, | |
| ) | |
| rewards = [float(sample["reward"]) for sample in samples] | |
| baseline = sum(rewards) / len(rewards) if rewards else 0.0 | |
| total_loss = 0.0 | |
| for sample in samples: | |
| advantage = float(sample["reward"]) - baseline | |
| full_text = sample["prompt"] + sample["response"] + tokenizer.eos_token | |
| inputs = tokenizer( | |
| full_text, | |
| return_tensors="pt", | |
| truncation=True, | |
| max_length=MAX_SEQ_LEN, | |
| ).to("cuda") | |
| outputs = model(**inputs, labels=inputs["input_ids"]) | |
| loss = outputs.loss * (-advantage) | |
| optimizer.zero_grad() | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| total_loss += float(loss.item()) | |
| row = { | |
| "episode": episode, | |
| "task": task, | |
| "stage": stage, | |
| "score": metrics["score"], | |
| "steps": metrics["steps"], | |
| "reward_sum": metrics["reward_sum"], | |
| "mean_reward": metrics["mean_reward"], | |
| "tool_calls": metrics["tool_calls"], | |
| "loss": total_loss, | |
| "seconds": round(time.time() - started, 2), | |
| } | |
| rows.append(row) | |
| print( | |
| f"episode={episode:03d} task={task:<20} " | |
| f"stage={stage:<32} " | |
| f"score={row['score']:.3f} mean_reward={row['mean_reward']:.3f} " | |
| f"loss={row['loss']:.4f} steps={row['steps']:02d} tools={row['tool_calls']:02d}" | |
| ) | |
| if row["score"] > best_score: | |
| best_score = row["score"] | |
| model.save_pretrained(output_dir / "best") | |
| tokenizer.save_pretrained(output_dir / "best") | |
| if args.save_every and episode % args.save_every == 0: | |
| model.save_pretrained(output_dir / f"checkpoint-{episode}") | |
| tokenizer.save_pretrained(output_dir / f"checkpoint-{episode}") | |
| model.save_pretrained(output_dir / "final") | |
| tokenizer.save_pretrained(output_dir / "final") | |
| evaluation_rows = evaluate_model_suite( | |
| model=model, | |
| tokenizer=tokenizer, | |
| selected_task=args.task, | |
| eval_episodes=args.eval_episodes, | |
| max_steps=args.max_steps, | |
| use_tools=args.use_tools, | |
| world_split=args.train_world_split, | |
| seed_start=args.heldout_seed, | |
| ) | |
| heldout_evaluation_rows = evaluate_model_suite( | |
| model=model, | |
| tokenizer=tokenizer, | |
| selected_task=args.task, | |
| eval_episodes=args.eval_episodes, | |
| max_steps=args.max_steps, | |
| use_tools=args.use_tools, | |
| world_split=args.heldout_world_split, | |
| seed_start=args.heldout_seed, | |
| ) | |
| metrics_path = save_metrics( | |
| output_dir=output_dir, | |
| rows=rows, | |
| model_name=model_name, | |
| episodes=args.episodes, | |
| curriculum=args.curriculum, | |
| use_tools=args.use_tools, | |
| trainer="pg", | |
| evaluation_rows=evaluation_rows, | |
| heldout_evaluation_rows=heldout_evaluation_rows, | |
| extra={ | |
| "train_world_split": args.train_world_split, | |
| "heldout_world_split": args.heldout_world_split, | |
| "heldout_seed": args.heldout_seed, | |
| }, | |
| ) | |
| if args.plot: | |
| maybe_plot(metrics_path, output_dir) | |
| print() | |
| print(f"Training complete. Best score: {best_score:.3f}") | |
| print("Post-train online evaluation:") | |
| for row in evaluation_rows: | |
| print( | |
| f" task={row['task']:<20} score={row['score']:.3f} " | |
| f"steps={row['steps']} tools={row['tool_calls']}" | |
| ) | |
| print("Held-out family evaluation:") | |
| for row in heldout_evaluation_rows: | |
| print( | |
| f" task={row['task']:<20} score={row['score']:.3f} " | |
| f"steps={row['steps']} tools={row['tool_calls']}" | |
| ) | |
| print(f"Metrics saved to: {metrics_path}") | |
| def train_grpo(args: argparse.Namespace) -> None: | |
| from datasets import Dataset | |
| from trl import GRPOConfig, GRPOTrainer | |
| import torch | |
| random.seed(args.seed) | |
| torch.manual_seed(args.seed) | |
| model_name = args.model_path or MODEL_CHOICES[args.model] | |
| output_dir = Path(args.output) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| print("AdaptShield GRPO training") | |
| print(f"Task: {args.task}") | |
| print(f"Curriculum: {args.curriculum}") | |
| print(f"Use tools: {args.use_tools}") | |
| print(f"Model: {model_name}") | |
| print(f"Prompt-bank episodes: {args.prompt_bank_episodes}") | |
| print(f"GRPO epochs: {args.grpo_epochs}") | |
| print(f"Eval episodes: {args.eval_episodes}") | |
| print(f"Output: {output_dir}") | |
| print() | |
| bf16_supported = bool(getattr(torch.cuda, "is_bf16_supported", lambda: False)()) | |
| compute_dtype = torch.bfloat16 if bf16_supported else torch.float16 | |
| model, tokenizer = _load_training_model_and_tokenizer( | |
| model_name=model_name, | |
| model_key=args.model, | |
| max_seq_length=MAX_SEQ_LEN, | |
| compute_dtype=compute_dtype, | |
| seed=args.seed, | |
| ) | |
| from unsloth import FastLanguageModel | |
| if getattr(tokenizer, "pad_token", None) is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| if getattr(model, "config", None) is not None: | |
| try: | |
| model.config.return_dict = True | |
| except Exception: | |
| pass | |
| try: | |
| model.config.use_cache = False | |
| except Exception: | |
| pass | |
| if getattr(model, "generation_config", None) is not None: | |
| try: | |
| model.generation_config.pad_token_id = tokenizer.pad_token_id | |
| except Exception: | |
| pass | |
| FastLanguageModel.for_training(model) | |
| dtype_summary = _align_trainable_dtypes(model, target_dtype=compute_dtype) | |
| print(f"Using GRPO compute dtype: {compute_dtype}") | |
| print(f"Aligned trainable parameter dtypes: {dtype_summary}") | |
| prompt_bank = build_prompt_bank( | |
| tokenizer=tokenizer, | |
| selected_task=args.task, | |
| curriculum=args.curriculum, | |
| rollout_episodes=args.prompt_bank_episodes, | |
| max_steps=args.max_steps, | |
| use_tools=args.use_tools, | |
| seed=args.seed, | |
| world_split=args.train_world_split, | |
| hard_multiplier=args.prompt_bank_hard_multiplier, | |
| borderline_bonus=args.prompt_bank_borderline_bonus, | |
| ) | |
| if not prompt_bank: | |
| raise RuntimeError("Prompt bank is empty; cannot start GRPO training.") | |
| dataset = Dataset.from_list(prompt_bank) | |
| reward_fn = build_grpo_reward_fn() | |
| config_kwargs = { | |
| "output_dir": str(output_dir), | |
| "learning_rate": args.lr, | |
| "per_device_train_batch_size": args.per_device_batch_size, | |
| "gradient_accumulation_steps": args.gradient_accumulation_steps, | |
| "num_train_epochs": args.grpo_epochs, | |
| "max_prompt_length": MAX_SEQ_LEN - 256, | |
| "max_completion_length": 256, | |
| "num_generations": args.num_generations, | |
| "logging_steps": 1, | |
| "save_strategy": "no" if args.save_every <= 0 else "steps", | |
| "report_to": "none", | |
| "remove_unused_columns": False, | |
| "bf16": bf16_supported, | |
| "fp16": not bf16_supported, | |
| "max_grad_norm": 1.0, | |
| "seed": args.seed, | |
| } | |
| if args.save_every > 0: | |
| config_kwargs["save_steps"] = args.save_every | |
| grpo_config = GRPOConfig(**_filter_supported_kwargs(GRPOConfig, config_kwargs)) | |
| trainer_kwargs = { | |
| "model": model, | |
| "reward_funcs": [reward_fn], | |
| "args": grpo_config, | |
| "train_dataset": dataset, | |
| "processing_class": tokenizer, | |
| "tokenizer": tokenizer, | |
| } | |
| trainer = GRPOTrainer(**_filter_supported_kwargs(GRPOTrainer, trainer_kwargs)) | |
| trainer.train() | |
| model.save_pretrained(output_dir / "final") | |
| tokenizer.save_pretrained(output_dir / "final") | |
| log_history = list(getattr(getattr(trainer, "state", None), "log_history", []) or []) | |
| train_rows = _trainer_log_rows(log_history, selected_task=args.task) | |
| if not train_rows: | |
| train_rows = [{ | |
| "episode": index + 1, | |
| "task": "mixed" if args.task == "all" else args.task, | |
| "stage": "grpo", | |
| "score": 0.50, | |
| } for index in range(max(1, args.grpo_epochs))] | |
| try: | |
| evaluation_rows = evaluate_model_suite( | |
| model=model, | |
| tokenizer=tokenizer, | |
| selected_task=args.task, | |
| eval_episodes=args.eval_episodes, | |
| max_steps=args.max_steps, | |
| use_tools=args.use_tools, | |
| world_split=args.train_world_split, | |
| seed_start=args.heldout_seed, | |
| ) | |
| except Exception as exc: | |
| print(f"GRPO in-distribution evaluation failed: {exc}") | |
| evaluation_rows = [] | |
| try: | |
| heldout_evaluation_rows = evaluate_model_suite( | |
| model=model, | |
| tokenizer=tokenizer, | |
| selected_task=args.task, | |
| eval_episodes=args.eval_episodes, | |
| max_steps=args.max_steps, | |
| use_tools=args.use_tools, | |
| world_split=args.heldout_world_split, | |
| seed_start=args.heldout_seed, | |
| ) | |
| except Exception as exc: | |
| print(f"GRPO held-out evaluation failed: {exc}") | |
| heldout_evaluation_rows = [] | |
| metrics_path = save_metrics( | |
| output_dir=output_dir, | |
| rows=train_rows, | |
| model_name=model_name, | |
| episodes=max(1, len(train_rows)), | |
| curriculum=args.curriculum, | |
| use_tools=args.use_tools, | |
| trainer="grpo", | |
| evaluation_rows=evaluation_rows, | |
| heldout_evaluation_rows=heldout_evaluation_rows, | |
| prompt_bank_size=len(prompt_bank), | |
| extra={ | |
| "train_world_split": args.train_world_split, | |
| "heldout_world_split": args.heldout_world_split, | |
| "heldout_seed": args.heldout_seed, | |
| "base_model": model_name, | |
| }, | |
| ) | |
| if args.plot: | |
| maybe_plot(metrics_path, output_dir) | |
| print("GRPO training complete.") | |
| print(f"Prompt bank size: {len(prompt_bank)}") | |
| print("Post-train online evaluation:") | |
| for row in evaluation_rows: | |
| print( | |
| f" task={row['task']:<20} score={row['score']:.3f} " | |
| f"steps={row['steps']} tools={row['tool_calls']}" | |
| ) | |
| print("Held-out family evaluation:") | |
| for row in heldout_evaluation_rows: | |
| print( | |
| f" task={row['task']:<20} score={row['score']:.3f} " | |
| f"steps={row['steps']} tools={row['tool_calls']}" | |
| ) | |
| if log_history: | |
| final_keys = sorted(log_history[-1].keys()) | |
| print(f"Trainer log keys: {final_keys}") | |
| print(f"Metrics saved to: {metrics_path}") | |
| def _looks_like_adapter_path(model_name: str) -> bool: | |
| path = Path(str(model_name)) | |
| return path.exists() and (path / "adapter_config.json").exists() | |
| def run_fallback_smoke(args: argparse.Namespace) -> None: | |
| if args.use_tools: | |
| run_tool_fallback_smoke(args) | |
| return | |
| if args.curriculum: | |
| tasks = [ | |
| task_for_episode( | |
| episode=episode, | |
| total_episodes=min(args.episodes, args.smoke_episodes), | |
| selected_task=args.task, | |
| curriculum=True, | |
| )[0] | |
| for episode in range(1, min(args.episodes, args.smoke_episodes) + 1) | |
| ] | |
| else: | |
| tasks = TASKS if args.task == "all" else [args.task] | |
| rows = run_smoke_training( | |
| tasks=tasks, | |
| episodes=min(args.episodes, args.smoke_episodes), | |
| output=Path(args.output) / "train_smoke.csv", | |
| seed=args.seed, | |
| epsilon=0.85, | |
| epsilon_decay=0.94, | |
| epsilon_floor=0.08, | |
| lr=0.35, | |
| max_steps=args.max_steps, | |
| ) | |
| output_dir = Path(args.output) | |
| metrics_rows = [] | |
| for row in rows: | |
| row = dict(row) | |
| episode = int(row["episode"]) | |
| _, stage = task_for_episode( | |
| episode=episode, | |
| total_episodes=min(args.episodes, args.smoke_episodes), | |
| selected_task=args.task, | |
| curriculum=args.curriculum, | |
| ) | |
| row["stage"] = stage | |
| metrics_rows.append(row) | |
| metrics_path = save_metrics( | |
| output_dir=output_dir, | |
| rows=metrics_rows, | |
| model_name="smoke-tabular-policy", | |
| episodes=min(args.episodes, args.smoke_episodes), | |
| curriculum=args.curriculum, | |
| use_tools=False, | |
| ) | |
| print(f"Metrics saved to: {metrics_path}") | |
| if args.plot: | |
| maybe_plot(metrics_path, output_dir) | |
| def run_tool_fallback_smoke(args: argparse.Namespace) -> None: | |
| """No-GPU tool-aware rehearsal. This validates flow, not model learning.""" | |
| from tool_baseline import run_task as run_tool_task | |
| total = min(args.episodes, args.smoke_episodes) | |
| if args.curriculum: | |
| tasks = [ | |
| task_for_episode( | |
| episode=episode, | |
| total_episodes=total, | |
| selected_task=args.task, | |
| curriculum=True, | |
| )[0] | |
| for episode in range(1, total + 1) | |
| ] | |
| else: | |
| tasks = TASKS if args.task == "all" else [args.task] | |
| print("AdaptShield tool-aware smoke evaluation") | |
| print("Mode: no-GPU flow validation, not model learning") | |
| print(f"Tasks: {', '.join(tasks)}") | |
| print(f"Episodes: {total}") | |
| print() | |
| rows: List[Dict[str, Any]] = [] | |
| for episode in range(1, total + 1): | |
| task = tasks[(episode - 1) % len(tasks)] | |
| result = run_tool_task(task, emit_logs=False) | |
| metadata = result.get("metadata", {}) | |
| tool_calls = len(metadata.get("tool_trace", [])) if isinstance(metadata, dict) else 0 | |
| _, stage = task_for_episode( | |
| episode=episode, | |
| total_episodes=total, | |
| selected_task=args.task, | |
| curriculum=args.curriculum, | |
| ) | |
| row = { | |
| "episode": episode, | |
| "task": task, | |
| "stage": stage, | |
| "score": result["score"], | |
| "steps": result["steps"], | |
| "reward_sum": sum(result["rewards"]), | |
| "mean_reward": sum(result["rewards"]) / len(result["rewards"]) if result["rewards"] else 0.0, | |
| "tool_calls": tool_calls, | |
| "status": "PASS" if result["success"] else "FAIL", | |
| } | |
| rows.append(row) | |
| print( | |
| f"episode={episode:03d} task={task:<20} " | |
| f"score={row['score']:.3f} steps={row['steps']:02d} " | |
| f"tools={tool_calls:02d} {row['status']}" | |
| ) | |
| output_dir = Path(args.output) | |
| metrics_path = save_metrics( | |
| output_dir=output_dir, | |
| rows=rows, | |
| model_name="tool-aware-smoke-policy", | |
| episodes=total, | |
| curriculum=args.curriculum, | |
| use_tools=True, | |
| ) | |
| print(f"Metrics saved to: {metrics_path}") | |
| if args.plot: | |
| maybe_plot(metrics_path, output_dir) | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser(description="AdaptShield training harness.") | |
| parser.add_argument("--task", default="direct-triage", choices=TASKS + ["all"]) | |
| parser.add_argument("--model", default=DEFAULT_MODEL, choices=list(MODEL_CHOICES)) | |
| parser.add_argument("--model-path", default="", help="Optional local/HF adapter path to continue training from.") | |
| parser.add_argument("--episodes", type=int, default=60) | |
| parser.add_argument("--max-steps", type=int, default=30) | |
| parser.add_argument("--output", default="checkpoints/adaptshield") | |
| parser.add_argument("--seed", type=int, default=42) | |
| parser.add_argument("--lr", type=float, default=1e-5) | |
| parser.add_argument("--save-every", type=int, default=20) | |
| parser.add_argument("--smoke", action="store_true", help="Force dependency-free smoke mode.") | |
| parser.add_argument("--smoke-episodes", type=int, default=30) | |
| parser.add_argument("--curriculum", action="store_true", help="Train direct -> dual -> hard instead of fixed/round-robin tasks.") | |
| parser.add_argument("--use-tools", action="store_true", help="Let GPU training query SOC tools before hard-task actions.") | |
| parser.add_argument("--plot", action="store_true", help="Generate reward_curve.png from metrics.json after training.") | |
| parser.add_argument("--trainer", default="auto", choices=["auto", "pg", "grpo"], help="Training backend: safe policy-gradient fallback or TRL GRPO.") | |
| parser.add_argument("--prompt-bank-episodes", type=int, default=24, help="Reference rollout episodes used to build the GRPO prompt bank.") | |
| parser.add_argument("--prompt-bank-hard-multiplier", type=int, default=2, help="Duplicate hard-task GRPO prompts this many times to emphasize difficult slices.") | |
| parser.add_argument("--prompt-bank-borderline-bonus", type=int, default=1, help="Extra copies for degraded-handoff / borderline GRPO prompts.") | |
| parser.add_argument("--grpo-epochs", type=int, default=1, help="Number of epochs over the prompt bank for GRPO runs.") | |
| parser.add_argument("--num-generations", type=int, default=4, help="GRPO generations per prompt when TRL path is active.") | |
| parser.add_argument("--per-device-batch-size", type=int, default=1, help="Per-device batch size for GRPO training.") | |
| parser.add_argument("--gradient-accumulation-steps", type=int, default=4, help="Gradient accumulation for GRPO training.") | |
| parser.add_argument("--eval-episodes", type=int, default=2, help="Online environment episodes per task after GPU training.") | |
| parser.add_argument("--train-world-split", default="train", choices=["train", "eval"], help="World split used for training/prompt-bank generation.") | |
| parser.add_argument("--heldout-world-split", default="eval", choices=["train", "eval"], help="World split used for held-out evaluation.") | |
| parser.add_argument("--heldout-seed", type=int, default=314, help="Seed offset used for held-out evaluation episodes.") | |
| return parser.parse_args() | |
| def main() -> int: | |
| args = parse_args() | |
| if args.smoke: | |
| run_fallback_smoke(args) | |
| return 0 | |
| trainer_choice = args.trainer | |
| if trainer_choice == "auto": | |
| try: | |
| import datasets # noqa: F401 | |
| import trl # noqa: F401 | |
| trainer_choice = "grpo" | |
| except ImportError: | |
| trainer_choice = "pg" | |
| try: | |
| if trainer_choice == "grpo": | |
| train_grpo(args) | |
| else: | |
| train_policy_gradient(args) | |
| except ImportError as exc: | |
| print(f"GPU training dependency missing for trainer={trainer_choice}: {exc}") | |
| if trainer_choice == "grpo": | |
| print("Falling back to policy-gradient GPU trainer.") | |
| try: | |
| train_policy_gradient(args) | |
| return 0 | |
| except ImportError as nested_exc: | |
| print(f"Policy-gradient fallback also unavailable: {nested_exc}") | |
| print("Falling back to dependency-free smoke training.") | |
| run_fallback_smoke(args) | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |