import json import os import sys import asyncio import inspect from pathlib import Path from typing import Any from dotenv import load_dotenv from openai import OpenAI from openenv.core.containers.runtime.providers import LocalDockerProvider sys.path.insert(0, str(Path(__file__).parent)) from client import NetworkForensicsEnv from models import NetworkForensicsAction from server.network_forensics_environment import NetworkForensicsEnvironment load_dotenv(Path(__file__).parent / ".env") API_BASE_URL = os.getenv("API_BASE_URL") MODEL_NAME = os.getenv("MODEL_NAME") API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN") LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "network-forensics-env:latest") ENV_MODE = (os.getenv("NETWORK_FORENSICS_ENV_MODE") or os.getenv("ENV_MODE") or "docker").lower() ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000") DOCKER_READY_TIMEOUT_S = float(os.getenv("DOCKER_READY_TIMEOUT_S", "120")) _ASYNC_LOOP: asyncio.AbstractEventLoop | None = None SYSTEM_PROMPT = """You are a network forensics analyst operating in an RL environment. Choose exactly one next action using this JSON schema: {"action_type":"inspect_packet|flag_as_suspicious|group_into_session|tag_pattern|identify_entry_point|submit_report","packet_id":"pkt_0001","packet_ids":["pkt_0001","pkt_0002"],"session_name":"name","pattern_type":"ddos","claimed_entry_point":"pkt_0001"} Rules: - Return JSON only. - Prefer inspecting packets with suspicious payload previews, HTTP attack strings, DDoS bursts, or repeated unusual destinations. - Flag packets only after some evidence. - Group packets into a session only when they share the same src_ip, dst_ip, dst_port, and likely role. - Tag patterns using labels like ddos, web_bruteforce, web_xss, web_sql_injection, dos_hulk, dos_goldeneye, dos_slowloris, dos_slowhttptest, heartbleed. - Identify the entry point only when you have a strong guess. - Submit the report when you have already flagged multiple suspicious packets and created at least one session.""" def build_client() -> OpenAI: return OpenAI(base_url=API_BASE_URL, api_key=API_KEY) def validate_config() -> None: missing = [] if not API_BASE_URL: missing.append("API_BASE_URL") if not MODEL_NAME: missing.append("MODEL_NAME") if not API_KEY: missing.append("API_KEY") if missing: raise RuntimeError(f"Missing required environment variables: {', '.join(missing)}") if ENV_MODE not in {"local", "server", "docker"}: raise RuntimeError("NETWORK_FORENSICS_ENV_MODE must be one of: local, server, docker") def format_action(action: NetworkForensicsAction) -> str: payload = action.model_dump(exclude_none=True, exclude_defaults=True) payload.pop("metadata", None) payload = { key: value for key, value in payload.items() if value not in ("", [], {}) } return json.dumps(payload, separators=(",", ":")) def summarize_observation(obs: Any) -> str: packets = [] for packet in obs.visible_packets[:25]: packets.append( { "packet_id": packet.packet_id, "src_ip": packet.src_ip, "dst_ip": packet.dst_ip, "dst_port": packet.dst_port, "protocol": packet.protocol, "ttl": packet.ttl, "payload_size": packet.payload_size, "payload_preview": packet.payload_preview, "revealed_payload": packet.full_payload if packet.is_revealed else None, } ) summary = { "step_number": obs.step_number, "steps_remaining": obs.steps_remaining, "current_score_estimate": obs.current_score_estimate, "total_packets": obs.total_packets, "flagged_packet_ids": obs.flagged_packet_ids, "grouped_sessions": obs.grouped_sessions, "tagged_patterns": obs.tagged_patterns, "claimed_entry_point": obs.claimed_entry_point, "visible_packets": packets, } return json.dumps(summary, separators=(",", ":")) def parse_action(raw_text: str) -> NetworkForensicsAction: text = raw_text.strip() start = text.find("{") end = text.rfind("}") if start == -1 or end == -1: raise ValueError("model did not return JSON") data = json.loads(text[start : end + 1]) data.pop("metadata", None) for key in ("session_name", "pattern_type", "claimed_entry_point"): if data.get(key) == "": data.pop(key, None) if data.get("packet_ids") == []: data.pop("packet_ids", None) return NetworkForensicsAction(**data) def sanitize_action(action: NetworkForensicsAction) -> NetworkForensicsAction: payload = {"action_type": action.action_type} if action.action_type in {"inspect_packet", "flag_as_suspicious"} and action.packet_id: payload["packet_id"] = action.packet_id elif action.action_type == "group_into_session": if action.session_name: payload["session_name"] = action.session_name if action.packet_ids: payload["packet_ids"] = action.packet_ids elif action.action_type == "tag_pattern": if action.session_name: payload["session_name"] = action.session_name if action.pattern_type: payload["pattern_type"] = action.pattern_type elif action.action_type == "identify_entry_point" and action.claimed_entry_point: payload["claimed_entry_point"] = action.claimed_entry_point return NetworkForensicsAction(**payload) def keyword_to_pattern(payload: str) -> str | None: text = payload.lower() if "slowloris" in text: return "dos_slowloris" if "slowhttptest" in text: return "dos_slowhttptest" if "goldeneye" in text: return "dos_goldeneye" if "hulk" in text: return "dos_hulk" if "heartbeat" in text or "tls" in text: return "heartbleed" if "xss" in text or "