from __future__ import annotations import json import time from typing import Any, Tuple import gradio as gr try: from ..inference import build_client, build_fallback_action, choose_action, packet_payload_text, sync_agent_state from ..models import NetworkForensicsAction, NetworkForensicsObservation from .network_forensics_environment import NetworkForensicsEnvironment except ImportError: from inference import build_client, build_fallback_action, choose_action, packet_payload_text, sync_agent_state from models import NetworkForensicsAction, NetworkForensicsObservation from server.network_forensics_environment import NetworkForensicsEnvironment # --------------------------------------------------------------------------- # Global state (single-session; fine for HF Spaces single-user demo) # --------------------------------------------------------------------------- env: NetworkForensicsEnvironment | None = None current_obs: NetworkForensicsObservation | None = None agent_state: dict[str, Any] = {} last_step_reward: float = 0.0 last_final_meta: dict[str, Any] = {} PATTERN_CHOICES = [ "ddos", "web_bruteforce", "web_xss", "web_sql_injection", "dos_hulk", "dos_goldeneye", "dos_slowloris", "dos_slowhttptest", "heartbleed", ] MODEL_CHOICES = [ "openai/gpt-oss-120b", "mistralai/mistral-small-4-119b-2603", "mistralai/mamba-codestral-7b-v0.1", "nvidia/nvidia-nemotron-nano-9b-v2", ] ACTION_TYPES = [ "inspect_packet", "flag_as_suspicious", "group_into_session", "tag_pattern", "identify_entry_point", "submit_report", ] # --------------------------------------------------------------------------- # Formatting helpers # --------------------------------------------------------------------------- def _parse_packet_ids(packet_ids: Any) -> list[str] | None: if packet_ids is None or packet_ids == "": return None if isinstance(packet_ids, list): values = [str(v).strip() for v in packet_ids if str(v).strip()] return values or None values = [v.strip() for v in str(packet_ids).split(",") if v.strip()] return values or None def _format_packets(obs: NetworkForensicsObservation) -> list[list[Any]]: rows: list[list[Any]] = [] flagged = set(obs.flagged_packet_ids) grouped = { packet_id for packet_ids in obs.grouped_sessions.values() for packet_id in packet_ids } for packet in obs.visible_packets[:30]: preview = packet_payload_text(packet) status = "" if packet.packet_id in flagged: status = "FLAG" elif packet.packet_id in grouped: status = "GROUP" rows.append([ status, packet.packet_id, packet.src_ip, packet.dst_ip, packet.dst_port, packet.protocol, packet.ttl, packet.payload_size, "full" if packet.is_revealed else "preview", (preview or "")[:120], ]) return rows def _format_summary(obs: NetworkForensicsObservation) -> str: pct_flagged = ( round(len(obs.flagged_packet_ids) / max(1, obs.total_packets) * 100, 1) ) lines = [ "### Episode Status", f"| Metric | Value |", f"|--------|-------|", f"| Step | **{obs.step_number}** (remaining: {obs.steps_remaining}) |", f"| Running Score | **{obs.current_score_estimate:.3f}** |", f"| Total Packets | **{obs.total_packets}** |", f"| Flagged | **{len(obs.flagged_packet_ids)}** ({pct_flagged}%) |", f"| Sessions | **{len(obs.grouped_sessions)}** |", f"| Tagged Patterns | **{len(obs.tagged_patterns)}** |", ] if obs.claimed_entry_point: lines.append(f"| Entry Point | `{obs.claimed_entry_point}` |") if obs.tagged_patterns: lines.append("\n**Tags:**") for session, tag in obs.tagged_patterns.items(): lines.append(f"- `{session}` -> `{tag}`") return "\n".join(lines) def _format_graph(obs: NetworkForensicsObservation) -> str: g = obs.connection_graph_summary if not g: return "_No graph data yet. Inspect packets to build the topology._" lines = ["### Connection Graph Summary"] # Top talkers talkers = g.get("top_talkers", []) if talkers: lines.append("\n**Top Talkers (by packet count)**") lines.append("| IP | Packets |") lines.append("|----|---------|") for entry in talkers[:10]: ip = entry.get("ip", entry) if isinstance(entry, dict) else str(entry) count = entry.get("packet_count", entry.get("count", "")) if isinstance(entry, dict) else "" lines.append(f"| `{ip}` | {count} |") # Top flows flows = g.get("top_flows", []) if flows: lines.append("\n**Top Flows**") lines.append("| Src -> Dst | Protocol | Packets |") lines.append("|-----------|----------|---------|") for flow in flows[:12]: if isinstance(flow, dict): src = flow.get("src", "?") dst = flow.get("dst", "?") protocols = flow.get("protocols", flow.get("protocol", "?")) proto = ", ".join(protocols) if isinstance(protocols, list) else str(protocols) count = flow.get("packet_count", flow.get("count", "")) lines.append(f"| `{src}` -> `{dst}` | {proto} | {count} |") else: lines.append(f"| {flow} | | |") # Stats stats = g.get("stats", {}) if stats: lines.append("\n**Graph Stats**") for k, v in stats.items(): lines.append(f"- **{k}**: {v}") return "\n".join(lines) def _format_final_scores(meta: dict[str, Any]) -> str: if not meta: return "_Submit an incident report to see final evaluation scores._" keys = [ ("final_precision", "Precision"), ("final_recall", "Recall"), ("final_logic", "Logic"), ("final_llm_report", "LLM Report Quality"), ("final_session_overlap", "Session Overlap"), ("final_pattern_score", "Pattern Score"), ("final_entry_score", "Entry Point Score"), ("final_score", "**FINAL SCORE**"), ] lines = ["### Final Evaluation Scores", "| Metric | Score |", "|--------|-------|"] for key, label in keys: if key in meta: val = meta[key] bar = "█" * int(float(val) * 10) + "░" * (10 - int(float(val) * 10)) lines.append(f"| {label} | {float(val):.3f} `{bar}` |") success = meta.get("success_threshold_met", 0) lines.append(f"\n**Success:** {'YES' if success else 'NO'}") return "\n".join(lines) def _final_metrics(obs: NetworkForensicsObservation | None) -> dict[str, Any]: if obs is None: return {} return getattr(obs, "final_metrics", None) or getattr(obs, "metadata", None) or {} def _control_updates(obs: NetworkForensicsObservation) -> tuple: packet_choices = [p.packet_id for p in obs.visible_packets] session_choices = list(obs.grouped_sessions.keys()) return ( gr.Dropdown(choices=packet_choices, value=None), gr.Dropdown(choices=packet_choices, value=[]), gr.Dropdown(choices=session_choices, value=None), gr.Dropdown(choices=PATTERN_CHOICES, value=None), gr.Dropdown(choices=packet_choices, value=None), ) def _mode_updates(mode: str) -> tuple: manual = mode == "Manual" return ( gr.Dropdown(interactive=manual), gr.Dropdown(interactive=manual), gr.Dropdown(interactive=manual), gr.Dropdown(interactive=manual), gr.Dropdown(interactive=manual), gr.Dropdown(interactive=manual), gr.Button(interactive=manual), gr.Button(interactive=manual), gr.Button(interactive=not manual), gr.Button(interactive=not manual), ) # --------------------------------------------------------------------------- # Event handlers # --------------------------------------------------------------------------- def reset_env(task_name: str): global env, current_obs, agent_state, last_step_reward, last_final_meta env = NetworkForensicsEnvironment(task_id=task_name) current_obs = env.reset() agent_state = {} last_step_reward = 0.0 last_final_meta = {} sync_agent_state(current_obs, agent_state) return ( _format_summary(current_obs), _format_packets(current_obs), _format_graph(current_obs), _format_final_scores({}), f"Episode reset for **{task_name}** task.", *_control_updates(current_obs), ) def set_mode(mode: str) -> tuple: msg = ( "**Manual mode** - pick actions yourself to explore reward shaping." if mode == "Manual" else "**Agent mode** - use Run Agent Step / Replay to watch the policy." ) return (*_mode_updates(mode), msg) def suggest_action(task_name: str, model_name: str): global current_obs, agent_state if current_obs is None: return "{}", None, [], None, None, None client = build_client() action = choose_action(client, task_name, current_obs, agent_state, model_name=model_name) payload = action.model_dump(exclude_none=True, exclude_defaults=True) payload.pop("metadata", None) return ( json.dumps(payload, indent=2), action.packet_id, action.packet_ids or [], action.session_name, action.pattern_type, action.claimed_entry_point, ) def run_agent_step(task_name: str, model_name: str): global current_obs, agent_state, env, last_step_reward, last_final_meta if env is None or current_obs is None: reset_env(task_name) client = build_client() try: action = choose_action(client, task_name, current_obs, agent_state, model_name=model_name) except Exception: action = build_fallback_action(task_name, current_obs, agent_state) payload = action.model_dump(exclude_none=True, exclude_defaults=True) payload.pop("metadata", None) current_obs = env.step(action) reward = current_obs.reward last_step_reward = reward meta = _final_metrics(current_obs) if meta: last_final_meta = dict(meta) sync_agent_state(current_obs, agent_state) log_line = f"Step {current_obs.step_number}: {json.dumps(payload)} -> reward {reward:.3f}" status = ( f"Episode finished. Step reward: **{reward:.3f}**" if current_obs.done else f"Agent step done. Reward: **{reward:.3f}**" ) return ( _format_summary(current_obs), _format_packets(current_obs), _format_graph(current_obs), _format_final_scores(last_final_meta), status, json.dumps(payload, indent=2), log_line, *_control_updates(current_obs), ) def replay_agent(task_name: str, model_name: str): global current_obs, agent_state, env, last_step_reward, last_final_meta if env is None or current_obs is None or current_obs.done: reset_env(task_name) client = build_client() replay_lines: list[str] = [] max_steps = current_obs.steps_remaining or 50 for _ in range(max_steps): if current_obs.done: break try: action = choose_action(client, task_name, current_obs, agent_state, model_name=model_name) except Exception: action = build_fallback_action(task_name, current_obs, agent_state) payload = action.model_dump(exclude_none=True, exclude_defaults=True) payload.pop("metadata", None) current_obs = env.step(action) reward = float(getattr(current_obs, 'reward', 0.0)) meta = _final_metrics(current_obs) if action.action_type == "submit_report" and meta: last_final_meta = dict(meta) elif meta: last_final_meta = dict(meta) sync_agent_state(current_obs, agent_state) replay_lines.append(f"Step {current_obs.step_number}: {json.dumps(payload)} -> {reward:.3f}") status = ( f"Replay complete. Final reward: **{reward:.3f}**" if current_obs.done else f"Replaying... step {current_obs.step_number} reward {reward:.3f}" ) yield ( _format_summary(current_obs), _format_packets(current_obs), _format_graph(current_obs), _format_final_scores(last_final_meta), status, json.dumps(payload, indent=2), "\n".join(replay_lines), *_control_updates(current_obs), ) time.sleep(0.3) def step_env_manual( action_type: str, packet_id: str, packet_ids: Any, session_name: str, pattern_type: str, claimed_entry_point: str, incident_summary: str, ): global env, current_obs, last_final_meta if env is None: return ( "### No episode running", [], "_No graph yet._", "_No scores yet._", "Choose a task and click **Reset Episode** first.", gr.Dropdown(), gr.Dropdown(), gr.Dropdown(), gr.Dropdown(), gr.Dropdown(), ) action = NetworkForensicsAction( action_type=action_type, packet_id=packet_id or None, packet_ids=_parse_packet_ids(packet_ids), session_name=session_name or None, pattern_type=pattern_type or None, claimed_entry_point=claimed_entry_point or None, incident_summary=incident_summary or None, ) current_obs = env.step(action) reward = float(getattr(current_obs, 'reward', 0.0)) meta = _final_metrics(current_obs) if action.action_type == "submit_report" and meta: last_final_meta = dict(meta) elif meta: last_final_meta = dict(meta) sync_agent_state(current_obs, agent_state) status = ( f"Episode complete. Step reward: **{reward:.3f}**" if current_obs.done else f"Action applied. Step reward: **{reward:.3f}**" ) return ( _format_summary(current_obs), _format_packets(current_obs), _format_graph(current_obs), _format_final_scores(last_final_meta), status, *_control_updates(current_obs), ) # --------------------------------------------------------------------------- # UI layout # --------------------------------------------------------------------------- def create_demo() -> gr.Blocks: css = """ body, .gradio-container { background: #0a0f1e !important; } .app-shell { max-width: 1600px; margin: 0 auto; } .panel { border: 1px solid rgba(99,179,237,0.15); border-radius: 16px; padding: 16px; background: rgba(10,20,40,0.85); backdrop-filter: blur(8px); } .hero { padding: 20px 28px; border-radius: 20px; background: linear-gradient(135deg, #05090f 0%, #0d2240 50%, #0a3060 100%); border: 1px solid rgba(99,179,237,0.2); margin-bottom: 12px; } .hero h1 { color: #63b3ed; margin: 0; font-size: 1.6rem; } .hero p { opacity: 0.7; margin-top: 6px; color: #a0c4e8; } .score-good { color: #68d391 !important; } .score-bad { color: #fc8181 !important; } """ with gr.Blocks( title="NetForensics-RL · Analyst Console", ) as demo: with gr.Column(elem_classes=["app-shell"]): gr.HTML(f"") gr.HTML("""

NetForensics-RL  ·  Analyst Console

Investigate network attacks with an AI agent or step through manually. Watch the connection graph build in real-time as packets are revealed.

""") with gr.Row(): # ── Left sidebar ──────────────────────────────────────────── with gr.Column(scale=1, min_width=280, elem_classes=["panel"]): gr.Markdown("### ⚙️ Episode Control") mode = gr.Radio(["Manual", "Agent"], label="Mode", value="Manual") task_select = gr.Radio(["easy", "medium", "hard"], label="Task", value="easy") model_name = gr.Dropdown( choices=MODEL_CHOICES, value=MODEL_CHOICES[0], label="LLM Model", ) reset_btn = gr.Button("Reset Episode", variant="primary") gr.Markdown("---") gr.Markdown("### Agent Controls") suggest_btn = gr.Button("Suggest Action (LLM)") agent_step_btn = gr.Button("Run Agent Step", interactive=False) replay_btn = gr.Button("Run Agent Replay", interactive=False) gr.Markdown("---") gr.Markdown("### Manual Action") action_type = gr.Dropdown(ACTION_TYPES, label="Action Type", value="inspect_packet") packet_id = gr.Dropdown(label="Packet ID", choices=[], value=None) packet_ids = gr.Dropdown(label="Packet IDs (multi)", choices=[], value=[], multiselect=True) session_name = gr.Dropdown(label="Session Name", choices=[], value=None, allow_custom_value=True) pattern_type = gr.Dropdown(label="Pattern Type", choices=PATTERN_CHOICES, value=None) claimed_entry_point = gr.Dropdown(label="Entry Point Packet", choices=[], value=None) incident_summary = gr.Textbox( label="Incident Summary (for submit_report)", lines=4, placeholder="Describe the attack: actors, targets, techniques, timeline…", ) step_btn = gr.Button("Apply Action", variant="secondary") # ── Main content area ──────────────────────────────────────── with gr.Column(scale=3): # Top row: status + LLM output with gr.Row(): with gr.Column(scale=2, elem_classes=["panel"]): summary = gr.Markdown("Click **Reset Episode** to begin.") status = gr.Markdown("") with gr.Column(scale=1, elem_classes=["panel"]): llm_json = gr.Code(label="LLM Action JSON", language="json", value="{}") # Middle: packet table with gr.Row(): with gr.Column(elem_classes=["panel"]): packets = gr.Dataframe( headers=["Status", "ID", "Src IP", "Dst IP", "Port", "Protocol", "TTL", "Size", "Payload Source", "Payload"], datatype=["str", "str", "str", "str", "number", "str", "number", "number", "str", "str"], interactive=False, wrap=True, label="Packet Stream", ) # Bottom: graph + scores + replay log with gr.Row(): with gr.Column(scale=2, elem_classes=["panel"]): graph_md = gr.Markdown("_No graph data yet._", label="") gr.Markdown("#### Connection Graph", visible=False) # label handled above with gr.Column(scale=1, elem_classes=["panel"]): scores_md = gr.Markdown("_Submit a report to see scores._") with gr.Column(scale=2, elem_classes=["panel"]): replay_log = gr.Code(label="Agent Replay Log", language="markdown", value="") # ── Common output list helpers ─────────────────────────────────────── # Order: summary, packets, graph, scores, status, packet_id, packet_ids, # session_name, pattern_type, claimed_entry_point common_outs = [summary, packets, graph_md, scores_md, status, packet_id, packet_ids, session_name, pattern_type, claimed_entry_point] # ── Wiring ────────────────────────────────────────────────────────── reset_btn.click( reset_env, inputs=task_select, outputs=common_outs, ) reset_btn.click(lambda: "", outputs=replay_log) step_btn.click( step_env_manual, inputs=[action_type, packet_id, packet_ids, session_name, pattern_type, claimed_entry_point, incident_summary], outputs=common_outs, ) suggest_btn.click( suggest_action, inputs=[task_select, model_name], outputs=[llm_json, packet_id, packet_ids, session_name, pattern_type, claimed_entry_point], ) agent_step_btn.click( run_agent_step, inputs=[task_select, model_name], outputs=[summary, packets, graph_md, scores_md, status, llm_json, replay_log, packet_id, packet_ids, session_name, pattern_type, claimed_entry_point], ) replay_btn.click( replay_agent, inputs=[task_select, model_name], outputs=[summary, packets, graph_md, scores_md, status, llm_json, replay_log, packet_id, packet_ids, session_name, pattern_type, claimed_entry_point], ) mode.change( set_mode, inputs=mode, outputs=[action_type, packet_id, packet_ids, session_name, pattern_type, claimed_entry_point, step_btn, suggest_btn, agent_step_btn, replay_btn, status], ) task_select.change(lambda: "", outputs=replay_log) demo.load( set_mode, inputs=mode, outputs=[action_type, packet_id, packet_ids, session_name, pattern_type, claimed_entry_point, step_btn, suggest_btn, agent_step_btn, replay_btn, status], ) return demo