SOC-env / inference.py
ApoorvaBadoni's picture
Upload folder using huggingface_hub
4c8f878 verified
import os, json
from typing import List, Optional
from dotenv import load_dotenv
from openai import OpenAI
load_dotenv(override=True)
API_KEY = os.getenv("HF_TOKEN") or os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY")
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
try:
from server.SOC_env_environment import SOCEnvironment
from models import SOCAction
except ImportError:
from SOC_env.server.SOC_env_environment import SOCEnvironment
from SOC_env.models import SOCAction
BENCHMARK = "SOC_env"
MAX_STEPS = 12
TASKS = ["task_easy", "task_medium", "task_hard"]
SUCCESS_SCORE_THRESHOLD = 0.60
client = None
TASK_SCENARIOS = {
"task_easy": "easy_false_positive_vpn",
"task_medium": "medium_insider_threat",
"task_hard": "hard_apt_lateral_movement",
}
TASK_DIFFICULTY = {
"task_easy": "easy",
"task_medium": "medium",
"task_hard": "hard",
}
BASELINE = {
"task_easy": ["investigate", "ignore"],
"task_medium": ["investigate", "block_account", "collect_forensics", "escalate"],
"task_hard": ["investigate", "isolate_device", "block_ip", "collect_forensics", "escalate"],
}
SYSTEM_PROMPT = '''You are an expert SOC Tier-1 analyst following NIST SP 800-61 incident response procedures.
Respond ONLY with JSON: {"decision": "<action>", "reasoning": "<one sentence>"}
SOC Playbook Rules:
1. ALWAYS investigate first when context is empty — never act blind
2. Use query_logs or check_threat_intel for additional context before acting
3. If alert is authorized/normal activity (VPN, pentest, scheduled scan) -> ignore
4. Account compromise or credential theft -> block_account, then request_mfa
5. Active malware or C2 beacon confirmed -> isolate_device immediately
6. Malicious IP confirmed by threat intel -> block_ip
7. Phishing with credential risk -> request_mfa first, then monitor
8. Supply chain / vulnerable package -> patch_system
9. Evidence needed for legal/forensics -> collect_forensics before destructive actions
10. Beyond Tier-1 (APT, ransomware, legal hold, nation-state) -> escalate
11. NEVER repeat an action already taken
12. NEVER escalate on low/medium severity without investigation
Action meanings:
- investigate: Pull SIEM logs, review user history, check endpoint telemetry
- query_logs: Deep SIEM query — firewall, proxy, DNS, authentication logs
- check_threat_intel: Query threat intel platforms (VirusTotal, Shodan, MISP, Mandiant)
- run_sandbox: Detonate suspicious file in isolated sandbox environment
- block_ip: Block at perimeter firewall — use when malicious IP confirmed
- block_account: Disable user account — use when compromise confirmed
- isolate_device: Network quarantine — use when active malware/C2 confirmed
- escalate: Hand to Tier-2/IR team — use for APT, ransomware, legal exposure
- request_mfa: Force MFA re-enrollment — use after credential theft
- patch_system: Remove malicious package or apply security patch
- collect_forensics: Preserve disk image, memory dump, logs for investigation
- monitor: Passive watch — only appropriate for low-severity ambiguous alerts
- ignore: Close alert as false positive — only when clearly benign'''
def choose_action_baseline(task_name, step, history):
seq = BASELINE.get(task_name, ["investigate", "escalate"])
idx = step - 1
if idx < len(seq) and seq[idx] not in history:
return seq[idx], "baseline policy"
for a in seq:
if a not in history:
return a, "baseline policy"
return "escalate", "baseline exhausted"
def llm_decide(obs_dict, history, task_name, step):
available = obs_dict.get("available_actions", [
"ignore", "monitor", "investigate", "query_logs", "check_threat_intel",
"run_sandbox", "block_ip", "block_account", "isolate_device", "escalate",
"request_mfa", "patch_system", "collect_forensics"
])
context = obs_dict.get("context", {})
context_str = json.dumps(context, indent=2) if context else "(empty — run investigate or query_logs first)"
user_msg = (
f"=== ACTIVE ALERT ===\n"
f"Type : {obs_dict.get('alert_type','')}\n"
f"Severity: {obs_dict.get('severity','').upper()}\n"
f"Phase : {obs_dict.get('phase','')}\n"
f"Step : {obs_dict.get('step',0)}/{obs_dict.get('max_steps',12)}\n\n"
f"=== SIGNALS ===\n" +
"\n".join(f" [{i+1}] {s}" for i, s in enumerate(obs_dict.get("signals", []))) +
f"\n\n=== INVESTIGATION CONTEXT ===\n{context_str}\n\n"
f"=== LAST FEEDBACK ===\n{obs_dict.get('feedback','')}\n\n"
f"=== AVAILABLE ACTIONS ===\n{', '.join(available)}\n\n"
f"=== ALREADY TAKEN ===\n{', '.join(history) if history else 'none'}\n\n"
f"Based on the signals and context above, what is the SINGLE best next action?\n"
f"Respond ONLY with JSON."
)
try:
response = client.chat.completions.create(
model=MODEL_NAME, max_tokens=300, temperature=0.1,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_msg}
],
)
text = response.choices[0].message.content.strip()
if "```" in text:
text = text.split("```")[1]
if text.startswith("json"):
text = text[4:]
parsed = json.loads(text.strip())
decision = parsed.get("decision", "investigate")
if decision not in available:
decision = "investigate"
return decision, parsed.get("reasoning", ""), None
except Exception as exc:
d, r = choose_action_baseline(task_name, step, history)
return d, r, str(exc)
def compute_score(task_name, actions):
if task_name == "task_easy":
if "ignore" in actions and not any(a in actions for a in ["block_account", "isolate_device"]):
return 0.98
elif any(a in actions for a in ["block_account", "isolate_device", "escalate"]):
return 0.02
elif "investigate" in actions:
return 0.40
return 0.10
elif task_name == "task_medium":
s = 0.0
if "investigate" in actions or "query_logs" in actions: s += 0.20
if "block_account" in actions: s += 0.25
if "collect_forensics" in actions: s += 0.20
if "escalate" in actions: s += 0.25
return round(min(0.99, max(0.01, s)), 2)
elif task_name == "task_hard":
key = ["investigate", "isolate_device", "block_ip", "collect_forensics", "escalate"]
weights = [0.15, 0.20, 0.20, 0.20, 0.15]
s = sum(w for a, w in zip(key, weights) if a in actions)
# Also count query_logs as investigate
if "query_logs" in actions and "investigate" not in actions:
s += 0.15
return round(min(0.99, max(0.01, s)), 2)
return 0.50
def run_episode(task_name):
print(f"[START] task={task_name} env={BENCHMARK} model={MODEL_NAME}", flush=True)
rewards, actions, step = [], [], 0
try:
env = SOCEnvironment(
difficulty=TASK_DIFFICULTY[task_name],
pinned_scenario_id=TASK_SCENARIOS[task_name]
)
obs = env.reset()
except Exception as exc:
print(f"[END] success=false steps=0 score=0.01 rewards=", flush=True)
return False, 0, [], 0.01
done = obs.done
while not done and step < MAX_STEPS:
step += 1
obs_dict = obs.model_dump()
if client is not None:
decision, reasoning, llm_error = llm_decide(obs_dict, actions, task_name, step)
else:
decision, reasoning = choose_action_baseline(task_name, step, actions)
llm_error = "no client"
try:
action = SOCAction(decision=decision, reasoning=reasoning)
obs = env.step(action)
reward = float(obs.reward)
done = obs.done
except Exception as exc:
reward, done, llm_error = 0.0, True, str(exc)
rewards.append(reward)
actions.append(decision)
error_str = llm_error if llm_error else "null"
print(f"[STEP] step={step} action={decision} reward={reward:.2f} done={'true' if done else 'false'} error={error_str}", flush=True)
score = compute_score(task_name, actions)
success = score >= SUCCESS_SCORE_THRESHOLD
print(f"[END] success={'true' if success else 'false'} steps={step} score={score:.2f} rewards={','.join(f'{r:.2f}' for r in rewards)}", flush=True)
return success, step, rewards, score
def main():
results = []
for task_name in TASKS:
success, steps, rewards, score = run_episode(task_name)
results.append({"task": task_name, "success": success, "steps": steps, "score": score, "total_reward": round(sum(rewards), 2)})
print(flush=True)
print("# SUMMARY", flush=True)
for r in results:
print(f"# {r['task']:20s} {'SUCCESS' if r['success'] else 'FAIL':8s} steps={r['steps']:2d} score={r['score']:.2f} total_reward={r['total_reward']:.2f}", flush=True)
print(f"# Tasks passed: {sum(1 for r in results if r['success'])}/{len(results)}", flush=True)
if __name__ == "__main__":
if API_KEY:
try:
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
print("[DEBUG] OpenAI client initialized.", flush=True)
except Exception as e:
print(f"[DEBUG] Failed to initialize OpenAI client: {e}", flush=True)
client = None
else:
print("[DEBUG] No API key found. Using baseline policy.", flush=True)
client = None
main()