Spaces:
Sleeping
Sleeping
| import os, json, re, textwrap | |
| from openai import OpenAI | |
| import gradio as gr | |
| try: | |
| import wandb | |
| _WANDB_AVAILABLE = callable(getattr(wandb, "init", None)) | |
| except ImportError: | |
| _WANDB_AVAILABLE = False | |
| try: | |
| from dotenv import load_dotenv | |
| load_dotenv(dotenv_path=os.path.join(os.path.dirname(os.path.dirname(__file__)), ".env"), override=False) | |
| except ImportError: | |
| pass | |
| try: | |
| from openenv.core.env_server.http_server import create_app | |
| except Exception as e: | |
| raise ImportError("openenv is required.") from e | |
| try: | |
| from ..models import NetworkAction, NetworkObservation | |
| from .vir_env_environment import NetworkEnvironment | |
| except (ImportError, ValueError): | |
| try: | |
| from models import NetworkAction, NetworkObservation | |
| from server.vir_env_environment import NetworkEnvironment | |
| except ImportError: | |
| import sys | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from models import NetworkAction, NetworkObservation | |
| from server.vir_env_environment import NetworkEnvironment | |
| # --------------------------------------------------------------------------- | |
| # LLM configuration | |
| # --------------------------------------------------------------------------- | |
| _API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN") or "ollama" | |
| _API_BASE_URL = os.getenv("API_BASE_URL") or "http://127.0.0.1:11434/v1" | |
| _MODEL_NAME = os.getenv("MODEL_NAME") or "gemma3:4b" | |
| _SYSTEM_PROMPT = textwrap.dedent(""" | |
| You are an AI SOC analyst racing to contain a spreading virus. | |
| === NETWORK NODES (20-NODE ENTERPRISE) === | |
| DMZ: Gateway, Firewall, LoadBalancer | |
| Web Tier: Web1, Web2, Web3 | |
| App Tier: App1, App2, App3 | |
| Data Tier: DB_Primary ★, DB_Replica ★, Backup ★, Cache1, Cache2 | |
| Auth/Mgmt: Auth ★, LDAP ★, Monitor, SIEM | |
| Internal: Workstation1, Workstation2 | |
| ★ = critical asset — losing these fails the mission. | |
| === TOOLS === | |
| 1. scan_network — Refresh network state. | |
| 2. isolate_node(target) — Sever ALL connections instantly. Stops spread from that node. | |
| WARNING: isolating DB_Primary costs -0.50 penalty. | |
| 3. deploy_patch(target) — Remove malware AND reconnect node to network. | |
| === DIFFICULTY RULES === | |
| - EASY: 1 entry point. Virus spreads randomly (50%). 30 steps. No breach timer. | |
| - MEDIUM: 2 entry points. Virus hunts DB_Primary. DB_Primary infected 5+ consecutive steps = DATA BREACH = instant loss. 20 steps. | |
| - HARD: 3 entry points. Auth infected = all spread rates DOUBLE. DB_Primary breach in 2 steps. 15 steps. | |
| === PRIORITY === | |
| 1. Look at INFECTED NODES list — those are your ONLY targets. | |
| 2. DB_Primary infected → patch IMMEDIATELY (db_infected_steps counts to breach; do NOT isolate — use deploy_patch). | |
| 3. Auth infected → isolate IMMEDIATELY (doubles all spread rates). | |
| 4. Any other infected node → isolate it. | |
| 5. Once all infections cleared → patch isolated nodes to restore network. | |
| ⚠ CRITICAL: "AT RISK" = virus MIGHT spread there — they are still CLEAN. | |
| NEVER isolate an AT RISK node. ONLY isolate INFECTED nodes. | |
| RESPOND ONLY with valid JSON: | |
| {"action_type": "scan_network|isolate_node|deploy_patch", "target": "NodeName or null", "reasoning": "one sentence"} | |
| """).strip() | |
| # --------------------------------------------------------------------------- | |
| # Shared environment state for the AI web interface | |
| # --------------------------------------------------------------------------- | |
| _env = NetworkEnvironment() | |
| _current_obs = [None] | |
| _episode_done = [True] | |
| _last_task = ["easy"] | |
| HIGH_VALUE = {"DB_Primary", "DB_Replica", "Auth", "LDAP", "Backup"} | |
| def _build_display(obs) -> str: | |
| net = obs.network_state | |
| infected_set = {n for n, d in net.items() if d.get("status") == "infected"} | |
| at_risk = { | |
| n for n, d in net.items() | |
| if d.get("status") == "clean" | |
| and any(c in infected_set for c in d.get("connections", [])) | |
| } | |
| if any(n in HIGH_VALUE for n in infected_set): | |
| threat = "🔴 CRITICAL" | |
| elif obs.infected_count >= 3: | |
| threat = "🟠 HIGH" | |
| elif obs.infected_count >= 1: | |
| threat = "🟡 MEDIUM" | |
| else: | |
| threat = "🟢 CLEAR" | |
| scenario = getattr(obs, "scenario_name", obs.task.upper()) | |
| score = getattr(obs, "cumulative_score", 0.0) | |
| db_steps = getattr(obs, "db_infected_steps", 0) | |
| auth_comp = getattr(obs, "auth_compromised", False) | |
| breakdown = getattr(obs, "reward_breakdown", {}) | |
| infected_list = ', '.join(sorted(infected_set)) if infected_set else "NONE — network clean!" | |
| lines = [ | |
| f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━", | |
| f" SOC OPERATIONS CENTER [{scenario}]", | |
| f" Step {obs.step}/{obs.max_steps} | Threat: {threat} | Score: {score:+.2f}", | |
| f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━", | |
| "", | |
| f" >> INFECTED NODES (YOUR TARGETS): {infected_list} <<", | |
| "", | |
| " NODE STATUS", | |
| " ─────────────────────────────────────", | |
| ] | |
| for node, info in net.items(): | |
| status = info.get("status", "clean") | |
| conns = info.get("connections", []) | |
| if node in at_risk and status == "clean": | |
| icon = "🟡 AT RISK " | |
| elif status == "infected": | |
| icon = "🔴 INFECTED " | |
| elif status == "isolated": | |
| icon = "⬛ ISOLATED " | |
| else: | |
| icon = "🟢 ONLINE " | |
| conn_str = ", ".join(conns) if conns else "no connections" | |
| hv = " ★" if node in HIGH_VALUE else "" | |
| lines.append(f" {icon} {node:10s}{hv} → {conn_str}") | |
| lines += [ | |
| "", | |
| f" Infected: {obs.infected_count} | At Risk: {len(at_risk)} " | |
| f"| Isolated: {obs.isolated_count} | Secure: {obs.clean_count - len(at_risk)}", | |
| ] | |
| # Warnings | |
| warnings = [] | |
| if db_steps > 0: | |
| breach = getattr(obs, "breach_threshold", None) | |
| limit = f"/{breach}" if breach else "" | |
| warnings.append(f"⚠ DB EXPOSED {db_steps}{limit} steps — patch immediately!") | |
| if auth_comp: | |
| warnings.append("⚠ AUTH WAS COMPROMISED — spread rates may be doubled!") | |
| if warnings: | |
| lines += [""] + [f" {w}" for w in warnings] | |
| # Reward breakdown | |
| if breakdown: | |
| lines += ["", " LAST STEP PENALTIES/REWARDS", " ─────────────────────────────────────"] | |
| labels = { | |
| "step": "Time penalty", | |
| "infected_nodes": "Infected node damage", | |
| "db_at_risk": "DB exposure penalty", | |
| "auth_compromised": "Auth compromise penalty", | |
| "db_isolate": "DB isolation downtime", | |
| "data_breach": "DATA BREACH", | |
| "fast_containment": "Fast containment bonus", | |
| "win": "Mission success bonus", | |
| } | |
| for k, v in breakdown.items(): | |
| label = labels.get(k, k) | |
| lines.append(f" {label:30s} {v:+.2f}") | |
| # Spread log | |
| if obs.spread_events: | |
| lines += ["", " LATERAL MOVEMENT LOG", " ─────────────────────────────────────"] | |
| for ev in obs.spread_events[-5:]: | |
| crit = " ⚠" if ev.get("critical") else "" | |
| lines.append(f" Step {ev['step']:>2} {ev['from']:10s} ──► {ev['to']}{crit}") | |
| if at_risk: | |
| lines += ["", f" ⚡ VIRUS MAY SPREAD TO (still clean, do NOT isolate): {', '.join(sorted(at_risk))}"] | |
| lines += ["", f" {obs.message}", ""] | |
| return "\n".join(lines) | |
| def _call_llm(obs) -> NetworkAction: | |
| client = OpenAI(base_url=_API_BASE_URL, api_key=_API_KEY) | |
| prompt = _build_display(obs) | |
| print("=" * 60) | |
| print("[LLM INPUT]") | |
| print(prompt) | |
| print("=" * 60) | |
| try: | |
| resp = client.chat.completions.create( | |
| model=_MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": _SYSTEM_PROMPT}, | |
| {"role": "user", "content": _build_display(obs)}, | |
| ], | |
| temperature=0.3, | |
| max_tokens=400, | |
| ) | |
| raw = (resp.choices[0].message.content or "").strip() | |
| if raw.startswith("```"): | |
| raw = raw.split("```")[1].lstrip("json").strip() | |
| m = re.search(r'\{.*\}', raw, re.DOTALL) | |
| if not m: | |
| raise ValueError("No JSON in LLM response") | |
| parsed = json.loads(m.group()) | |
| return NetworkAction( | |
| action_type=parsed.get("action_type", "scan_network"), | |
| target=parsed.get("target") or None, | |
| reasoning=parsed.get("reasoning", ""), | |
| ) | |
| except Exception as e: | |
| return NetworkAction(action_type="scan_network", reasoning=f"LLM error: {e}") | |
| # --------------------------------------------------------------------------- | |
| # Gradio button handlers | |
| # --------------------------------------------------------------------------- | |
| def web_reset(task: str): | |
| _last_task[0] = task | |
| obs = _env.reset(task=task) | |
| _current_obs[0] = obs | |
| _episode_done[0] = False | |
| if _WANDB_AVAILABLE: | |
| from uuid import uuid4 | |
| wandb.init(project="vir_env", name=f"{task}-{uuid4().hex[:6]}", reinit=True) | |
| return _build_display(obs), "", "", f"Episode started [{task.upper()}]. Click AI Step." | |
| def web_ai_step(): | |
| if _episode_done[0] or _current_obs[0] is None: | |
| # Auto-reset instead of blocking | |
| obs = _env.reset(task=_last_task[0]) | |
| _current_obs[0] = obs | |
| _episode_done[0] = False | |
| obs = _current_obs[0] | |
| action = _call_llm(obs) | |
| result = _env.step(action) | |
| _current_obs[0] = result | |
| _episode_done[0] = result.done | |
| if _WANDB_AVAILABLE and wandb.run is not None: | |
| wandb.log({ | |
| "step": result.step, | |
| "infected_count": result.infected_count, | |
| "isolated_count": result.isolated_count, | |
| "reward": result.reward, | |
| "cumulative_score": result.cumulative_score, | |
| "action_type": action.action_type.value, | |
| "target": action.target or "none", | |
| "db_infected_steps":result.db_infected_steps, | |
| "auth_compromised": result.auth_compromised, | |
| "done": result.done, | |
| "task": _last_task[0], | |
| }) | |
| if result.done: | |
| wandb.log({ | |
| "episode/won": result.infected_count == 0, | |
| "episode/total_steps": result.step, | |
| "episode/final_score": result.cumulative_score, | |
| }) | |
| action_type = action.action_type.value | |
| target = action.target or "—" | |
| reasoning = action.reasoning | |
| if result.done: | |
| reasoning += "\n\n" + ("✅ NETWORK SECURED!" if result.infected_count == 0 else "❌ Step budget exhausted.") | |
| return _build_display(result), action_type, target, reasoning | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks(title="AI-SOAR: Network Defense", theme=gr.themes.Monochrome()) as ai_interface: | |
| gr.Markdown("# 🛡️ AI-SOAR Network Defense\nWatch the AI agent chase and contain the spreading virus in real time.") | |
| with gr.Row(): | |
| task_dd = gr.Dropdown(["easy", "medium", "hard"], value="easy", label="Difficulty") | |
| reset_btn = gr.Button("🔄 Reset", variant="secondary") | |
| step_btn = gr.Button("🤖 AI Step", variant="primary") | |
| with gr.Row(): | |
| net_display = gr.Textbox(label="Network State", lines=22, interactive=False) | |
| with gr.Column(): | |
| action_type_box = gr.Textbox(label="Action Type", lines=1, interactive=False) | |
| target_box = gr.Textbox(label="Target", lines=1, interactive=False) | |
| reasoning_box = gr.Textbox(label="Reasoning", lines=4, interactive=False) | |
| reset_btn.click(web_reset, inputs=[task_dd], outputs=[net_display, action_type_box, target_box, reasoning_box]) | |
| step_btn.click( web_ai_step, inputs=[], outputs=[net_display, action_type_box, target_box, reasoning_box]) | |
| # --------------------------------------------------------------------------- | |
| # App assembly | |
| # --------------------------------------------------------------------------- | |
| app = create_app( | |
| NetworkEnvironment, | |
| NetworkAction, | |
| NetworkObservation, | |
| env_name="vir_env", | |
| max_concurrent_envs=5, | |
| ) | |
| app = gr.mount_gradio_app(app, ai_interface, path="/ai") | |
| def main(): | |
| import argparse, uvicorn | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--host", default="0.0.0.0") | |
| parser.add_argument("--port", type=int, default=8000) | |
| args = parser.parse_args() | |
| uvicorn.run(app, host=args.host, port=args.port) | |
| if __name__ == "__main__": | |
| main() | |