network_forensics / inference.py
WHOAM-EYE's picture
Upload folder using huggingface_hub
ded853c verified
import json
import os
import sys
import asyncio
import inspect
import random
import time
from pathlib import Path
from typing import Any
from dotenv import load_dotenv
from openai import OpenAI
from openenv.core.containers.runtime.providers import LocalDockerProvider
sys.path.insert(0, str(Path(__file__).parent))
from client import NetworkForensicsEnv
from models import NetworkForensicsAction
load_dotenv(Path(__file__).parent / ".env")
API_BASE_URL = os.getenv("API_BASE_URL")
MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b")
API_KEY = os.getenv("OPENAI_API_KEY") or os.getenv("API_KEY") or os.getenv("HF_TOKEN")
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "network-forensics-env:latest")
ENV_MODE = (
os.getenv("NETWORK_FORENSICS_ENV_MODE") or os.getenv("ENV_MODE") or "hf"
).lower()
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
HF_SPACE_ID = (
os.getenv("HF_SPACE_ID") or os.getenv("SPACE_ID") or "WHOAM-EYE/network_forensics"
)
HF_SPACE_URL = os.getenv("HF_SPACE_URL", "https://whoam-eye-network-forensics.hf.space")
DOCKER_READY_TIMEOUT_S = float(os.getenv("DOCKER_READY_TIMEOUT_S", "120"))
_ASYNC_LOOP: asyncio.AbstractEventLoop | None = None
SYSTEM_PROMPT = """You are a senior Network Forensics Analyst. Your goal is to investigate malicious network traffic and achieve a 100% detection score.
### SCORING RULES:
- You MUST identify and `flag_as_suspicious` EVERY malicious packet to maximize RECALL (very important!).
- Only grouped packets or flagged packets contribute towards your score.
- If RECALL is < 0.5, your score will be 0.0. DO NOT stop until you have flagged/grouped at least 60% of visible malicious packets.
- Entry point must be the EARLIEST packet that initiated the attack (often in first group).
- For HARD tasks: wrong entry point = score 0. Always identify_entry_point before submitting.
### WORKFLOW:
1. **Explore**: `inspect_packet` on suspicious samples.
2. **Flag**: `flag_as_suspicious` on ALL revealed malicious packets.
3. **Correlate**: `group_into_session` with descriptive names.
4. **Classify**: `tag_pattern` with a valid type.
5. **Root Cause**: `identify_entry_point` with the earliest malicious packet.
6. **Report**: `submit_report` ONLY when you have covered all visible malicious sessions.
### VALID PATTERN TYPES:
ddos, dos_slowloris, dos_slowhttptest, dos_goldeneye, dos_hulk, heartbleed, web_sql_injection, web_xss, web_bruteforce, c2, exfiltration, scan, lateral
### JSON SCHEMA EXAMPLES (Use these exactly):
- Inspect: {"action_type":"inspect_packet","packet_id":"pkt_0001"}
- Flag: {"action_type":"flag_as_suspicious","packet_id":"pkt_0001"}
- Group: {"action_type":"group_into_session","session_name":"DDoS_Burst_2","packet_ids":["pkt_0001","pkt_0002"]}
- Tag: {"action_type":"tag_pattern","session_name":"DDoS_Burst_2","pattern_type":"ddos"}
- Entry: {"action_type":"identify_entry_point","claimed_entry_point":"pkt_0001"}
- Report: {"action_type":"submit_report","incident_summary":"Detailed incident summary here.","claimed_entry_point":"pkt_0001"}"""
HISTORY_WINDOW = 20
REPEAT_ACTION_LIMIT = 3
CORRECTION_WINDOW = 5
UNTAGGED_BACKLOG_LIMIT = 6
INSPECT_SOFT_RATIO_THRESHOLD = 0.60
SOFT_STEP_BUDGETS = {"easy": 14, "medium": 28, "hard": 40}
HARD_STEP_CAPS = {"easy": 30, "medium": 50, "hard": 65}
TASK_SCORE_TARGETS = {"easy": 0.70, "medium": 0.68, "hard": 0.66}
TASK_COVERAGE_TARGETS = {"easy": 0.32, "medium": 0.24, "hard": 0.20}
MAX_TASK_SECONDS = float(os.getenv("MAX_TASK_SECONDS", "780"))
TASK_TIME_BUDGET_SECONDS = {
"easy": float(os.getenv("EASY_MAX_SECONDS", "150")),
"medium": float(os.getenv("MEDIUM_MAX_SECONDS", "220")),
"hard": float(os.getenv("HARD_MAX_SECONDS", "320")),
}
def build_client() -> OpenAI:
return OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
def validate_config() -> None:
missing = []
if not API_BASE_URL:
missing.append("API_BASE_URL")
if not API_KEY:
missing.append("OPENAI_API_KEY/API_KEY/HF_TOKEN")
if ENV_MODE == "hf" and not (HF_SPACE_URL or HF_SPACE_ID):
missing.append("HF_SPACE_URL or HF_SPACE_ID/SPACE_ID")
if missing:
raise RuntimeError(
f"Missing required environment variables: {', '.join(missing)}"
)
if ENV_MODE not in {"server", "docker", "hf"}:
raise RuntimeError(
"NETWORK_FORENSICS_ENV_MODE must be one of: server, docker, hf"
)
def format_action(action: NetworkForensicsAction) -> str:
payload = action.model_dump(exclude_none=True, exclude_defaults=True)
payload.pop("metadata", None)
payload = {
key: value for key, value in payload.items() if value not in ("", [], {})
}
return json.dumps(payload, separators=(",", ":"))
def summarize_observation(obs: Any, agent_state: dict[str, Any]) -> str:
"""Provide a compact structured summary for low-latency policy learning."""
packets = obs.visible_packets
revealed = [p for p in packets if p.is_revealed]
revealed_ids = [p.packet_id for p in revealed]
sessions = obs.grouped_sessions or {}
tags = obs.tagged_patterns or {}
untagged_sessions = [s for s in sessions.keys() if s not in tags]
last_reward = agent_state.get("last_step_reward")
reward_feedback = agent_state.get("last_reward_feedback", "n/a")
recent_corrections = agent_state.get("recent_corrections", [])[-CORRECTION_WINDOW:]
strategy_hints = agent_state.get("strategy_hints", [])
task_name = agent_state.get("current_task_name", "")
flagged_count = len(obs.flagged_packet_ids)
total_visible = max(1, len(obs.visible_packets))
coverage = flagged_count / total_visible
coverage_target = TASK_COVERAGE_TARGETS.get(task_name, 0.25)
score_target = TASK_SCORE_TARGETS.get(task_name, 0.65)
grouped_count = len(sessions)
tagged_count = len(tags)
ready_to_submit = (
obs.current_score_estimate >= score_target
and coverage >= coverage_target
and (task_name == "easy" or grouped_count >= 2)
and (task_name == "easy" or tagged_count >= 1)
)
summary = [
f"Step: {obs.step_number}/{obs.step_number + obs.steps_remaining}",
f"Current Progress: {obs.current_score_estimate:.2f}",
f"Coverage: {flagged_count}/{total_visible} ({coverage:.2%}) | target {coverage_target:.0%}",
f"Sessions: grouped={grouped_count}, tagged={tagged_count}",
f"Submit Readiness: {'READY' if ready_to_submit else 'KEEP INVESTIGATING'}",
f"Last Step Reward: {last_reward:.2f}" if isinstance(last_reward, (int, float)) else "Last Step Reward: n/a",
f"Last Reward Feedback: {reward_feedback}",
f"ALREADY REVEALED: {', '.join(revealed_ids[-6:])} " + ("..." if len(revealed_ids) > 6 else ""),
"\n### SESSIONS PENDING TAGGING:",
]
if recent_corrections:
summary.append("\n### RECENT CORRECTIONS:")
for reason in recent_corrections:
summary.append(f"- {reason}")
if strategy_hints:
summary.append("\n### STRATEGY HINTS:")
for hint in strategy_hints:
summary.append(f"- {hint}")
if untagged_sessions:
for s in untagged_sessions[:6]:
summary.append(f"- {s} ({len(sessions[s])} packets)")
else:
summary.append("- [No pending sessions]")
summary.append("\n### REVEALED INDICATORS:")
for p in revealed[-4:]:
payload = (p.full_payload or "")[:150]
if payload:
summary.append(f"- {p.packet_id}: {payload}")
summary.append("\n### UNKNOWN PACKETS (Must Inspect):")
unknown = [p for p in packets if not p.is_revealed][:10]
for p in unknown:
summary.append(f"- {p.packet_id} | {p.src_ip} -> {p.dst_ip} | Proto: {p.protocol}")
return "\n".join(summary)
def parse_action(raw_text: str) -> NetworkForensicsAction:
text = raw_text.strip()
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1:
raise ValueError("model did not return JSON")
data = json.loads(text[start : end + 1])
data.pop("metadata", None)
for key in ("session_name", "pattern_type", "claimed_entry_point"):
if data.get(key) == "":
data.pop(key, None)
if data.get("packet_ids") == []:
data.pop("packet_ids", None)
return NetworkForensicsAction(**data)
def sanitize_action(action: NetworkForensicsAction) -> NetworkForensicsAction:
payload = {"action_type": action.action_type}
if (
action.action_type in {"inspect_packet", "flag_as_suspicious"}
and action.packet_id
):
payload["packet_id"] = action.packet_id
elif action.action_type == "group_into_session":
if action.session_name:
payload["session_name"] = action.session_name
if action.packet_ids:
payload["packet_ids"] = action.packet_ids
elif action.action_type == "tag_pattern":
if action.session_name:
payload["session_name"] = action.session_name
if action.pattern_type:
payload["pattern_type"] = action.pattern_type
elif action.action_type == "identify_entry_point" and action.claimed_entry_point:
payload["claimed_entry_point"] = action.claimed_entry_point
if action.action_type == "submit_report":
if action.incident_summary:
payload["incident_summary"] = action.incident_summary
if action.claimed_entry_point:
payload["claimed_entry_point"] = action.claimed_entry_point
return NetworkForensicsAction(**payload)
def decode_payload_preview(payload_preview: str) -> str:
preview = (payload_preview or "").strip()
compact = "".join(preview.split())
if compact and len(compact) % 2 == 0:
try:
decoded = bytes.fromhex(compact).decode("utf-8", errors="ignore").strip()
if decoded:
return decoded
except ValueError:
pass
return preview
def packet_payload_text(packet: Any) -> str:
return packet.full_payload or decode_payload_preview(packet.payload_preview)
def keyword_to_pattern(payload: str) -> str | None:
text = payload.lower()
# --- DoS / DDoS variants ---
if "slowloris" in text:
return "dos_slowloris"
if "slowhttptest" in text or "slow http" in text:
return "dos_slowhttptest"
if "goldeneye" in text or "golden eye" in text:
return "dos_goldeneye"
if "hulk" in text:
return "dos_hulk"
if "heartbeat" in text or "heartbleed" in text or ("tls" in text and "ext" in text):
return "heartbleed"
if "flood" in text or "burst" in text or "ddos" in text:
return "ddos"
# HTTP flood indicators (repeated GET/POST to same endpoint)
if text.startswith("get /") or text.startswith("post /") or text.startswith("get http"):
if "accept-encoding" in text or "connection" in text or "keep-alive" in text:
return "ddos"
# SYN flood / connection flood
if "syn" in text and "ack" not in text and len(text) < 30:
return "ddos"
# ICMP flood
if "icmp" in text and ("echo" in text or "request" in text or len(text) < 20):
return "ddos"
# --- Web attacks ---
if "xss" in text or "<script>" in text or "<scrip" in text or "/search?q=" in text or "onerror" in text or "onload" in text or "javascript:" in text or "alert(" in text or "%3cscript" in text:
return "web_xss"
if (
"or 1=1" in text
or "%20or" in text
or "/items?id=" in text
or "1=1" in text
or "' or " in text
or "'--" in text
or "union select" in text
or "union all select" in text
or "drop table" in text
or "select * from" in text
or "sql" in text
or "%27" in text # URL-encoded single quote
or "' and " in text
or "admin'--" in text
):
return "web_sql_injection"
if (
"login" in text
or "username=admin" in text
or "password=" in text
or "passwd=" in text
or "user=admin" in text
or "brute" in text
or "/login" in text
or "/signin" in text
or "/auth" in text
or "post /login" in text
or "post /sign" in text
):
return "web_bruteforce"
# --- C2 / exfil / scan / lateral ---
if "c2" in text or "command" in text or "shell" in text or "cmd" in text or "/bin/" in text or "reverse" in text:
return "c2"
if "exfil" in text or "exfiltrat" in text or "data_leak" in text or "dns_tunnel" in text:
return "exfiltration"
if "scan" in text or "nmap" in text or "port_scan" in text or "recon" in text:
return "scan"
if "lateral" in text or "pivot" in text or "spread" in text or "propagat" in text:
return "lateral"
return None
def packet_sort_key(packet_id: str) -> int:
try:
return int(packet_id.rsplit("_", 1)[-1])
except ValueError:
return 0
def packet_signature(packet: Any, pattern: str) -> tuple[str, str, int, str]:
return (packet.src_ip, packet.dst_ip, packet.dst_port, pattern)
SUSPICIOUS_PORTS = {22, 23, 445, 1433, 3306, 5432, 4444, 5555, 6666, 6667, 7777, 8888, 9999, 31337}
SUSPICIOUS_PROTOCOLS = {"ICMP"}
def _infer_flow_pattern(packet: Any, flow_size: int) -> str | None:
"""Heuristic pattern inference from flow characteristics when keyword matching fails."""
dst_port = packet.dst_port
protocol = packet.protocol
flags = getattr(packet, "flags", []) or []
# High-density flows to web ports → likely DDoS
if flow_size >= 5 and dst_port in {80, 8080, 443, 8443}:
return "ddos"
# SYN-only flood
if flow_size >= 5 and flags == ["SYN"]:
return "ddos"
# Suspicious ports → C2 or lateral
if dst_port in SUSPICIOUS_PORTS:
if dst_port in {4444, 5555, 6666, 7777, 31337}:
return "c2"
if dst_port in {445, 1433, 3306, 5432}:
return "lateral"
# ICMP flood
if protocol in SUSPICIOUS_PROTOCOLS and flow_size >= 3:
return "ddos"
# High-density flow to non-standard port
if flow_size >= 8 and dst_port not in {53, 80, 443, 8080}:
return "scan"
return None
def session_candidates(obs: Any) -> list[tuple[tuple[str, str, int, str], list[Any]]]:
grouped: dict[tuple[str, str, int, str], list[Any]] = {}
attack_source_ports: dict[tuple[str, str, int, str], set[int]] = {}
# Phase 1: keyword-based grouping (high confidence)
for packet in obs.visible_packets:
pattern = keyword_to_pattern(packet_payload_text(packet))
if pattern:
key = packet_signature(packet, pattern)
grouped.setdefault(key, []).append(packet)
attack_source_ports.setdefault(key, set()).add(packet.src_port)
# Add reverse-response packets to keyword-matched sessions
for key, source_ports in attack_source_ports.items():
src_ip, dst_ip, dst_port, _pattern = key
for packet in obs.visible_packets:
is_reverse_response = (
packet.src_ip == dst_ip
and packet.dst_ip == src_ip
and packet.src_port == dst_port
and packet.dst_port in source_ports
)
if is_reverse_response:
grouped[key].append(packet)
# Phase 2: flow-based grouping for packets without keyword match
# Group unclaimed packets by (src_ip, dst_ip, dst_port) and infer pattern
claimed_ids: set[str] = set()
for items in grouped.values():
for p in items:
claimed_ids.add(p.packet_id)
flow_buckets: dict[tuple[str, str, int], list[Any]] = {}
for packet in obs.visible_packets:
if packet.packet_id in claimed_ids:
continue
flow_key = (packet.src_ip, packet.dst_ip, packet.dst_port)
flow_buckets.setdefault(flow_key, []).append(packet)
for flow_key, items in flow_buckets.items():
if len(items) < 2:
continue
pattern = _infer_flow_pattern(items[0], len(items))
if pattern:
session_key = (*flow_key, pattern)
grouped.setdefault(session_key, []).extend(items)
for p in items:
claimed_ids.add(p.packet_id)
candidates = [
(
key,
sorted(
{packet.packet_id: packet for packet in items}.values(),
key=lambda pkt: packet_sort_key(pkt.packet_id),
),
)
for key, items in grouped.items()
if len(items) >= 2
]
return sorted(candidates, key=lambda item: packet_sort_key(item[1][0].packet_id))
def required_tag_count(task_name: str, total_sessions: int) -> int:
if task_name == "hard":
return (total_sessions + 1) // 2
return 0
def select_inspect_packet(
obs: Any,
inspected_ids: set[str],
flagged_ids: set[str] | None = None,
) -> str | None:
flagged_ids = flagged_ids or set()
unrevealed = [
p
for p in obs.visible_packets
if (not p.is_revealed)
and (p.packet_id not in inspected_ids)
and (p.packet_id not in flagged_ids)
]
if not unrevealed:
return None
flow_counts: dict[tuple[str, str, int], int] = {}
for packet in obs.visible_packets:
key = (packet.src_ip, packet.dst_ip, packet.dst_port)
flow_counts[key] = flow_counts.get(key, 0) + 1
# Bias toward denser flows first to speed up session construction.
ranked = sorted(
unrevealed,
key=lambda p: (
-flow_counts.get((p.src_ip, p.dst_ip, p.dst_port), 0),
packet_sort_key(p.packet_id),
),
)
top_tier = ranked[: min(4, len(ranked))]
rng = random.Random(f"{obs.step_number}:{len(inspected_ids)}:{len(unrevealed)}")
return rng.choice(top_tier).packet_id
def append_action_history(agent_state: dict[str, Any], action: NetworkForensicsAction) -> None:
history = agent_state.setdefault("previous_actions", [])
history.append(format_action(action))
if action.action_type == "inspect_packet" and action.packet_id:
inspected_ids = agent_state.setdefault("inspected_ids", set())
inspected_ids.add(action.packet_id)
if len(history) > HISTORY_WINDOW:
del history[:-HISTORY_WINDOW]
def record_correction(agent_state: dict[str, Any], reason: str) -> None:
corrections = agent_state.setdefault("recent_corrections", [])
corrections.append(reason)
if len(corrections) > CORRECTION_WINDOW:
del corrections[:-CORRECTION_WINDOW]
def candidate_evidence(
candidate_packets: list[Any],
flagged_ids: set[str],
visible_by_id: dict[str, Any],
) -> tuple[int, int, int]:
flagged = 0
revealed = 0
malicious_revealed = 0
for item in candidate_packets:
packet = visible_by_id.get(item.packet_id, item)
if packet.packet_id in flagged_ids:
flagged += 1
if packet.is_revealed:
revealed += 1
if keyword_to_pattern(packet_payload_text(packet)):
malicious_revealed += 1
return flagged, revealed, malicious_revealed
def group_meets_evidence_gate(
candidate_packets: list[Any],
flagged_ids: set[str],
visible_by_id: dict[str, Any],
task_name: str,
trusted_pattern: bool = False,
) -> bool:
flagged, revealed, malicious_revealed = candidate_evidence(
candidate_packets, flagged_ids, visible_by_id
)
size = len(candidate_packets)
# Lowered thresholds for more aggressive grouping
if task_name == "easy":
min_flagged = 1 if size >= 2 else 0
elif task_name == "medium":
min_flagged = 1 if size >= 2 else 0
else:
min_flagged = 1 if size >= 3 else 0
if trusted_pattern and size >= 3:
min_flagged = 1
if flagged >= min_flagged:
return True
# Allow grouping with strong revealed malicious evidence.
if task_name == "easy" and (malicious_revealed >= 1 or revealed >= 1):
return True
if task_name == "medium" and malicious_revealed >= 1 and revealed >= 1:
return True
if malicious_revealed >= 1 and revealed >= min(2, size):
return True
# After a pattern has been confirmed by tagging, allow structure-first grouping.
if trusted_pattern and size >= 3:
return True
# Large flows are very likely attack sessions - allow with minimal evidence
if size >= 6 and (flagged >= 1 or revealed >= 2 or malicious_revealed >= 1):
return True
return False
def trusted_patterns(
session_map: dict[tuple[str, str, int, str], str], tagged_sessions: set[str]
) -> set[str]:
return {key[3] for key, name in session_map.items() if name in tagged_sessions}
def derive_strategy_hints(obs: Any, agent_state: dict[str, Any]) -> list[str]:
hints: list[str] = []
previous_actions = agent_state.get("previous_actions", [])
recent = previous_actions[-HISTORY_WINDOW:]
if recent:
inspect_recent = sum(1 for a in recent if '"inspect_packet"' in a)
inspect_ratio = inspect_recent / len(recent)
else:
inspect_ratio = 0.0
revealed_count = sum(1 for p in obs.visible_packets if p.is_revealed)
flagged_count = len(obs.flagged_packet_ids)
soft_limit = max(6, min(14, len(obs.visible_packets) // 15))
if revealed_count >= soft_limit and inspect_ratio >= INSPECT_SOFT_RATIO_THRESHOLD:
hints.append(
"Inspection is high. Prefer flagging suspicious revealed packets, then group/tag before further inspection."
)
if flagged_count == 0 and revealed_count >= 4:
hints.append(
"You have enough revealed packets. Start flagging suspicious packets before creating more sessions."
)
sessions = agent_state.get("sessions", {})
tagged_sessions = agent_state.get("tagged_sessions", set())
untagged_backlog = max(0, len(sessions) - len(tagged_sessions))
if untagged_backlog > UNTAGGED_BACKLOG_LIMIT:
hints.append(
"Tag pending sessions before creating new groups to avoid over-grouping."
)
inspect_limit = {
"easy": 18,
"medium": 20,
"hard": 25,
}.get(agent_state.get("current_task_name", ""), 15)
if len(previous_actions) >= inspect_limit and inspect_ratio >= INSPECT_SOFT_RATIO_THRESHOLD:
hints.append(
"You are over-inspecting. Shift to flagging, grouping, tagging, or report submission unless the next packet is clearly high-value."
)
return hints
def should_submit_early(task_name: str, obs: Any, agent_state: dict[str, Any]) -> bool:
flagged_count = len(obs.flagged_packet_ids)
total_visible = max(1, len(obs.visible_packets))
coverage = flagged_count / total_visible
score = float(obs.current_score_estimate)
sessions = obs.grouped_sessions or {}
tags = obs.tagged_patterns or {}
score_target = TASK_SCORE_TARGETS.get(task_name, 0.65)
coverage_target = TASK_COVERAGE_TARGETS.get(task_name, 0.25)
if task_name == "easy":
return (
coverage >= max(coverage_target * 0.7, 0.20)
and flagged_count >= 6
and len(sessions) >= 1
)
if task_name == "medium":
return (
score >= score_target * 0.8
and coverage >= coverage_target * 0.7
and len(sessions) >= 1
and len(tags) >= 1
)
return (
score >= score_target * 0.8
and coverage >= coverage_target * 0.7
and len(sessions) >= 2
and len(tags) >= 1
and bool(agent_state.get("claimed_entry_point") or obs.claimed_entry_point)
)
def build_fallback_action(
task_name: str, obs: Any, agent_state: dict[str, Any]
) -> NetworkForensicsAction:
"""Smart workflow engine: Flag aggressive -> Group -> Tag -> Entry Point -> Report."""
inspected_ids = agent_state.setdefault("inspected_ids", set())
flagged_ids = agent_state.setdefault("flagged_ids", set())
session_map = agent_state.setdefault("sessions", {}) # key -> session_name
tagged_sessions = agent_state.setdefault("tagged_sessions", set())
claimed_entry = agent_state.get("claimed_entry_point")
visible_by_id = {p.packet_id: p for p in obs.visible_packets}
trusted = trusted_patterns(session_map, tagged_sessions)
if obs.steps_remaining <= 1 or should_submit_early(task_name, obs, agent_state):
summary = _build_report_summary(obs, agent_state)
return NetworkForensicsAction(
action_type="submit_report",
incident_summary=summary,
claimed_entry_point=claimed_entry,
)
# PHASE 1: Aggressive flag of ALL revealed malicious packets
# This maximizes recall by comprehensively flagging known-bad traffic
unflagged_malicious = []
for packet in obs.visible_packets:
if packet.is_revealed and packet.packet_id not in flagged_ids:
payload = packet.full_payload or ""
pattern = keyword_to_pattern(payload)
if pattern:
unflagged_malicious.append(packet.packet_id)
if unflagged_malicious:
# Flag up to 5 per turn for aggressive recall buildup
target = min(5, len(unflagged_malicious))
for _ in range(target):
if unflagged_malicious:
pid = unflagged_malicious.pop(0)
flagged_ids.add(pid)
return NetworkForensicsAction(
action_type="flag_as_suspicious",
packet_id=pid,
)
# PHASE 2: Group flagged packets into sessions with evidence gate and backlog pacing.
min_flagged_before_group = 1 if task_name == "easy" else 2
untagged_backlog = max(0, len(session_map) - len(tagged_sessions))
if len(flagged_ids) >= min_flagged_before_group and untagged_backlog <= UNTAGGED_BACKLOG_LIMIT:
candidates = session_candidates(obs)
for key, items in candidates:
if key in session_map:
continue
if not group_meets_evidence_gate(
items,
flagged_ids,
visible_by_id,
task_name=task_name,
trusted_pattern=key[3] in trusted,
):
continue
packet_ids = [p.packet_id for p in items]
session_name = f"{task_name}_session_{len(session_map) + 1:02d}"
session_map[key] = session_name
return NetworkForensicsAction(
action_type="group_into_session",
session_name=session_name,
packet_ids=packet_ids,
)
# PHASE 2.5: Recall sweep - flag packets that are already part of grouped sessions.
# This boosts recall quickly without requiring more inspections.
grouped_packets = []
for packet_ids in (obs.grouped_sessions or {}).values():
grouped_packets.extend(packet_ids)
for pid in sorted(set(grouped_packets), key=packet_sort_key):
if pid in flagged_ids:
continue
if pid in visible_by_id:
flagged_ids.add(pid)
return NetworkForensicsAction(
action_type="flag_as_suspicious",
packet_id=pid,
)
# PHASE 3: Tag ALL untagged sessions aggressively (critical for medium/hard logic_score).
# Tagging helps LLM report score and logic_score for all difficulties.
for key, session_name in session_map.items():
if session_name in tagged_sessions:
continue
_src_ip, _dst_ip, _dst_port, pattern = key
tagged_sessions.add(session_name)
return NetworkForensicsAction(
action_type="tag_pattern",
session_name=session_name,
pattern_type=pattern,
)
# Also tag any observed sessions not yet in our session_map
for session_name, session_data in (obs.grouped_sessions or {}).items():
if session_name in tagged_sessions:
continue
if session_name in (obs.tagged_patterns or {}):
tagged_sessions.add(session_name)
continue
# Infer pattern from session packets
pattern = None
for pid in session_data:
pkt = visible_by_id.get(pid)
if pkt and pkt.is_revealed:
pattern = keyword_to_pattern(packet_payload_text(pkt))
if pattern:
break
if not pattern:
# Try flow-based inference
pkt = visible_by_id.get(session_data[0]) if session_data else None
if pkt:
pattern = _infer_flow_pattern(pkt, len(session_data))
if pattern:
tagged_sessions.add(session_name)
return NetworkForensicsAction(
action_type="tag_pattern",
session_name=session_name,
pattern_type=pattern,
)
# PHASE 4: Identify entry point - CRITICAL for hard mode (score=0 without it)
if not claimed_entry:
entry_candidate = None
# Strategy 1: earliest packet in any grouped session from observation
try:
grouped_packets = set()
for session_name in session_map.values():
if obs.grouped_sessions and session_name in obs.grouped_sessions:
grouped_packets.update(obs.grouped_sessions[session_name])
if grouped_packets:
entry_candidate = min(grouped_packets, key=lambda pid: packet_sort_key(pid))
except Exception:
pass
# Strategy 2: earliest flagged packet (often the first discovered attack)
if not entry_candidate and flagged_ids:
entry_candidate = min(flagged_ids, key=lambda pid: packet_sort_key(pid))
# Strategy 3: earliest revealed malicious packet
if not entry_candidate:
revealed_malicious = [
p for p in obs.visible_packets
if p.is_revealed and keyword_to_pattern(packet_payload_text(p))
]
if revealed_malicious:
entry_candidate = min(
revealed_malicious, key=lambda p: packet_sort_key(p.packet_id)
).packet_id
# Strategy 4: earliest packet in session_candidates
if not entry_candidate:
all_session_packets = []
for key, items in session_candidates(obs):
for p in items:
all_session_packets.append(p.packet_id)
if all_session_packets:
entry_candidate = min(all_session_packets, key=packet_sort_key)
# Strategy 5: earliest flagged packet from observation
if not entry_candidate and obs.flagged_packet_ids:
entry_candidate = min(obs.flagged_packet_ids, key=packet_sort_key)
if entry_candidate:
agent_state["claimed_entry_point"] = entry_candidate
return NetworkForensicsAction(
action_type="identify_entry_point",
claimed_entry_point=entry_candidate,
)
# PHASE 5: Inspect more unrevealed packets (to discover more malicious traffic)
inspect_id = select_inspect_packet(obs, inspected_ids, flagged_ids)
if inspect_id is not None:
return NetworkForensicsAction(action_type="inspect_packet", packet_id=inspect_id)
# PHASE 6: Submit report
summary = _build_report_summary(obs, agent_state)
return NetworkForensicsAction(
action_type="submit_report",
incident_summary=summary,
claimed_entry_point=claimed_entry,
)
def _build_report_summary(obs: Any, agent_state: dict[str, Any]) -> str:
"""Generate a detailed incident summary for high LLM judge scores."""
flagged = agent_state.get("flagged_ids", set())
sessions = agent_state.get("sessions", {})
tagged = agent_state.get("tagged_sessions", set())
entry_point = agent_state.get("claimed_entry_point") or getattr(obs, "claimed_entry_point", None)
patterns_by_session: dict[str, str] = {}
src_ips_by_pattern: dict[str, set[str]] = {}
dst_ips_by_pattern: dict[str, set[str]] = {}
for key, session_name in sessions.items():
if len(key) >= 4:
pattern = key[3]
patterns_by_session[session_name] = pattern
src_ips_by_pattern.setdefault(pattern, set()).add(key[0])
dst_ips_by_pattern.setdefault(pattern, set()).add(key[1])
# Build detailed per-pattern section
pattern_details = []
for pattern in sorted(set(patterns_by_session.values())):
srcs = ", ".join(sorted(src_ips_by_pattern.get(pattern, set()))[:5])
dsts = ", ".join(sorted(dst_ips_by_pattern.get(pattern, set()))[:5])
session_names = [n for n, p in patterns_by_session.items() if p == pattern]
pattern_details.append(
f" - {pattern}: {len(session_names)} session(s) from {srcs} targeting {dsts}"
)
pattern_section = "\n".join(pattern_details) if pattern_details else " - No patterns classified"
# Tagged pattern summary
tagged_details = []
for session_name in sorted(tagged):
pattern = patterns_by_session.get(session_name, "unknown")
tagged_details.append(f"{session_name}={pattern}")
tagged_section = "; ".join(tagged_details) if tagged_details else "none"
entry_section = f"Entry point: {entry_point}" if entry_point else "Entry point: not identified"
return (
f"INCIDENT REPORT\n\n"
f"Summary: Detected {len(flagged)} malicious packets across "
f"{len(sessions)} attack sessions.\n\n"
f"Attack Patterns:\n{pattern_section}\n\n"
f"Tagged Sessions: {tagged_section}\n\n"
f"{entry_section}\n\n"
f"Total flagged: {len(flagged)} | Total sessions: {len(sessions)} | "
f"Classified sessions: {len(tagged)}"
)
def should_override_action(
action: NetworkForensicsAction,
obs: Any,
agent_state: dict[str, Any],
task_name: str,
) -> str | None:
"""Checks if the action should be overridden. Returns the reason for override, or None."""
previous_actions = agent_state.setdefault("previous_actions", [])
flagged_ids = agent_state.setdefault("flagged_ids", set())
action_repr = format_action(action)
visible_by_id = {p.packet_id: p for p in obs.visible_packets}
sessions = agent_state.setdefault("sessions", {})
tagged_sessions = agent_state.setdefault("tagged_sessions", set())
trusted = trusted_patterns(sessions, tagged_sessions)
inspect_count = sum(1 for a in previous_actions if '"inspect_packet"' in a)
revealed_count = sum(1 for p in obs.visible_packets if p.is_revealed)
inspect_limit = {
"easy": 25,
"medium": 18,
"hard": 25,
}.get(task_name, 15)
if action.action_type not in {
"inspect_packet",
"flag_as_suspicious",
"group_into_session",
"tag_pattern",
"identify_entry_point",
"submit_report",
}:
return "Invalid action_type"
if len(previous_actions) >= 3:
if all(a == action_repr for a in previous_actions[-REPEAT_ACTION_LIMIT:]):
return "Identical action repeated 3 times consecutively (Infinite Loop)"
if action.action_type == "inspect_packet":
if not action.packet_id:
return "Missing packet_id for inspect_packet"
if action.packet_id not in {p.packet_id for p in obs.visible_packets}:
return f"Invalid packet_id {action.packet_id} - not in visible_packets"
inspected_ids = agent_state.setdefault("inspected_ids", set())
if action.packet_id in inspected_ids:
return f"Packet {action.packet_id} was already inspected. Choose a different hidden packet."
revealed_ids = {p.packet_id for p in obs.visible_packets if p.is_revealed}
if action.packet_id in revealed_ids:
return f"Packet {action.packet_id} is ALREADY revealed. Choose a HIDDEN packet."
if action.packet_id in set(obs.flagged_packet_ids):
return (
f"Packet {action.packet_id} is already flagged. Inspect a new hidden unflagged packet instead."
)
revealed_unflagged_malicious = [
p.packet_id
for p in obs.visible_packets
if p.is_revealed
and p.packet_id not in set(obs.flagged_packet_ids)
and keyword_to_pattern(packet_payload_text(p))
]
if revealed_unflagged_malicious:
return (
"Recall-first policy: revealed malicious packets exist and must be flagged before new inspection."
)
grouped_unflagged = [
pid
for packet_ids in (obs.grouped_sessions or {}).values()
for pid in packet_ids
if pid not in set(obs.flagged_packet_ids)
]
if grouped_unflagged:
return (
"Recall-first policy: grouped session packets remain unflagged. Flag them before further inspection."
)
if task_name == "easy" and len(flagged_ids) >= 4:
grouped_session_names = set((obs.grouped_sessions or {}).keys())
for key, items in session_candidates(obs):
if key in sessions:
continue
if len(items) >= 4:
return (
"Exploit mode: enough evidence exists. Group high-confidence attack flows before more inspection."
)
if inspect_count >= inspect_limit and (len(sessions) > 0 or len(flagged_ids) > 0 or revealed_count >= 4):
# Only block inspections for medium/hard modes; easy mode needs discovery
if task_name != "easy":
return (
f"Inspection budget reached for {task_name}. Shift to flagging, grouping, tagging, or report submission."
)
if action.action_type == "flag_as_suspicious":
if not action.packet_id:
return "Missing packet_id for flag_as_suspicious"
if action.packet_id not in {p.packet_id for p in obs.visible_packets}:
return f"Invalid packet_id {action.packet_id} - not in visible_packets"
if action.packet_id in set(obs.flagged_packet_ids):
return f"Packet {action.packet_id} is ALREADY flagged."
if action.action_type == "group_into_session":
if not action.session_name:
return "Missing session_name for group_into_session"
if not action.packet_ids or len(action.packet_ids) < 2:
return "Need at least 2 packet_ids to form a session"
invalid_ids = set(action.packet_ids) - {
p.packet_id for p in obs.visible_packets
}
if invalid_ids:
return f"Invalid packet_ids in session: {invalid_ids}"
if action.session_name in sessions.values():
return f"Session name {action.session_name} is already used."
min_flagged_before_group = 1 if task_name == "easy" else 1
if len(flagged_ids) < min_flagged_before_group:
return (
f"Group blocked until enough evidence is flagged ({len(flagged_ids)}/{min_flagged_before_group}). "
"Inspect and flag suspicious packets first."
)
new_group_ids = set(action.packet_ids)
for existing_ids in (obs.grouped_sessions or {}).values():
existing_set = set(existing_ids)
if not existing_set:
continue
overlap = len(new_group_ids & existing_set) / max(1, len(new_group_ids))
if overlap >= 0.8:
return "This grouping heavily overlaps an existing session. Prioritize new evidence."
untagged_backlog = max(0, len(sessions) - len(tagged_sessions))
if untagged_backlog > UNTAGGED_BACKLOG_LIMIT:
return (
"Too many untagged sessions pending. Tag existing sessions before grouping new ones."
)
candidate_packets = [visible_by_id[pid] for pid in action.packet_ids if pid in visible_by_id]
inferred_patterns = {
keyword_to_pattern(packet_payload_text(packet))
for packet in candidate_packets
if keyword_to_pattern(packet_payload_text(packet))
}
trusted_pattern = any(pattern in trusted for pattern in inferred_patterns)
if not group_meets_evidence_gate(
candidate_packets,
flagged_ids,
visible_by_id,
task_name=task_name,
trusted_pattern=trusted_pattern,
):
return (
"Insufficient evidence for grouping. Flag or reveal more suspicious packets in this flow first."
)
if action.action_type == "submit_report":
untagged_backlog = max(0, len(sessions) - len(tagged_sessions))
total_visible = max(1, len(obs.visible_packets))
flagged_count = len(obs.flagged_packet_ids)
coverage = flagged_count / total_visible
min_cov = TASK_COVERAGE_TARGETS.get(task_name, 0.25) * 0.6
min_flags = 4 if task_name == "easy" else (3 if task_name == "medium" else 4)
min_groups = 1 if task_name == "easy" else (2 if task_name == "medium" else 2)
if (
obs.steps_remaining > 2
and obs.current_score_estimate < 0.40
and not should_submit_early(task_name, obs, agent_state)
):
return (
"Premature report submission. Improve coverage and score estimate before submit_report."
)
if obs.steps_remaining > 1 and (coverage < min_cov or flagged_count < min_flags):
return (
f"Premature report submission. Need stronger recall coverage before submit_report "
f"(coverage {coverage:.0%}/{min_cov:.0%}, flags {flagged_count}/{min_flags})."
)
if obs.steps_remaining > 1 and len(sessions) < min_groups:
return (
f"Premature report submission. Need stronger session evidence before submit_report "
f"(grouped {len(sessions)}/{min_groups})."
)
if task_name == "hard" and obs.steps_remaining > 3 and untagged_backlog > 0:
return "Premature report submission. Tag pending sessions before submitting report."
# CRITICAL: Hard mode zero-out if no entry point identified
if task_name == "hard" and not (agent_state.get("claimed_entry_point") or obs.claimed_entry_point):
return (
"FATAL: Hard mode requires identify_entry_point before submit_report. "
"No entry point claimed yet — score will be 0.0 without it. "
"Use identify_entry_point with the earliest malicious packet first."
)
# Medium mode: need entry point for good logic_score
if task_name == "medium" and obs.steps_remaining > 5 and not (agent_state.get("claimed_entry_point") or obs.claimed_entry_point):
return (
"Missing entry point. Use identify_entry_point before submit_report for higher score."
)
# Require minimum tagging coverage for medium/hard
min_tagged = 1 if task_name == "medium" else 2
if task_name in {"medium", "hard"} and len(tagged_sessions) < min_tagged and obs.steps_remaining > 3:
return (
f"Premature report submission. Need at least {min_tagged} tagged session(s) before submit_report "
f"(currently {len(tagged_sessions)})."
)
if action.action_type == "tag_pattern":
if not action.session_name:
return "Missing session_name for tag_pattern"
if not action.pattern_type:
return "Missing pattern_type for tag_pattern"
if action.session_name in set((obs.tagged_patterns or {}).keys()):
return f"Session {action.session_name} is already tagged."
if task_name == "easy" and obs.steps_remaining > 8:
return "For easy mode, prioritize recall actions (inspect/flag/group) before tagging."
valid_patterns = {
"ddos", "dos_slowloris", "dos_slowhttptest", "dos_goldeneye", "dos_hulk",
"heartbleed", "web_sql_injection", "web_xss", "web_bruteforce",
"c2", "exfiltration", "scan", "lateral",
}
if action.pattern_type.lower() not in valid_patterns:
return f"Unknown pattern_type '{action.pattern_type}'"
if action.action_type == "identify_entry_point":
if not action.claimed_entry_point:
return "Missing claimed_entry_point for identify_entry_point"
# Lenient gating for easy mode
min_flags_needed = 1 if task_name == "easy" else (2 if task_name == "medium" else 2)
if obs.steps_remaining > 8 and len(flagged_ids) < min_flags_needed:
return (
"Premature entry-point claim. Gather and flag more evidence before identify_entry_point."
)
return None
def choose_action(
client: OpenAI,
task_name: str,
obs: Any,
agent_state: dict[str, Any],
model_name: str | None = None,
) -> NetworkForensicsAction:
agent_state["current_task_name"] = task_name
agent_state["strategy_hints"] = derive_strategy_hints(obs, agent_state)
if should_submit_early(task_name, obs, agent_state):
action = NetworkForensicsAction(
action_type="submit_report",
incident_summary=_build_report_summary(obs, agent_state),
claimed_entry_point=agent_state.get("claimed_entry_point") or obs.claimed_entry_point,
)
append_action_history(agent_state, action)
return action
history = agent_state.get("previous_actions", [])[-HISTORY_WINDOW:]
history_str = "\n".join([f"Step {i+1}: {a}" for i, a in enumerate(history)])
# Persist correction feedback so repeated mistakes remain visible.
recent_corrections = agent_state.get("recent_corrections", [])[-CORRECTION_WINDOW:]
correction_text = ""
if recent_corrections:
correction_text = "\n".join(f"- {item}" for item in recent_corrections)
correction_text = (
"\n### SYSTEM CORRECTIONS (recent):\n"
f"{correction_text}\n"
"Follow the JSON schema in the system prompt."
)
try:
response = client.chat.completions.create(
model=model_name or MODEL_NAME,
temperature=0.1,
timeout=LLM_TIMEOUT_S,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": f"TASK: {task_name}{correction_text}\n\n### RECENT HISTORY:\n{history_str}\n\n### CURRENT OBSERVATION:\n{summarize_observation(obs, agent_state)}",
},
],
)
except Exception as llm_exc:
print(f"[WARN] LLM call failed/timed out: {llm_exc}")
fallback = build_fallback_action(task_name, obs, agent_state)
append_action_history(agent_state, fallback)
return fallback
content = response.choices[0].message.content or ""
try:
action = sanitize_action(parse_action(content))
except Exception as e:
reason = f"Invalid JSON ({str(e)})"
record_correction(agent_state, reason)
fallback = build_fallback_action(task_name, obs, agent_state)
append_action_history(agent_state, fallback)
return fallback
reason = should_override_action(action, obs, agent_state, task_name)
if reason:
record_correction(agent_state, reason)
fallback = build_fallback_action(task_name, obs, agent_state)
append_action_history(agent_state, fallback)
return fallback
append_action_history(agent_state, action)
return action
def reward_feedback(action: NetworkForensicsAction, reward: float) -> str:
if action.action_type == "inspect_packet":
if reward < 0:
return "Inspect action was not useful. Try new packets or move to flag/group/tag."
return "Inspect yielded useful signal."
if action.action_type == "flag_as_suspicious":
if reward < 0:
return "Flagging was low quality or duplicate."
return "Flagging improved recall progress."
if action.action_type == "group_into_session":
if reward < 0:
return "Grouping did not match a strong attack session."
return "Grouping improved session structure."
if action.action_type == "tag_pattern":
if reward < 0:
return "Tag mismatch. Re-evaluate session characteristics."
return "Tag assignment was useful."
if action.action_type == "submit_report":
return "Report submitted. Score now reflects report quality and coverage."
return "Action completed."
def sync_agent_state(obs: Any, agent_state: dict[str, Any]) -> None:
inspected_ids = agent_state.setdefault("inspected_ids", set())
for packet in obs.visible_packets:
if packet.is_revealed:
inspected_ids.add(packet.packet_id)
flagged_ids = agent_state.setdefault("flagged_ids", set())
flagged_ids.update(obs.flagged_packet_ids)
tagged_sessions = agent_state.setdefault("tagged_sessions", set())
tagged_sessions.update(obs.tagged_patterns.keys())
if obs.claimed_entry_point:
agent_state["claimed_entry_point"] = obs.claimed_entry_point
def emit_step(
step_number: int,
action: NetworkForensicsAction,
reward: float,
done: bool,
error: str | None,
) -> None:
error_text = error if error is not None else "null"
done_text = str(done).lower()
print(
f"[STEP] step={step_number} action={format_action(action)} "
f"reward={reward:.2f} done={done_text} error={error_text}"
)
def normalize_score(score: float) -> float:
return max(0.0, min(1.0, score))
def final_metrics(obs: Any) -> dict[str, Any]:
return getattr(obs, "final_metrics", None) or getattr(obs, "metadata", None) or {}
class ExtendedWaitDockerProvider(LocalDockerProvider):
def wait_for_ready(self, base_url: str, timeout_s: float = 30.0) -> None:
super().wait_for_ready(base_url, timeout_s=DOCKER_READY_TIMEOUT_S)
def get_async_loop() -> asyncio.AbstractEventLoop:
global _ASYNC_LOOP
if _ASYNC_LOOP is None or _ASYNC_LOOP.is_closed():
_ASYNC_LOOP = asyncio.new_event_loop()
return _ASYNC_LOOP
def resolve_maybe_awaitable(value: Any) -> Any:
if inspect.isawaitable(value):
return get_async_loop().run_until_complete(value)
return value
def create_env() -> NetworkForensicsEnv:
# Preferred path: Hugging Face Space.
if ENV_MODE == "hf":
if HF_SPACE_URL:
return NetworkForensicsEnv(base_url=HF_SPACE_URL.rstrip("/"))
space_slug = HF_SPACE_ID.lower().replace("/", "-").replace("_", "-")
return NetworkForensicsEnv(base_url=f"https://{space_slug}.hf.space")
if ENV_MODE == "docker":
provider = ExtendedWaitDockerProvider()
return resolve_maybe_awaitable(
NetworkForensicsEnv.from_docker_image(LOCAL_IMAGE_NAME, provider=provider)
)
if ENV_MODE == "server":
return NetworkForensicsEnv(base_url=ENV_BASE_URL)
return NetworkForensicsEnv(base_url=ENV_BASE_URL)
def create_env_with_fallback() -> NetworkForensicsEnv:
# IF MANUAL SERVER MODE: Go straight to server
if ENV_MODE == "server":
print(f"[INFO] Manual Server Mode Active: Using {ENV_BASE_URL}")
return NetworkForensicsEnv(base_url=ENV_BASE_URL)
# 1) Try HF Space.
try:
env = NetworkForensicsEnv(base_url=HF_SPACE_URL.rstrip("/"))
_ = reset_env(env, "easy")
return env
except Exception as exc:
print(f"[WARN] HF space failed ({exc}); trying Docker.")
# 2) Try Docker.
try:
provider = ExtendedWaitDockerProvider()
env = resolve_maybe_awaitable(
NetworkForensicsEnv.from_docker_image(LOCAL_IMAGE_NAME, provider=provider)
)
_ = reset_env(env, "easy")
return env
except Exception as exc:
print(f"[WARN] Docker failed ({exc}); falling back to local simulation.")
# 3) Last resort: in-process environment.
try:
from server.network_forensics_environment import NetworkForensicsEnvironment
return NetworkForensicsEnvironment(task_id="easy") # type: ignore[return-value]
except Exception as exc:
raise RuntimeError(f"All environment backends failed: {exc}") from exc
def reset_env(env: NetworkForensicsEnv, task_name: str) -> Any:
result = resolve_maybe_awaitable(env.reset(task_id=task_name))
return result
def step_env(env: NetworkForensicsEnv, action: NetworkForensicsAction) -> Any:
result = resolve_maybe_awaitable(env.step(action))
return result
def extract_observation(result: Any) -> Any:
"""Support both direct observation returns and wrapped step/reset results."""
obs = getattr(result, "observation", result)
if obs is None:
raise RuntimeError("Environment returned no observation")
return obs
def extract_step_reward(step_result: Any, obs: Any) -> float:
reward = getattr(step_result, "reward", None)
if reward is None:
reward = getattr(obs, "reward", 0.0)
return float(reward or 0.0)
WS_RETRY_COUNT = 3
WS_RETRY_DELAY_S = 2.0
LLM_TIMEOUT_S = 45.0
def step_env_with_retry(
env: NetworkForensicsEnv,
action: NetworkForensicsAction,
task_name: str,
agent_state: dict[str, Any],
) -> tuple[Any, NetworkForensicsEnv | None]:
"""Try step_env with retries on WebSocket timeout.
Returns (step_result, new_env_or_None).
If the WebSocket connection drops, reconnects and retries.
"""
last_exc = None
for attempt in range(1, WS_RETRY_COUNT + 1):
try:
result = step_env(env, action)
return result, None
except Exception as exc:
last_exc = exc
exc_str = str(exc).lower()
is_ws_timeout = any(
kw in exc_str
for kw in ("keepalive", "ping timeout", "1011", "websocket", "connection")
)
if not is_ws_timeout:
raise
print(
f"[WARN] WebSocket timeout on attempt {attempt}/{WS_RETRY_COUNT}: {exc}"
)
if attempt < WS_RETRY_COUNT:
time.sleep(WS_RETRY_DELAY_S * attempt)
# Try reconnecting
try:
close_env(env)
except Exception:
pass
try:
env = create_env()
reset_result = reset_env(env, task_name)
obs = extract_observation(reset_result)
sync_agent_state(obs, agent_state)
print(f"[INFO] Reconnected to environment, resuming task={task_name}")
except Exception as reconnect_exc:
print(f"[WARN] Reconnect failed: {reconnect_exc}")
continue
raise last_exc # type: ignore[misc]
def close_env(env: NetworkForensicsEnv | None) -> None:
if env is None:
return
try:
resolve_maybe_awaitable(env.close())
except Exception:
pass
def close_async_loop() -> None:
global _ASYNC_LOOP
if _ASYNC_LOOP is not None and not _ASYNC_LOOP.is_closed():
_ASYNC_LOOP.close()
_ASYNC_LOOP = None
def run_task(task_name: str) -> None:
env: NetworkForensicsEnv | None = None
rewards: list[float] = []
final_steps = 0
final_score = 0.0
success = False
agent_state: dict[str, Any] = {}
client = build_client()
print(f"[START] task={task_name} env=network_forensics model={MODEL_NAME}")
try:
env = create_env()
reset_result = reset_env(env, task_name)
obs = extract_observation(reset_result)
sync_agent_state(obs, agent_state)
max_steps = obs.steps_remaining or 50
soft_budget = min(max_steps, SOFT_STEP_BUDGETS.get(task_name, max_steps))
hard_budget = min(max_steps, HARD_STEP_CAPS.get(task_name, max_steps))
start_ts = time.monotonic()
task_time_budget = min(MAX_TASK_SECONDS, TASK_TIME_BUDGET_SECONDS.get(task_name, MAX_TASK_SECONDS))
for _ in range(hard_budget):
if obs.done:
break
elapsed = time.monotonic() - start_ts
total_visible = max(1, len(obs.visible_packets))
current_coverage = len(obs.flagged_packet_ids) / total_visible
min_cov = TASK_COVERAGE_TARGETS.get(task_name, 0.25)
ready_for_budget_submit = (
obs.step_number >= soft_budget
and should_submit_early(task_name, obs, agent_state)
)
forced_at_hard_cap = (
obs.step_number >= max(1, hard_budget - 1)
and (should_submit_early(task_name, obs, agent_state) or task_name != "easy")
)
nearing_time_limit = elapsed >= max(20.0, task_time_budget - 12.0)
error = None
try:
if forced_at_hard_cap or nearing_time_limit or ready_for_budget_submit:
action = NetworkForensicsAction(
action_type="submit_report",
incident_summary=_build_report_summary(obs, agent_state),
claimed_entry_point=agent_state.get("claimed_entry_point") or obs.claimed_entry_point,
)
else:
action = choose_action(client, task_name, obs, agent_state)
except Exception as exc:
error = str(exc).replace("\n", " ")
action = build_fallback_action(task_name, obs, agent_state)
try:
step_result, new_env = step_env_with_retry(env, action, task_name, agent_state)
if new_env is not None:
env = new_env
except Exception as exc:
print(f"[WARN] step failure on task={task_name}: {exc}")
break
obs = extract_observation(step_result)
sync_agent_state(obs, agent_state)
step_reward = extract_step_reward(step_result, obs)
rewards.append(step_reward)
agent_state["last_step_reward"] = step_reward
agent_state["last_reward_feedback"] = reward_feedback(action, step_reward)
final_steps = obs.step_number
# Track the report quality score from the last submit_report step
metrics = final_metrics(obs)
if action.action_type == "submit_report" and metrics:
report_qs = metrics.get("final_score")
if report_qs is not None:
final_score = normalize_score(float(report_qs))
elif final_score == 0.0:
final_score = normalize_score(
metrics.get("final_score", obs.current_score_estimate)
if metrics
else obs.current_score_estimate
)
emit_step(
obs.step_number,
action,
step_reward,
bool(step_result.done),
error,
)
if step_result.done:
break
metrics = final_metrics(obs)
threshold_met = (
float(metrics.get("success_threshold_met", 0.0)) >= 1.0
if metrics
else False
)
success = bool(obs.done and (threshold_met or final_score >= 0.6))
except Exception:
success = False
raise
finally:
close_env(env)
rewards_text = ",".join(f"{reward:.2f}" for reward in rewards)
print(
f"[END] success={str(success).lower()} steps={final_steps} "
f"score={final_score:.2f} rewards={rewards_text}"
)
def main() -> None:
validate_config()
try:
for task_name in ("easy", "medium", "hard"):
run_task(task_name)
finally:
close_async_loop()
if __name__ == "__main__":
main()