Spaces:
Sleeping
Sleeping
| import os, json | |
| from typing import List, Optional | |
| from dotenv import load_dotenv | |
| from openai import OpenAI | |
| load_dotenv(override=True) | |
| API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY") | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct") | |
| try: | |
| from server.SOC_env_environment import SOCEnvironment | |
| from models import SOCAction | |
| except ImportError: | |
| from SOC_env.server.SOC_env_environment import SOCEnvironment | |
| from SOC_env.models import SOCAction | |
| BENCHMARK = "SOC_env" | |
| MAX_STEPS = 12 | |
| TASKS = ["task_easy", "task_medium", "task_hard"] | |
| SUCCESS_SCORE_THRESHOLD = 0.60 | |
| client = None | |
| TASK_SCENARIOS = { | |
| "task_easy": "easy_false_positive_vpn", | |
| "task_medium": "medium_insider_threat", | |
| "task_hard": "hard_apt_lateral_movement", | |
| } | |
| TASK_DIFFICULTY = { | |
| "task_easy": "easy", | |
| "task_medium": "medium", | |
| "task_hard": "hard", | |
| } | |
| BASELINE = { | |
| "task_easy": ["investigate", "ignore"], | |
| "task_medium": ["investigate", "block_account", "collect_forensics", "escalate"], | |
| "task_hard": ["investigate", "isolate_device", "block_ip", "collect_forensics", "escalate"], | |
| } | |
| SYSTEM_PROMPT = '''You are an expert SOC Tier-1 analyst following NIST SP 800-61 incident response procedures. | |
| Respond ONLY with JSON: {"decision": "<action>", "reasoning": "<one sentence>"} | |
| SOC Playbook Rules: | |
| 1. ALWAYS investigate first when context is empty — never act blind | |
| 2. Use query_logs or check_threat_intel for additional context before acting | |
| 3. If alert is authorized/normal activity (VPN, pentest, scheduled scan) -> ignore | |
| 4. Account compromise or credential theft -> block_account, then request_mfa | |
| 5. Active malware or C2 beacon confirmed -> isolate_device immediately | |
| 6. Malicious IP confirmed by threat intel -> block_ip | |
| 7. Phishing with credential risk -> request_mfa first, then monitor | |
| 8. Supply chain / vulnerable package -> patch_system | |
| 9. Evidence needed for legal/forensics -> collect_forensics before destructive actions | |
| 10. Beyond Tier-1 (APT, ransomware, legal hold, nation-state) -> escalate | |
| 11. NEVER repeat an action already taken | |
| 12. NEVER escalate on low/medium severity without investigation | |
| Action meanings: | |
| - investigate: Pull SIEM logs, review user history, check endpoint telemetry | |
| - query_logs: Deep SIEM query — firewall, proxy, DNS, authentication logs | |
| - check_threat_intel: Query threat intel platforms (VirusTotal, Shodan, MISP, Mandiant) | |
| - run_sandbox: Detonate suspicious file in isolated sandbox environment | |
| - block_ip: Block at perimeter firewall — use when malicious IP confirmed | |
| - block_account: Disable user account — use when compromise confirmed | |
| - isolate_device: Network quarantine — use when active malware/C2 confirmed | |
| - escalate: Hand to Tier-2/IR team — use for APT, ransomware, legal exposure | |
| - request_mfa: Force MFA re-enrollment — use after credential theft | |
| - patch_system: Remove malicious package or apply security patch | |
| - collect_forensics: Preserve disk image, memory dump, logs for investigation | |
| - monitor: Passive watch — only appropriate for low-severity ambiguous alerts | |
| - ignore: Close alert as false positive — only when clearly benign''' | |
| def choose_action_baseline(task_name, step, history): | |
| seq = BASELINE.get(task_name, ["investigate", "escalate"]) | |
| idx = step - 1 | |
| if idx < len(seq) and seq[idx] not in history: | |
| return seq[idx], "baseline policy" | |
| for a in seq: | |
| if a not in history: | |
| return a, "baseline policy" | |
| return "escalate", "baseline exhausted" | |
| def llm_decide(obs_dict, history, task_name, step): | |
| available = obs_dict.get("available_actions", [ | |
| "ignore", "monitor", "investigate", "query_logs", "check_threat_intel", | |
| "run_sandbox", "block_ip", "block_account", "isolate_device", "escalate", | |
| "request_mfa", "patch_system", "collect_forensics" | |
| ]) | |
| context = obs_dict.get("context", {}) | |
| context_str = json.dumps(context, indent=2) if context else "(empty — run investigate or query_logs first)" | |
| user_msg = ( | |
| f"=== ACTIVE ALERT ===\n" | |
| f"Type : {obs_dict.get('alert_type','')}\n" | |
| f"Severity: {obs_dict.get('severity','').upper()}\n" | |
| f"Phase : {obs_dict.get('phase','')}\n" | |
| f"Step : {obs_dict.get('step',0)}/{obs_dict.get('max_steps',12)}\n\n" | |
| f"=== SIGNALS ===\n" + | |
| "\n".join(f" [{i+1}] {s}" for i, s in enumerate(obs_dict.get("signals", []))) + | |
| f"\n\n=== INVESTIGATION CONTEXT ===\n{context_str}\n\n" | |
| f"=== LAST FEEDBACK ===\n{obs_dict.get('feedback','')}\n\n" | |
| f"=== AVAILABLE ACTIONS ===\n{', '.join(available)}\n\n" | |
| f"=== ALREADY TAKEN ===\n{', '.join(history) if history else 'none'}\n\n" | |
| f"Based on the signals and context above, what is the SINGLE best next action?\n" | |
| f"Respond ONLY with JSON." | |
| ) | |
| try: | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, max_tokens=300, temperature=0.1, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_msg} | |
| ], | |
| ) | |
| text = response.choices[0].message.content.strip() | |
| if "```" in text: | |
| text = text.split("```")[1] | |
| if text.startswith("json"): | |
| text = text[4:] | |
| parsed = json.loads(text.strip()) | |
| decision = parsed.get("decision", "investigate") | |
| if decision not in available: | |
| decision = "investigate" | |
| return decision, parsed.get("reasoning", ""), None | |
| except Exception as exc: | |
| d, r = choose_action_baseline(task_name, step, history) | |
| return d, r, str(exc) | |
| def compute_score(task_name, actions): | |
| if task_name == "task_easy": | |
| if "ignore" in actions and not any(a in actions for a in ["block_account", "isolate_device"]): | |
| return 0.98 | |
| elif any(a in actions for a in ["block_account", "isolate_device", "escalate"]): | |
| return 0.02 | |
| elif "investigate" in actions: | |
| return 0.40 | |
| return 0.10 | |
| elif task_name == "task_medium": | |
| s = 0.0 | |
| if "investigate" in actions or "query_logs" in actions: s += 0.20 | |
| if "block_account" in actions: s += 0.25 | |
| if "collect_forensics" in actions: s += 0.20 | |
| if "escalate" in actions: s += 0.25 | |
| return round(min(0.99, max(0.01, s)), 2) | |
| elif task_name == "task_hard": | |
| key = ["investigate", "isolate_device", "block_ip", "collect_forensics", "escalate"] | |
| weights = [0.15, 0.20, 0.20, 0.20, 0.15] | |
| s = sum(w for a, w in zip(key, weights) if a in actions) | |
| # Also count query_logs as investigate | |
| if "query_logs" in actions and "investigate" not in actions: | |
| s += 0.15 | |
| return round(min(0.99, max(0.01, s)), 2) | |
| return 0.50 | |
| def run_episode(task_name): | |
| print(f"[START] task={task_name} env={BENCHMARK} model={MODEL_NAME}", flush=True) | |
| rewards, actions, step = [], [], 0 | |
| try: | |
| env = SOCEnvironment( | |
| difficulty=TASK_DIFFICULTY[task_name], | |
| pinned_scenario_id=TASK_SCENARIOS[task_name] | |
| ) | |
| obs = env.reset() | |
| except Exception as exc: | |
| print(f"[END] success=false steps=0 score=0.01 rewards=", flush=True) | |
| return False, 0, [], 0.01 | |
| done = obs.done | |
| while not done and step < MAX_STEPS: | |
| step += 1 | |
| obs_dict = obs.model_dump() | |
| if client is not None: | |
| decision, reasoning, llm_error = llm_decide(obs_dict, actions, task_name, step) | |
| else: | |
| decision, reasoning = choose_action_baseline(task_name, step, actions) | |
| llm_error = "no client" | |
| try: | |
| action = SOCAction(decision=decision, reasoning=reasoning) | |
| obs = env.step(action) | |
| reward = float(obs.reward) | |
| done = obs.done | |
| except Exception as exc: | |
| reward, done, llm_error = 0.0, True, str(exc) | |
| rewards.append(reward) | |
| actions.append(decision) | |
| error_str = llm_error if llm_error else "null" | |
| print(f"[STEP] step={step} action={decision} reward={reward:.2f} done={'true' if done else 'false'} error={error_str}", flush=True) | |
| score = compute_score(task_name, actions) | |
| success = score >= SUCCESS_SCORE_THRESHOLD | |
| print(f"[END] success={'true' if success else 'false'} steps={step} score={score:.2f} rewards={','.join(f'{r:.2f}' for r in rewards)}", flush=True) | |
| return success, step, rewards, score | |
| def main(): | |
| results = [] | |
| for task_name in TASKS: | |
| success, steps, rewards, score = run_episode(task_name) | |
| results.append({"task": task_name, "success": success, "steps": steps, "score": score, "total_reward": round(sum(rewards), 2)}) | |
| print(flush=True) | |
| print("# SUMMARY", flush=True) | |
| for r in results: | |
| print(f"# {r['task']:20s} {'SUCCESS' if r['success'] else 'FAIL':8s} steps={r['steps']:2d} score={r['score']:.2f} total_reward={r['total_reward']:.2f}", flush=True) | |
| print(f"# Tasks passed: {sum(1 for r in results if r['success'])}/{len(results)}", flush=True) | |
| if __name__ == "__main__": | |
| if API_KEY: | |
| try: | |
| client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) | |
| print("[DEBUG] OpenAI client initialized.", flush=True) | |
| except Exception as e: | |
| print(f"[DEBUG] Failed to initialize OpenAI client: {e}", flush=True) | |
| client = None | |
| else: | |
| print("[DEBUG] No API key found. Using baseline policy.", flush=True) | |
| client = None | |
| main() |