adaptshield / inference.py
SaiManish123's picture
Initial deploy of AdaptShield two-phase cybersecurity environment
c1060df verified
"""
AdaptShield Inference Script
Single task per run. Emits mandatory [START]/[STEP]/[END] stdout format.
All credentials read from environment — never hardcoded.
Required env vars (injected by evaluator):
API_KEY: Evaluator's LiteLLM proxy key (checked first)
API_BASE_URL: LLM endpoint
MODEL_NAME: Model identifier
Optional env vars:
HF_TOKEN: Fallback if API_KEY not set
ADAPTSHIELD_TASK: Task name (default: direct-triage)
ENV_BASE_URL: Environment server URL (default: localhost:7860)
"""
import json
import os
import sys
import textwrap
from typing import Any, Dict, List, Optional
import urllib.request
import urllib.error
from openai import OpenAI
from client import AdaptshieldEnv
from models import AdaptShieldAction
from soc_tools import attach_tool_results, investigate_http, summarize_tool_results
# ── Configuration — read from env, NEVER hardcode ──────────────────────────
API_KEY = os.environ.get("API_KEY") or os.environ.get("HF_TOKEN", "")
API_BASE_URL = os.environ.get("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
TASK_NAME = os.environ.get("ADAPTSHIELD_TASK", "direct-triage")
ENV_BASE_URL = os.environ.get("ENV_BASE_URL", "http://localhost:7860").rstrip("/")
BENCHMARK = "adaptshield"
MAX_STEPS = 25
SUCCESS_THRESHOLD = 0.50
USE_TOOLS_SETTING = os.environ.get("ADAPTSHIELD_USE_TOOLS", "auto").lower()
# ── Mandatory stdout format ────────────────────────────────────────────────
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float,
done: bool, error: Optional[str]) -> None:
ev = error if error else "null"
print(
f"[STEP] step={step} action={action} "
f"reward={reward:.2f} done={str(done).lower()} error={ev}",
flush=True,
)
def log_end(success: bool, steps: int, score: float,
rewards: List[float]) -> None:
rs = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} "
f"score={score:.3f} rewards={rs}",
flush=True,
)
# ── Environment calls ──────────────────────────────────────────────────────
def env_post(path: str, data: Dict) -> Dict:
url = f"{ENV_BASE_URL}{path}"
body = json.dumps(data).encode()
req = urllib.request.Request(
url, data=body, headers={"Content-Type": "application/json"}
)
with urllib.request.urlopen(req, timeout=60) as r:
return json.loads(r.read())
def obs_to_dict(obs: Any) -> Dict[str, Any]:
"""Convert Pydantic observations from the persistent client to JSON dicts."""
if hasattr(obs, "model_dump"):
return obs.model_dump(mode="json")
return dict(obs)
def build_env_action(parsed: Dict[str, Any], phase: int) -> AdaptShieldAction:
"""Validate model output and fall back to a phase-correct safe action."""
try:
return AdaptShieldAction(**parsed)
except Exception:
if phase == 1:
return AdaptShieldAction(
threat_type="brute_force",
confidence=0.5,
target_node="auth_service",
recommended_action="monitor",
reasoning="validated fallback",
)
return AdaptShieldAction(
action="monitor",
target_node="auth_service",
reasoning="validated fallback",
)
# ── Score computation — strictly (0.01, 0.99) ─────────────────────────────
def safe_score(rewards: List[float], meta: Dict) -> float:
if "normalized_score" in meta:
raw = float(meta["normalized_score"])
elif rewards:
pos = sum(r for r in rewards if r > 0.50)
maxp = len(rewards) * 0.99
raw = pos / maxp if maxp > 0 else 0.50
else:
raw = 0.50
return max(0.01, min(0.99, raw))
# ── System prompts ─────────────────────────────────────────────────────────
PHASE1_SYS = textwrap.dedent("""
You are a Threat Analyst for a 4-node enterprise network.
Analyze the SIEM metrics and alerts. Identify the threat type.
Attack strategies: brute_force, lateral_movement, exfiltration, supply_chain, benign
If SOC tool evidence is provided, use it to update your belief before classifying.
Respond ONLY with valid JSON:
{"threat_type":"...","confidence":0.0,"target_node":"...","recommended_action":"...","reasoning":"..."}
Nodes: auth_service, payment_service, database, api_gateway
Actions: rate_limit, isolate, honeypot, patch, monitor
""").strip()
PHASE2_SYS = textwrap.dedent("""
You are a Tactical Executor. Act on the threat assessment provided.
You cannot see raw network data. Use the analyst assessment plus any SOC tool trace.
rate_limit=throttle traffic, isolate=take offline, honeypot=redirect attacker,
patch=fix vulnerability, monitor=observe only
Respond ONLY with valid JSON:
{"action":"...","target_node":"...","reasoning":"..."}
Nodes: auth_service, payment_service, database, api_gateway
""").strip()
def get_action(client: OpenAI, obs: Dict) -> Dict[str, Any]:
"""Call LLM for current phase. Falls back gracefully on parse error."""
phase = obs.get("phase", 1)
if phase == 1:
sys_msg = PHASE1_SYS
user_msg = "\n".join([
"Network nodes:",
json.dumps(obs.get("network_nodes", {}), indent=2),
"\nAlerts:",
"\n".join(obs.get("active_alerts", [])),
"\nSOC tool evidence:",
summarize_tool_results(obs.get("tool_results", [])),
"\nHistory:",
json.dumps(obs.get("history", []), indent=2),
"\nClassify the threat:",
])
fallback = {
"threat_type": "brute_force", "confidence": 0.5,
"target_node": "auth_service", "recommended_action": "monitor",
"reasoning": "fallback",
}
else:
sys_msg = PHASE2_SYS
metadata = obs.get("metadata", {}) if isinstance(obs.get("metadata", {}), dict) else {}
current_turn = int(obs.get("turn", 0) or 0)
tool_trace = [
row for row in metadata.get("tool_trace", [])
if int(row.get("turn", -1)) == current_turn
]
user_msg = "\n".join([
"Threat assessment from analyst:",
json.dumps(obs.get("phase1_assessment", {}), indent=2),
"\nSOC tool trace for this turn:",
json.dumps(tool_trace, indent=2),
"\nChoose your defensive action:",
])
fallback = {
"action": "monitor",
"target_node": "auth_service",
"reasoning": "fallback",
}
try:
resp = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": sys_msg},
{"role": "user", "content": user_msg},
],
temperature=0.1,
max_tokens=300,
stream=False,
)
text = (resp.choices[0].message.content or "").strip()
# Strip markdown fences
if "```" in text:
for part in text.split("```"):
if "{" in part:
text = part.strip().lstrip("json").strip()
break
return json.loads(text)
except Exception as exc:
print(f"[DEBUG] phase={phase} parse error: {exc}", flush=True)
return fallback
def should_use_tools(task_name: str) -> bool:
if USE_TOOLS_SETTING in ("1", "true", "yes", "on"):
return True
if USE_TOOLS_SETTING in ("0", "false", "no", "off"):
return False
return task_name == "polymorphic-zero-day"
def run_soc_episode(client: OpenAI, use_tools: bool) -> tuple[List[float], int, Dict[str, Any]]:
rewards: List[float] = []
steps_taken = 0
reset = env_post("/soc/reset", {"task": TASK_NAME})
session_id = str(reset.get("session_id", ""))
obs = dict(reset.get("observation", {}))
done = bool(obs.get("done", False))
for step in range(1, MAX_STEPS + 1):
if done:
break
tool_results = investigate_http(
env_base_url=ENV_BASE_URL,
session_id=session_id,
obs=obs,
use_tools=use_tools,
thorough=True,
)
obs_for_model = attach_tool_results(obs, tool_results)
parsed = get_action(client, obs_for_model)
action_str = json.dumps(parsed, separators=(",", ":"))
if len(action_str) > 100:
action_str = action_str[:97] + "..."
try:
action = build_env_action(parsed, phase=int(obs.get("phase", 1)))
action_payload = action.model_dump(
mode="json",
exclude_none=True,
exclude_defaults=True,
)
result = env_post("/soc/step", {"session_id": session_id, "action": action_payload})
obs = dict(result.get("observation", {}))
reward = float(result.get("reward", obs.get("reward", 0.0)))
done = bool(result.get("done", obs.get("done", False)))
error = None
except Exception as exc:
reward = 0.0
done = True
error = str(exc)[:80]
rewards.append(reward)
steps_taken = step
log_step(step=step, action=action_str, reward=reward, done=done, error=error)
if done:
break
return rewards, steps_taken, obs
def run_openenv_episode(client: OpenAI) -> tuple[List[float], int, Dict[str, Any]]:
rewards: List[float] = []
steps_taken = 0
obs: Dict[str, Any] = {}
env = AdaptshieldEnv(base_url=ENV_BASE_URL).sync()
with env:
result = env.reset(task_name=TASK_NAME)
obs = obs_to_dict(result.observation)
done = bool(result.done or obs.get("done", False))
for step in range(1, MAX_STEPS + 1):
if done:
break
parsed = get_action(client, obs)
action_str = json.dumps(parsed, separators=(",", ":"))
if len(action_str) > 100:
action_str = action_str[:97] + "..."
try:
action = build_env_action(parsed, phase=int(obs.get("phase", 1)))
sr = env.step(action)
obs = obs_to_dict(sr.observation)
reward = float(sr.reward if sr.reward is not None else obs.get("reward", 0.0))
done = bool(sr.done or obs.get("done", False))
error = None
except Exception as exc:
reward = 0.0
done = True
error = str(exc)[:80]
rewards.append(reward)
steps_taken = step
log_step(step=step, action=action_str, reward=reward,
done=done, error=error)
if done:
break
return rewards, steps_taken, obs
def main() -> None:
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
rewards: List[float] = []
steps_taken: int = 0
score: float = 0.50
success: bool = False
obs: Dict = {}
log_start(task=TASK_NAME, env=BENCHMARK, model=MODEL_NAME)
try:
if should_use_tools(TASK_NAME):
rewards, steps_taken, obs = run_soc_episode(client, use_tools=True)
else:
rewards, steps_taken, obs = run_openenv_episode(client)
score = safe_score(rewards, obs.get("metadata", {}))
success = score >= SUCCESS_THRESHOLD
except Exception as exc:
print(f"[DEBUG] episode error: {exc}", flush=True)
score = 0.10
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
if __name__ == "__main__":
main()