""" AdaptShield Inference Script Single task per run. Emits mandatory [START]/[STEP]/[END] stdout format. All credentials read from environment — never hardcoded. Required env vars (injected by evaluator): API_KEY: Evaluator's LiteLLM proxy key (checked first) API_BASE_URL: LLM endpoint MODEL_NAME: Model identifier Optional env vars: HF_TOKEN: Fallback if API_KEY not set ADAPTSHIELD_TASK: Task name (default: direct-triage) ENV_BASE_URL: Environment server URL (default: localhost:7860) """ import json import os import sys import textwrap from typing import Any, Dict, List, Optional import urllib.request import urllib.error from openai import OpenAI from client import AdaptshieldEnv from models import AdaptShieldAction from soc_tools import attach_tool_results, investigate_http, summarize_tool_results # ── Configuration — read from env, NEVER hardcode ────────────────────────── API_KEY = os.environ.get("API_KEY") or os.environ.get("HF_TOKEN", "") API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1") MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") TASK_NAME = os.environ.get("ADAPTSHIELD_TASK", "direct-triage") ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860").rstrip("/") BENCHMARK = "adaptshield" MAX_STEPS = 25 SUCCESS_THRESHOLD = 0.50 USE_TOOLS_SETTING = os.environ.get("ADAPTSHIELD_USE_TOOLS", "auto").lower() # ── Mandatory stdout format ──────────────────────────────────────────────── def log_start(task: str, env: str, model: str) -> None: print(f"[START] task={task} env={env} model={model}", flush=True) def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: ev = error if error else "null" print( f"[STEP] step={step} action={action} " f"reward={reward:.2f} done={str(done).lower()} error={ev}", flush=True, ) def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: rs = ",".join(f"{r:.2f}" for r in rewards) print( f"[END] success={str(success).lower()} steps={steps} " f"score={score:.3f} rewards={rs}", flush=True, ) # ── Environment calls ────────────────────────────────────────────────────── def env_post(path: str, data: Dict) -> Dict: url = f"{ENV_BASE_URL}{path}" body = json.dumps(data).encode() req = urllib.request.Request( url, data=body, headers={"Content-Type": "application/json"} ) with urllib.request.urlopen(req, timeout=60) as r: return json.loads(r.read()) def obs_to_dict(obs: Any) -> Dict[str, Any]: """Convert Pydantic observations from the persistent client to JSON dicts.""" if hasattr(obs, "model_dump"): return obs.model_dump(mode="json") return dict(obs) def build_env_action(parsed: Dict[str, Any], phase: int) -> AdaptShieldAction: """Validate model output and fall back to a phase-correct safe action.""" try: return AdaptShieldAction(**parsed) except Exception: if phase == 1: return AdaptShieldAction( threat_type="brute_force", confidence=0.5, target_node="auth_service", recommended_action="monitor", reasoning="validated fallback", ) return AdaptShieldAction( action="monitor", target_node="auth_service", reasoning="validated fallback", ) # ── Score computation — strictly (0.01, 0.99) ───────────────────────────── def safe_score(rewards: List[float], meta: Dict) -> float: if "normalized_score" in meta: raw = float(meta["normalized_score"]) elif rewards: pos = sum(r for r in rewards if r > 0.50) maxp = len(rewards) * 0.99 raw = pos / maxp if maxp > 0 else 0.50 else: raw = 0.50 return max(0.01, min(0.99, raw)) # ── System prompts ───────────────────────────────────────────────────────── PHASE1_SYS = textwrap.dedent(""" You are a Threat Analyst for a 4-node enterprise network. Analyze the SIEM metrics and alerts. Identify the threat type. Attack strategies: brute_force, lateral_movement, exfiltration, supply_chain, benign If SOC tool evidence is provided, use it to update your belief before classifying. Respond ONLY with valid JSON: {"threat_type":"...","confidence":0.0,"target_node":"...","recommended_action":"...","reasoning":"..."} Nodes: auth_service, payment_service, database, api_gateway Actions: rate_limit, isolate, honeypot, patch, monitor """).strip() PHASE2_SYS = textwrap.dedent(""" You are a Tactical Executor. Act on the threat assessment provided. You cannot see raw network data. Use the analyst assessment plus any SOC tool trace. rate_limit=throttle traffic, isolate=take offline, honeypot=redirect attacker, patch=fix vulnerability, monitor=observe only Respond ONLY with valid JSON: {"action":"...","target_node":"...","reasoning":"..."} Nodes: auth_service, payment_service, database, api_gateway """).strip() def get_action(client: OpenAI, obs: Dict) -> Dict[str, Any]: """Call LLM for current phase. Falls back gracefully on parse error.""" phase = obs.get("phase", 1) if phase == 1: sys_msg = PHASE1_SYS user_msg = "\n".join([ "Network nodes:", json.dumps(obs.get("network_nodes", {}), indent=2), "\nAlerts:", "\n".join(obs.get("active_alerts", [])), "\nSOC tool evidence:", summarize_tool_results(obs.get("tool_results", [])), "\nHistory:", json.dumps(obs.get("history", []), indent=2), "\nClassify the threat:", ]) fallback = { "threat_type": "brute_force", "confidence": 0.5, "target_node": "auth_service", "recommended_action": "monitor", "reasoning": "fallback", } else: sys_msg = PHASE2_SYS metadata = obs.get("metadata", {}) if isinstance(obs.get("metadata", {}), dict) else {} current_turn = int(obs.get("turn", 0) or 0) tool_trace = [ row for row in metadata.get("tool_trace", []) if int(row.get("turn", -1)) == current_turn ] user_msg = "\n".join([ "Threat assessment from analyst:", json.dumps(obs.get("phase1_assessment", {}), indent=2), "\nSOC tool trace for this turn:", json.dumps(tool_trace, indent=2), "\nChoose your defensive action:", ]) fallback = { "action": "monitor", "target_node": "auth_service", "reasoning": "fallback", } try: resp = client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": sys_msg}, {"role": "user", "content": user_msg}, ], temperature=0.1, max_tokens=300, stream=False, ) text = (resp.choices[0].message.content or "").strip() # Strip markdown fences if "```" in text: for part in text.split("```"): if "{" in part: text = part.strip().lstrip("json").strip() break return json.loads(text) except Exception as exc: print(f"[DEBUG] phase={phase} parse error: {exc}", flush=True) return fallback def should_use_tools(task_name: str) -> bool: if USE_TOOLS_SETTING in ("1", "true", "yes", "on"): return True if USE_TOOLS_SETTING in ("0", "false", "no", "off"): return False return task_name == "polymorphic-zero-day" def run_soc_episode(client: OpenAI, use_tools: bool) -> tuple[List[float], int, Dict[str, Any]]: rewards: List[float] = [] steps_taken = 0 reset = env_post("/soc/reset", {"task": TASK_NAME}) session_id = str(reset.get("session_id", "")) obs = dict(reset.get("observation", {})) done = bool(obs.get("done", False)) for step in range(1, MAX_STEPS + 1): if done: break tool_results = investigate_http( env_base_url=ENV_BASE_URL, session_id=session_id, obs=obs, use_tools=use_tools, thorough=True, ) obs_for_model = attach_tool_results(obs, tool_results) parsed = get_action(client, obs_for_model) action_str = json.dumps(parsed, separators=(",", ":")) if len(action_str) > 100: action_str = action_str[:97] + "..." try: action = build_env_action(parsed, phase=int(obs.get("phase", 1))) action_payload = action.model_dump( mode="json", exclude_none=True, exclude_defaults=True, ) result = env_post("/soc/step", {"session_id": session_id, "action": action_payload}) obs = dict(result.get("observation", {})) reward = float(result.get("reward", obs.get("reward", 0.0))) done = bool(result.get("done", obs.get("done", False))) error = None except Exception as exc: reward = 0.0 done = True error = str(exc)[:80] rewards.append(reward) steps_taken = step log_step(step=step, action=action_str, reward=reward, done=done, error=error) if done: break return rewards, steps_taken, obs def run_openenv_episode(client: OpenAI) -> tuple[List[float], int, Dict[str, Any]]: rewards: List[float] = [] steps_taken = 0 obs: Dict[str, Any] = {} env = AdaptshieldEnv(base_url=ENV_BASE_URL).sync() with env: result = env.reset(task_name=TASK_NAME) obs = obs_to_dict(result.observation) done = bool(result.done or obs.get("done", False)) for step in range(1, MAX_STEPS + 1): if done: break parsed = get_action(client, obs) action_str = json.dumps(parsed, separators=(",", ":")) if len(action_str) > 100: action_str = action_str[:97] + "..." try: action = build_env_action(parsed, phase=int(obs.get("phase", 1))) sr = env.step(action) obs = obs_to_dict(sr.observation) reward = float(sr.reward if sr.reward is not None else obs.get("reward", 0.0)) done = bool(sr.done or obs.get("done", False)) error = None except Exception as exc: reward = 0.0 done = True error = str(exc)[:80] rewards.append(reward) steps_taken = step log_step(step=step, action=action_str, reward=reward, done=done, error=error) if done: break return rewards, steps_taken, obs def main() -> None: client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) rewards: List[float] = [] steps_taken: int = 0 score: float = 0.50 success: bool = False obs: Dict = {} log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME) try: if should_use_tools(TASK_NAME): rewards, steps_taken, obs = run_soc_episode(client, use_tools=True) else: rewards, steps_taken, obs = run_openenv_episode(client) score = safe_score(rewards, obs.get("metadata", {})) success = score >= SUCCESS_THRESHOLD except Exception as exc: print(f"[DEBUG] episode error: {exc}", flush=True) score = 0.10 log_end(success=success, steps=steps_taken, score=score, rewards=rewards) if __name__ == "__main__": main()