import os, json, re, textwrap from openai import OpenAI import gradio as gr try: import wandb _WANDB_AVAILABLE = callable(getattr(wandb, "init", None)) except ImportError: _WANDB_AVAILABLE = False try: from dotenv import load_dotenv load_dotenv(dotenv_path=os.path.join(os.path.dirname(os.path.dirname(__file__)), ".env"), override=False) except ImportError: pass try: from openenv.core.env_server.http_server import create_app except Exception as e: raise ImportError("openenv is required.") from e try: from ..models import NetworkAction, NetworkObservation from .vir_env_environment import NetworkEnvironment except (ImportError, ValueError): try: from models import NetworkAction, NetworkObservation from server.vir_env_environment import NetworkEnvironment except ImportError: import sys sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models import NetworkAction, NetworkObservation from server.vir_env_environment import NetworkEnvironment # --------------------------------------------------------------------------- # LLM configuration # --------------------------------------------------------------------------- _API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN") or "ollama" _API_BASE_URL = os.getenv("API_BASE_URL") or "http://127.0.0.1:11434/v1" _MODEL_NAME = os.getenv("MODEL_NAME") or "gemma3:4b" _SYSTEM_PROMPT = textwrap.dedent(""" You are an AI SOC analyst racing to contain a spreading virus. === NETWORK NODES (20-NODE ENTERPRISE) === DMZ: Gateway, Firewall, LoadBalancer Web Tier: Web1, Web2, Web3 App Tier: App1, App2, App3 Data Tier: DB_Primary ★, DB_Replica ★, Backup ★, Cache1, Cache2 Auth/Mgmt: Auth ★, LDAP ★, Monitor, SIEM Internal: Workstation1, Workstation2 ★ = critical asset — losing these fails the mission. === TOOLS === 1. scan_network — Refresh network state. 2. isolate_node(target) — Sever ALL connections instantly. Stops spread from that node. WARNING: isolating DB_Primary costs -0.50 penalty. 3. deploy_patch(target) — Remove malware AND reconnect node to network. === DIFFICULTY RULES === - EASY: 1 entry point. Virus spreads randomly (50%). 30 steps. No breach timer. - MEDIUM: 2 entry points. Virus hunts DB_Primary. DB_Primary infected 5+ consecutive steps = DATA BREACH = instant loss. 20 steps. - HARD: 3 entry points. Auth infected = all spread rates DOUBLE. DB_Primary breach in 2 steps. 15 steps. === PRIORITY === 1. Look at INFECTED NODES list — those are your ONLY targets. 2. DB_Primary infected → patch IMMEDIATELY (db_infected_steps counts to breach; do NOT isolate — use deploy_patch). 3. Auth infected → isolate IMMEDIATELY (doubles all spread rates). 4. Any other infected node → isolate it. 5. Once all infections cleared → patch isolated nodes to restore network. ⚠ CRITICAL: "AT RISK" = virus MIGHT spread there — they are still CLEAN. NEVER isolate an AT RISK node. ONLY isolate INFECTED nodes. RESPOND ONLY with valid JSON: {"action_type": "scan_network|isolate_node|deploy_patch", "target": "NodeName or null", "reasoning": "one sentence"} """).strip() # --------------------------------------------------------------------------- # Shared environment state for the AI web interface # --------------------------------------------------------------------------- _env = NetworkEnvironment() _current_obs = [None] _episode_done = [True] _last_task = ["easy"] HIGH_VALUE = {"DB_Primary", "DB_Replica", "Auth", "LDAP", "Backup"} def _build_display(obs) -> str: net = obs.network_state infected_set = {n for n, d in net.items() if d.get("status") == "infected"} at_risk = { n for n, d in net.items() if d.get("status") == "clean" and any(c in infected_set for c in d.get("connections", [])) } if any(n in HIGH_VALUE for n in infected_set): threat = "🔴 CRITICAL" elif obs.infected_count >= 3: threat = "🟠 HIGH" elif obs.infected_count >= 1: threat = "🟡 MEDIUM" else: threat = "🟢 CLEAR" scenario = getattr(obs, "scenario_name", obs.task.upper()) score = getattr(obs, "cumulative_score", 0.0) db_steps = getattr(obs, "db_infected_steps", 0) auth_comp = getattr(obs, "auth_compromised", False) breakdown = getattr(obs, "reward_breakdown", {}) infected_list = ', '.join(sorted(infected_set)) if infected_set else "NONE — network clean!" lines = [ f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━", f" SOC OPERATIONS CENTER [{scenario}]", f" Step {obs.step}/{obs.max_steps} | Threat: {threat} | Score: {score:+.2f}", f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━", "", f" >> INFECTED NODES (YOUR TARGETS): {infected_list} <<", "", " NODE STATUS", " ─────────────────────────────────────", ] for node, info in net.items(): status = info.get("status", "clean") conns = info.get("connections", []) if node in at_risk and status == "clean": icon = "🟡 AT RISK " elif status == "infected": icon = "🔴 INFECTED " elif status == "isolated": icon = "⬛ ISOLATED " else: icon = "🟢 ONLINE " conn_str = ", ".join(conns) if conns else "no connections" hv = " ★" if node in HIGH_VALUE else "" lines.append(f" {icon} {node:10s}{hv} → {conn_str}") lines += [ "", f" Infected: {obs.infected_count} | At Risk: {len(at_risk)} " f"| Isolated: {obs.isolated_count} | Secure: {obs.clean_count - len(at_risk)}", ] # Warnings warnings = [] if db_steps > 0: breach = getattr(obs, "breach_threshold", None) limit = f"/{breach}" if breach else "" warnings.append(f"⚠ DB EXPOSED {db_steps}{limit} steps — patch immediately!") if auth_comp: warnings.append("⚠ AUTH WAS COMPROMISED — spread rates may be doubled!") if warnings: lines += [""] + [f" {w}" for w in warnings] # Reward breakdown if breakdown: lines += ["", " LAST STEP PENALTIES/REWARDS", " ─────────────────────────────────────"] labels = { "step": "Time penalty", "infected_nodes": "Infected node damage", "db_at_risk": "DB exposure penalty", "auth_compromised": "Auth compromise penalty", "db_isolate": "DB isolation downtime", "data_breach": "DATA BREACH", "fast_containment": "Fast containment bonus", "win": "Mission success bonus", } for k, v in breakdown.items(): label = labels.get(k, k) lines.append(f" {label:30s} {v:+.2f}") # Spread log if obs.spread_events: lines += ["", " LATERAL MOVEMENT LOG", " ─────────────────────────────────────"] for ev in obs.spread_events[-5:]: crit = " ⚠" if ev.get("critical") else "" lines.append(f" Step {ev['step']:>2} {ev['from']:10s} ──► {ev['to']}{crit}") if at_risk: lines += ["", f" ⚡ VIRUS MAY SPREAD TO (still clean, do NOT isolate): {', '.join(sorted(at_risk))}"] lines += ["", f" {obs.message}", ""] return "\n".join(lines) def _call_llm(obs) -> NetworkAction: client = OpenAI(base_url=_API_BASE_URL, api_key=_API_KEY) prompt = _build_display(obs) print("=" * 60) print("[LLM INPUT]") print(prompt) print("=" * 60) try: resp = client.chat.completions.create( model=_MODEL_NAME, messages=[ {"role": "system", "content": _SYSTEM_PROMPT}, {"role": "user", "content": _build_display(obs)}, ], temperature=0.3, max_tokens=400, ) raw = (resp.choices[0].message.content or "").strip() if raw.startswith("```"): raw = raw.split("```")[1].lstrip("json").strip() m = re.search(r'\{.*\}', raw, re.DOTALL) if not m: raise ValueError("No JSON in LLM response") parsed = json.loads(m.group()) return NetworkAction( action_type=parsed.get("action_type", "scan_network"), target=parsed.get("target") or None, reasoning=parsed.get("reasoning", ""), ) except Exception as e: return NetworkAction(action_type="scan_network", reasoning=f"LLM error: {e}") # --------------------------------------------------------------------------- # Gradio button handlers # --------------------------------------------------------------------------- def web_reset(task: str): _last_task[0] = task obs = _env.reset(task=task) _current_obs[0] = obs _episode_done[0] = False if _WANDB_AVAILABLE: from uuid import uuid4 wandb.init(project="vir_env", name=f"{task}-{uuid4().hex[:6]}", reinit=True) return _build_display(obs), "", "", f"Episode started [{task.upper()}]. Click AI Step." def web_ai_step(): if _episode_done[0] or _current_obs[0] is None: # Auto-reset instead of blocking obs = _env.reset(task=_last_task[0]) _current_obs[0] = obs _episode_done[0] = False obs = _current_obs[0] action = _call_llm(obs) result = _env.step(action) _current_obs[0] = result _episode_done[0] = result.done if _WANDB_AVAILABLE and wandb.run is not None: wandb.log({ "step": result.step, "infected_count": result.infected_count, "isolated_count": result.isolated_count, "reward": result.reward, "cumulative_score": result.cumulative_score, "action_type": action.action_type.value, "target": action.target or "none", "db_infected_steps":result.db_infected_steps, "auth_compromised": result.auth_compromised, "done": result.done, "task": _last_task[0], }) if result.done: wandb.log({ "episode/won": result.infected_count == 0, "episode/total_steps": result.step, "episode/final_score": result.cumulative_score, }) action_type = action.action_type.value target = action.target or "—" reasoning = action.reasoning if result.done: reasoning += "\n\n" + ("✅ NETWORK SECURED!" if result.infected_count == 0 else "❌ Step budget exhausted.") return _build_display(result), action_type, target, reasoning # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- with gr.Blocks(title="AI-SOAR: Network Defense", theme=gr.themes.Monochrome()) as ai_interface: gr.Markdown("# 🛡️ AI-SOAR Network Defense\nWatch the AI agent chase and contain the spreading virus in real time.") with gr.Row(): task_dd = gr.Dropdown(["easy", "medium", "hard"], value="easy", label="Difficulty") reset_btn = gr.Button("🔄 Reset", variant="secondary") step_btn = gr.Button("🤖 AI Step", variant="primary") with gr.Row(): net_display = gr.Textbox(label="Network State", lines=22, interactive=False) with gr.Column(): action_type_box = gr.Textbox(label="Action Type", lines=1, interactive=False) target_box = gr.Textbox(label="Target", lines=1, interactive=False) reasoning_box = gr.Textbox(label="Reasoning", lines=4, interactive=False) reset_btn.click(web_reset, inputs=[task_dd], outputs=[net_display, action_type_box, target_box, reasoning_box]) step_btn.click( web_ai_step, inputs=[], outputs=[net_display, action_type_box, target_box, reasoning_box]) # --------------------------------------------------------------------------- # App assembly # --------------------------------------------------------------------------- app = create_app( NetworkEnvironment, NetworkAction, NetworkObservation, env_name="vir_env", max_concurrent_envs=5, ) app = gr.mount_gradio_app(app, ai_interface, path="/ai") def main(): import argparse, uvicorn parser = argparse.ArgumentParser() parser.add_argument("--host", default="0.0.0.0") parser.add_argument("--port", type=int, default=8000) args = parser.parse_args() uvicorn.run(app, host=args.host, port=args.port) if __name__ == "__main__": main()