vir_env / server /app.py
arun-misra's picture
Upload folder using huggingface_hub
34a06f7 verified
import os, json, re, textwrap
from openai import OpenAI
import gradio as gr
try:
import wandb
_WANDB_AVAILABLE = callable(getattr(wandb, "init", None))
except ImportError:
_WANDB_AVAILABLE = False
try:
from dotenv import load_dotenv
load_dotenv(dotenv_path=os.path.join(os.path.dirname(os.path.dirname(__file__)), ".env"), override=False)
except ImportError:
pass
try:
from openenv.core.env_server.http_server import create_app
except Exception as e:
raise ImportError("openenv is required.") from e
try:
from ..models import NetworkAction, NetworkObservation
from .vir_env_environment import NetworkEnvironment
except (ImportError, ValueError):
try:
from models import NetworkAction, NetworkObservation
from server.vir_env_environment import NetworkEnvironment
except ImportError:
import sys
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from models import NetworkAction, NetworkObservation
from server.vir_env_environment import NetworkEnvironment
# ---------------------------------------------------------------------------
# LLM configuration
# ---------------------------------------------------------------------------
_API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN") or "ollama"
_API_BASE_URL = os.getenv("API_BASE_URL") or "http://127.0.0.1:11434/v1"
_MODEL_NAME = os.getenv("MODEL_NAME") or "gemma3:4b"
_SYSTEM_PROMPT = textwrap.dedent("""
You are an AI SOC analyst racing to contain a spreading virus.
=== NETWORK NODES (20-NODE ENTERPRISE) ===
DMZ: Gateway, Firewall, LoadBalancer
Web Tier: Web1, Web2, Web3
App Tier: App1, App2, App3
Data Tier: DB_Primary ★, DB_Replica ★, Backup ★, Cache1, Cache2
Auth/Mgmt: Auth ★, LDAP ★, Monitor, SIEM
Internal: Workstation1, Workstation2
★ = critical asset — losing these fails the mission.
=== TOOLS ===
1. scan_network — Refresh network state.
2. isolate_node(target) — Sever ALL connections instantly. Stops spread from that node.
WARNING: isolating DB_Primary costs -0.50 penalty.
3. deploy_patch(target) — Remove malware AND reconnect node to network.
=== DIFFICULTY RULES ===
- EASY: 1 entry point. Virus spreads randomly (50%). 30 steps. No breach timer.
- MEDIUM: 2 entry points. Virus hunts DB_Primary. DB_Primary infected 5+ consecutive steps = DATA BREACH = instant loss. 20 steps.
- HARD: 3 entry points. Auth infected = all spread rates DOUBLE. DB_Primary breach in 2 steps. 15 steps.
=== PRIORITY ===
1. Look at INFECTED NODES list — those are your ONLY targets.
2. DB_Primary infected → patch IMMEDIATELY (db_infected_steps counts to breach; do NOT isolate — use deploy_patch).
3. Auth infected → isolate IMMEDIATELY (doubles all spread rates).
4. Any other infected node → isolate it.
5. Once all infections cleared → patch isolated nodes to restore network.
⚠ CRITICAL: "AT RISK" = virus MIGHT spread there — they are still CLEAN.
NEVER isolate an AT RISK node. ONLY isolate INFECTED nodes.
RESPOND ONLY with valid JSON:
{"action_type": "scan_network|isolate_node|deploy_patch", "target": "NodeName or null", "reasoning": "one sentence"}
""").strip()
# ---------------------------------------------------------------------------
# Shared environment state for the AI web interface
# ---------------------------------------------------------------------------
_env = NetworkEnvironment()
_current_obs = [None]
_episode_done = [True]
_last_task = ["easy"]
HIGH_VALUE = {"DB_Primary", "DB_Replica", "Auth", "LDAP", "Backup"}
def _build_display(obs) -> str:
net = obs.network_state
infected_set = {n for n, d in net.items() if d.get("status") == "infected"}
at_risk = {
n for n, d in net.items()
if d.get("status") == "clean"
and any(c in infected_set for c in d.get("connections", []))
}
if any(n in HIGH_VALUE for n in infected_set):
threat = "🔴 CRITICAL"
elif obs.infected_count >= 3:
threat = "🟠 HIGH"
elif obs.infected_count >= 1:
threat = "🟡 MEDIUM"
else:
threat = "🟢 CLEAR"
scenario = getattr(obs, "scenario_name", obs.task.upper())
score = getattr(obs, "cumulative_score", 0.0)
db_steps = getattr(obs, "db_infected_steps", 0)
auth_comp = getattr(obs, "auth_compromised", False)
breakdown = getattr(obs, "reward_breakdown", {})
infected_list = ', '.join(sorted(infected_set)) if infected_set else "NONE — network clean!"
lines = [
f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━",
f" SOC OPERATIONS CENTER [{scenario}]",
f" Step {obs.step}/{obs.max_steps} | Threat: {threat} | Score: {score:+.2f}",
f"━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━",
"",
f" >> INFECTED NODES (YOUR TARGETS): {infected_list} <<",
"",
" NODE STATUS",
" ─────────────────────────────────────",
]
for node, info in net.items():
status = info.get("status", "clean")
conns = info.get("connections", [])
if node in at_risk and status == "clean":
icon = "🟡 AT RISK "
elif status == "infected":
icon = "🔴 INFECTED "
elif status == "isolated":
icon = "⬛ ISOLATED "
else:
icon = "🟢 ONLINE "
conn_str = ", ".join(conns) if conns else "no connections"
hv = " ★" if node in HIGH_VALUE else ""
lines.append(f" {icon} {node:10s}{hv}{conn_str}")
lines += [
"",
f" Infected: {obs.infected_count} | At Risk: {len(at_risk)} "
f"| Isolated: {obs.isolated_count} | Secure: {obs.clean_count - len(at_risk)}",
]
# Warnings
warnings = []
if db_steps > 0:
breach = getattr(obs, "breach_threshold", None)
limit = f"/{breach}" if breach else ""
warnings.append(f"⚠ DB EXPOSED {db_steps}{limit} steps — patch immediately!")
if auth_comp:
warnings.append("⚠ AUTH WAS COMPROMISED — spread rates may be doubled!")
if warnings:
lines += [""] + [f" {w}" for w in warnings]
# Reward breakdown
if breakdown:
lines += ["", " LAST STEP PENALTIES/REWARDS", " ─────────────────────────────────────"]
labels = {
"step": "Time penalty",
"infected_nodes": "Infected node damage",
"db_at_risk": "DB exposure penalty",
"auth_compromised": "Auth compromise penalty",
"db_isolate": "DB isolation downtime",
"data_breach": "DATA BREACH",
"fast_containment": "Fast containment bonus",
"win": "Mission success bonus",
}
for k, v in breakdown.items():
label = labels.get(k, k)
lines.append(f" {label:30s} {v:+.2f}")
# Spread log
if obs.spread_events:
lines += ["", " LATERAL MOVEMENT LOG", " ─────────────────────────────────────"]
for ev in obs.spread_events[-5:]:
crit = " ⚠" if ev.get("critical") else ""
lines.append(f" Step {ev['step']:>2} {ev['from']:10s} ──► {ev['to']}{crit}")
if at_risk:
lines += ["", f" ⚡ VIRUS MAY SPREAD TO (still clean, do NOT isolate): {', '.join(sorted(at_risk))}"]
lines += ["", f" {obs.message}", ""]
return "\n".join(lines)
def _call_llm(obs) -> NetworkAction:
client = OpenAI(base_url=_API_BASE_URL, api_key=_API_KEY)
prompt = _build_display(obs)
print("=" * 60)
print("[LLM INPUT]")
print(prompt)
print("=" * 60)
try:
resp = client.chat.completions.create(
model=_MODEL_NAME,
messages=[
{"role": "system", "content": _SYSTEM_PROMPT},
{"role": "user", "content": _build_display(obs)},
],
temperature=0.3,
max_tokens=400,
)
raw = (resp.choices[0].message.content or "").strip()
if raw.startswith("```"):
raw = raw.split("```")[1].lstrip("json").strip()
m = re.search(r'\{.*\}', raw, re.DOTALL)
if not m:
raise ValueError("No JSON in LLM response")
parsed = json.loads(m.group())
return NetworkAction(
action_type=parsed.get("action_type", "scan_network"),
target=parsed.get("target") or None,
reasoning=parsed.get("reasoning", ""),
)
except Exception as e:
return NetworkAction(action_type="scan_network", reasoning=f"LLM error: {e}")
# ---------------------------------------------------------------------------
# Gradio button handlers
# ---------------------------------------------------------------------------
def web_reset(task: str):
_last_task[0] = task
obs = _env.reset(task=task)
_current_obs[0] = obs
_episode_done[0] = False
if _WANDB_AVAILABLE:
from uuid import uuid4
wandb.init(project="vir_env", name=f"{task}-{uuid4().hex[:6]}", reinit=True)
return _build_display(obs), "", "", f"Episode started [{task.upper()}]. Click AI Step."
def web_ai_step():
if _episode_done[0] or _current_obs[0] is None:
# Auto-reset instead of blocking
obs = _env.reset(task=_last_task[0])
_current_obs[0] = obs
_episode_done[0] = False
obs = _current_obs[0]
action = _call_llm(obs)
result = _env.step(action)
_current_obs[0] = result
_episode_done[0] = result.done
if _WANDB_AVAILABLE and wandb.run is not None:
wandb.log({
"step": result.step,
"infected_count": result.infected_count,
"isolated_count": result.isolated_count,
"reward": result.reward,
"cumulative_score": result.cumulative_score,
"action_type": action.action_type.value,
"target": action.target or "none",
"db_infected_steps":result.db_infected_steps,
"auth_compromised": result.auth_compromised,
"done": result.done,
"task": _last_task[0],
})
if result.done:
wandb.log({
"episode/won": result.infected_count == 0,
"episode/total_steps": result.step,
"episode/final_score": result.cumulative_score,
})
action_type = action.action_type.value
target = action.target or "—"
reasoning = action.reasoning
if result.done:
reasoning += "\n\n" + ("✅ NETWORK SECURED!" if result.infected_count == 0 else "❌ Step budget exhausted.")
return _build_display(result), action_type, target, reasoning
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
with gr.Blocks(title="AI-SOAR: Network Defense", theme=gr.themes.Monochrome()) as ai_interface:
gr.Markdown("# 🛡️ AI-SOAR Network Defense\nWatch the AI agent chase and contain the spreading virus in real time.")
with gr.Row():
task_dd = gr.Dropdown(["easy", "medium", "hard"], value="easy", label="Difficulty")
reset_btn = gr.Button("🔄 Reset", variant="secondary")
step_btn = gr.Button("🤖 AI Step", variant="primary")
with gr.Row():
net_display = gr.Textbox(label="Network State", lines=22, interactive=False)
with gr.Column():
action_type_box = gr.Textbox(label="Action Type", lines=1, interactive=False)
target_box = gr.Textbox(label="Target", lines=1, interactive=False)
reasoning_box = gr.Textbox(label="Reasoning", lines=4, interactive=False)
reset_btn.click(web_reset, inputs=[task_dd], outputs=[net_display, action_type_box, target_box, reasoning_box])
step_btn.click( web_ai_step, inputs=[], outputs=[net_display, action_type_box, target_box, reasoning_box])
# ---------------------------------------------------------------------------
# App assembly
# ---------------------------------------------------------------------------
app = create_app(
NetworkEnvironment,
NetworkAction,
NetworkObservation,
env_name="vir_env",
max_concurrent_envs=5,
)
app = gr.mount_gradio_app(app, ai_interface, path="/ai")
def main():
import argparse, uvicorn
parser = argparse.ArgumentParser()
parser.add_argument("--host", default="0.0.0.0")
parser.add_argument("--port", type=int, default=8000)
args = parser.parse_args()
uvicorn.run(app, host=args.host, port=args.port)
if __name__ == "__main__":
main()