import os, json from typing import List, Optional from dotenv import load_dotenv from openai import OpenAI load_dotenv(override=True) API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY") API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") try: from server.SOC_env_environment import SOCEnvironment from models import SOCAction except ImportError: from SOC_env.server.SOC_env_environment import SOCEnvironment from SOC_env.models import SOCAction BENCHMARK = "SOC_env" MAX_STEPS = 12 TASKS = ["task_easy", "task_medium", "task_hard"] SUCCESS_SCORE_THRESHOLD = 0.60 client = None TASK_SCENARIOS = { "task_easy": "easy_false_positive_vpn", "task_medium": "medium_insider_threat", "task_hard": "hard_apt_lateral_movement", } TASK_DIFFICULTY = { "task_easy": "easy", "task_medium": "medium", "task_hard": "hard", } BASELINE = { "task_easy": ["investigate", "ignore"], "task_medium": ["investigate", "block_account", "collect_forensics", "escalate"], "task_hard": ["investigate", "isolate_device", "block_ip", "collect_forensics", "escalate"], } SYSTEM_PROMPT = '''You are an expert SOC Tier-1 analyst following NIST SP 800-61 incident response procedures. Respond ONLY with JSON: {"decision": "", "reasoning": ""} SOC Playbook Rules: 1. ALWAYS investigate first when context is empty — never act blind 2. Use query_logs or check_threat_intel for additional context before acting 3. If alert is authorized/normal activity (VPN, pentest, scheduled scan) -> ignore 4. Account compromise or credential theft -> block_account, then request_mfa 5. Active malware or C2 beacon confirmed -> isolate_device immediately 6. Malicious IP confirmed by threat intel -> block_ip 7. Phishing with credential risk -> request_mfa first, then monitor 8. Supply chain / vulnerable package -> patch_system 9. Evidence needed for legal/forensics -> collect_forensics before destructive actions 10. Beyond Tier-1 (APT, ransomware, legal hold, nation-state) -> escalate 11. NEVER repeat an action already taken 12. NEVER escalate on low/medium severity without investigation Action meanings: - investigate: Pull SIEM logs, review user history, check endpoint telemetry - query_logs: Deep SIEM query — firewall, proxy, DNS, authentication logs - check_threat_intel: Query threat intel platforms (VirusTotal, Shodan, MISP, Mandiant) - run_sandbox: Detonate suspicious file in isolated sandbox environment - block_ip: Block at perimeter firewall — use when malicious IP confirmed - block_account: Disable user account — use when compromise confirmed - isolate_device: Network quarantine — use when active malware/C2 confirmed - escalate: Hand to Tier-2/IR team — use for APT, ransomware, legal exposure - request_mfa: Force MFA re-enrollment — use after credential theft - patch_system: Remove malicious package or apply security patch - collect_forensics: Preserve disk image, memory dump, logs for investigation - monitor: Passive watch — only appropriate for low-severity ambiguous alerts - ignore: Close alert as false positive — only when clearly benign''' def choose_action_baseline(task_name, step, history): seq = BASELINE.get(task_name, ["investigate", "escalate"]) idx = step - 1 if idx < len(seq) and seq[idx] not in history: return seq[idx], "baseline policy" for a in seq: if a not in history: return a, "baseline policy" return "escalate", "baseline exhausted" def llm_decide(obs_dict, history, task_name, step): available = obs_dict.get("available_actions", [ "ignore", "monitor", "investigate", "query_logs", "check_threat_intel", "run_sandbox", "block_ip", "block_account", "isolate_device", "escalate", "request_mfa", "patch_system", "collect_forensics" ]) context = obs_dict.get("context", {}) context_str = json.dumps(context, indent=2) if context else "(empty — run investigate or query_logs first)" user_msg = ( f"=== ACTIVE ALERT ===\n" f"Type : {obs_dict.get('alert_type','')}\n" f"Severity: {obs_dict.get('severity','').upper()}\n" f"Phase : {obs_dict.get('phase','')}\n" f"Step : {obs_dict.get('step',0)}/{obs_dict.get('max_steps',12)}\n\n" f"=== SIGNALS ===\n" + "\n".join(f" [{i+1}] {s}" for i, s in enumerate(obs_dict.get("signals", []))) + f"\n\n=== INVESTIGATION CONTEXT ===\n{context_str}\n\n" f"=== LAST FEEDBACK ===\n{obs_dict.get('feedback','')}\n\n" f"=== AVAILABLE ACTIONS ===\n{', '.join(available)}\n\n" f"=== ALREADY TAKEN ===\n{', '.join(history) if history else 'none'}\n\n" f"Based on the signals and context above, what is the SINGLE best next action?\n" f"Respond ONLY with JSON." ) try: response = client.chat.completions.create( model=MODEL_NAME, max_tokens=300, temperature=0.1, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_msg} ], ) text = response.choices[0].message.content.strip() if "```" in text: text = text.split("```")[1] if text.startswith("json"): text = text[4:] parsed = json.loads(text.strip()) decision = parsed.get("decision", "investigate") if decision not in available: decision = "investigate" return decision, parsed.get("reasoning", ""), None except Exception as exc: d, r = choose_action_baseline(task_name, step, history) return d, r, str(exc) def compute_score(task_name, actions): if task_name == "task_easy": if "ignore" in actions and not any(a in actions for a in ["block_account", "isolate_device"]): return 0.98 elif any(a in actions for a in ["block_account", "isolate_device", "escalate"]): return 0.02 elif "investigate" in actions: return 0.40 return 0.10 elif task_name == "task_medium": s = 0.0 if "investigate" in actions or "query_logs" in actions: s += 0.20 if "block_account" in actions: s += 0.25 if "collect_forensics" in actions: s += 0.20 if "escalate" in actions: s += 0.25 return round(min(0.99, max(0.01, s)), 2) elif task_name == "task_hard": key = ["investigate", "isolate_device", "block_ip", "collect_forensics", "escalate"] weights = [0.15, 0.20, 0.20, 0.20, 0.15] s = sum(w for a, w in zip(key, weights) if a in actions) # Also count query_logs as investigate if "query_logs" in actions and "investigate" not in actions: s += 0.15 return round(min(0.99, max(0.01, s)), 2) return 0.50 def run_episode(task_name): print(f"[START] task={task_name} env={BENCHMARK} model={MODEL_NAME}", flush=True) rewards, actions, step = [], [], 0 try: env = SOCEnvironment( difficulty=TASK_DIFFICULTY[task_name], pinned_scenario_id=TASK_SCENARIOS[task_name] ) obs = env.reset() except Exception as exc: print(f"[END] success=false steps=0 score=0.01 rewards=", flush=True) return False, 0, [], 0.01 done = obs.done while not done and step < MAX_STEPS: step += 1 obs_dict = obs.model_dump() if client is not None: decision, reasoning, llm_error = llm_decide(obs_dict, actions, task_name, step) else: decision, reasoning = choose_action_baseline(task_name, step, actions) llm_error = "no client" try: action = SOCAction(decision=decision, reasoning=reasoning) obs = env.step(action) reward = float(obs.reward) done = obs.done except Exception as exc: reward, done, llm_error = 0.0, True, str(exc) rewards.append(reward) actions.append(decision) error_str = llm_error if llm_error else "null" print(f"[STEP] step={step} action={decision} reward={reward:.2f} done={'true' if done else 'false'} error={error_str}", flush=True) score = compute_score(task_name, actions) success = score >= SUCCESS_SCORE_THRESHOLD print(f"[END] success={'true' if success else 'false'} steps={step} score={score:.2f} rewards={','.join(f'{r:.2f}' for r in rewards)}", flush=True) return success, step, rewards, score def main(): results = [] for task_name in TASKS: success, steps, rewards, score = run_episode(task_name) results.append({"task": task_name, "success": success, "steps": steps, "score": score, "total_reward": round(sum(rewards), 2)}) print(flush=True) print("# SUMMARY", flush=True) for r in results: print(f"# {r['task']:20s} {'SUCCESS' if r['success'] else 'FAIL':8s} steps={r['steps']:2d} score={r['score']:.2f} total_reward={r['total_reward']:.2f}", flush=True) print(f"# Tasks passed: {sum(1 for r in results if r['success'])}/{len(results)}", flush=True) if __name__ == "__main__": if API_KEY: try: client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) print("[DEBUG] OpenAI client initialized.", flush=True) except Exception as e: print(f"[DEBUG] Failed to initialize OpenAI client: {e}", flush=True) client = None else: print("[DEBUG] No API key found. Using baseline policy.", flush=True) client = None main()