WHOAM-EYE's picture
Upload folder using huggingface_hub
4440ec1 verified
import json
import os
import sys
import asyncio
import inspect
from pathlib import Path
from typing import Any
from dotenv import load_dotenv
from openai import OpenAI
from openenv.core.containers.runtime.providers import LocalDockerProvider
sys.path.insert(0, str(Path(__file__).parent))
from client import NetworkForensicsEnv
from models import NetworkForensicsAction
from server.network_forensics_environment import NetworkForensicsEnvironment
load_dotenv(Path(__file__).parent / ".env")
API_BASE_URL = os.getenv("API_BASE_URL")
MODEL_NAME = os.getenv("MODEL_NAME")
API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN")
LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME", "network-forensics-env:latest")
ENV_MODE = (os.getenv("NETWORK_FORENSICS_ENV_MODE") or os.getenv("ENV_MODE") or "docker").lower()
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://localhost:8000")
DOCKER_READY_TIMEOUT_S = float(os.getenv("DOCKER_READY_TIMEOUT_S", "120"))
_ASYNC_LOOP: asyncio.AbstractEventLoop | None = None
SYSTEM_PROMPT = """You are a network forensics analyst operating in an RL environment.
Choose exactly one next action using this JSON schema:
{"action_type":"inspect_packet|flag_as_suspicious|group_into_session|tag_pattern|identify_entry_point|submit_report","packet_id":"pkt_0001","packet_ids":["pkt_0001","pkt_0002"],"session_name":"name","pattern_type":"ddos","claimed_entry_point":"pkt_0001"}
Rules:
- Return JSON only.
- Prefer inspecting packets with suspicious payload previews, HTTP attack strings, DDoS bursts, or repeated unusual destinations.
- Flag packets only after some evidence.
- Group packets into a session only when they share the same src_ip, dst_ip, dst_port, and likely role.
- Tag patterns using labels like ddos, web_bruteforce, web_xss, web_sql_injection, dos_hulk, dos_goldeneye, dos_slowloris, dos_slowhttptest, heartbleed.
- Identify the entry point only when you have a strong guess.
- Submit the report when you have already flagged multiple suspicious packets and created at least one session."""
def build_client() -> OpenAI:
return OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
def validate_config() -> None:
missing = []
if not API_BASE_URL:
missing.append("API_BASE_URL")
if not MODEL_NAME:
missing.append("MODEL_NAME")
if not API_KEY:
missing.append("API_KEY")
if missing:
raise RuntimeError(f"Missing required environment variables: {', '.join(missing)}")
if ENV_MODE not in {"local", "server", "docker"}:
raise RuntimeError("NETWORK_FORENSICS_ENV_MODE must be one of: local, server, docker")
def format_action(action: NetworkForensicsAction) -> str:
payload = action.model_dump(exclude_none=True, exclude_defaults=True)
payload.pop("metadata", None)
payload = {
key: value
for key, value in payload.items()
if value not in ("", [], {})
}
return json.dumps(payload, separators=(",", ":"))
def summarize_observation(obs: Any) -> str:
packets = []
for packet in obs.visible_packets[:25]:
packets.append(
{
"packet_id": packet.packet_id,
"src_ip": packet.src_ip,
"dst_ip": packet.dst_ip,
"dst_port": packet.dst_port,
"protocol": packet.protocol,
"ttl": packet.ttl,
"payload_size": packet.payload_size,
"payload_preview": packet.payload_preview,
"revealed_payload": packet.full_payload if packet.is_revealed else None,
}
)
summary = {
"step_number": obs.step_number,
"steps_remaining": obs.steps_remaining,
"current_score_estimate": obs.current_score_estimate,
"total_packets": obs.total_packets,
"flagged_packet_ids": obs.flagged_packet_ids,
"grouped_sessions": obs.grouped_sessions,
"tagged_patterns": obs.tagged_patterns,
"claimed_entry_point": obs.claimed_entry_point,
"visible_packets": packets,
}
return json.dumps(summary, separators=(",", ":"))
def parse_action(raw_text: str) -> NetworkForensicsAction:
text = raw_text.strip()
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1:
raise ValueError("model did not return JSON")
data = json.loads(text[start : end + 1])
data.pop("metadata", None)
for key in ("session_name", "pattern_type", "claimed_entry_point"):
if data.get(key) == "":
data.pop(key, None)
if data.get("packet_ids") == []:
data.pop("packet_ids", None)
return NetworkForensicsAction(**data)
def sanitize_action(action: NetworkForensicsAction) -> NetworkForensicsAction:
payload = {"action_type": action.action_type}
if action.action_type in {"inspect_packet", "flag_as_suspicious"} and action.packet_id:
payload["packet_id"] = action.packet_id
elif action.action_type == "group_into_session":
if action.session_name:
payload["session_name"] = action.session_name
if action.packet_ids:
payload["packet_ids"] = action.packet_ids
elif action.action_type == "tag_pattern":
if action.session_name:
payload["session_name"] = action.session_name
if action.pattern_type:
payload["pattern_type"] = action.pattern_type
elif action.action_type == "identify_entry_point" and action.claimed_entry_point:
payload["claimed_entry_point"] = action.claimed_entry_point
return NetworkForensicsAction(**payload)
def keyword_to_pattern(payload: str) -> str | None:
text = payload.lower()
if "slowloris" in text:
return "dos_slowloris"
if "slowhttptest" in text:
return "dos_slowhttptest"
if "goldeneye" in text:
return "dos_goldeneye"
if "hulk" in text:
return "dos_hulk"
if "heartbeat" in text or "tls" in text:
return "heartbleed"
if "xss" in text or "<script>" in text:
return "web_xss"
if "or 1=1" in text or "sql" in text:
return "web_sql_injection"
if "login" in text or "username=admin" in text:
return "web_bruteforce"
if "flood" in text or "burst" in text:
return "ddos"
return None
def packet_signature(packet: Any) -> tuple[str, str, int]:
return (packet.src_ip, packet.dst_ip, packet.dst_port)
def build_fallback_action(task_name: str, obs: Any, agent_state: dict[str, Any]) -> NetworkForensicsAction:
inspected_ids = agent_state.setdefault("inspected_ids", set())
flagged_ids = agent_state.setdefault("flagged_ids", set())
session_map = agent_state.setdefault("sessions", {})
tagged_sessions = agent_state.setdefault("tagged_sessions", set())
claimed_entry = agent_state.setdefault("claimed_entry_point", None)
suspicious_revealed = []
for packet in obs.visible_packets:
payload = packet.full_payload or ""
pattern = keyword_to_pattern(payload) if packet.is_revealed else None
if pattern:
suspicious_revealed.append((packet, pattern))
for packet, _pattern in suspicious_revealed:
if packet.packet_id not in flagged_ids:
flagged_ids.add(packet.packet_id)
return NetworkForensicsAction(
action_type="flag_as_suspicious",
packet_id=packet.packet_id,
)
grouped_candidates: dict[tuple[str, str, int], list[Any]] = {}
for packet, pattern in suspicious_revealed:
key = packet_signature(packet)
grouped_candidates.setdefault(key, []).append((packet, pattern))
for key, items in grouped_candidates.items():
packet_ids = [packet.packet_id for packet, _ in items]
if len(packet_ids) >= 2 and key not in session_map:
session_name = f"{task_name}_session_{len(session_map) + 1:02d}"
session_map[key] = session_name
return NetworkForensicsAction(
action_type="group_into_session",
session_name=session_name,
packet_ids=packet_ids,
)
for key, session_name in session_map.items():
if session_name in tagged_sessions:
continue
packets = grouped_candidates.get(key, [])
if not packets:
continue
pattern = keyword_to_pattern(packets[0][0].full_payload or "")
if pattern:
tagged_sessions.add(session_name)
return NetworkForensicsAction(
action_type="tag_pattern",
session_name=session_name,
pattern_type=pattern,
)
if suspicious_revealed and not claimed_entry:
earliest_packet = min(suspicious_revealed, key=lambda item: item[0].packet_id)[0]
agent_state["claimed_entry_point"] = earliest_packet.packet_id
return NetworkForensicsAction(
action_type="identify_entry_point",
claimed_entry_point=earliest_packet.packet_id,
)
for packet in obs.visible_packets:
if not packet.is_revealed and packet.packet_id not in inspected_ids:
return NetworkForensicsAction(
action_type="inspect_packet",
packet_id=packet.packet_id,
)
ready_to_submit = bool(flagged_ids) and bool(session_map)
if ready_to_submit or obs.steps_remaining <= 3:
return NetworkForensicsAction(action_type="submit_report")
for packet in obs.visible_packets:
if not packet.is_revealed and packet.packet_id not in flagged_ids:
return NetworkForensicsAction(
action_type="inspect_packet",
packet_id=packet.packet_id,
)
return NetworkForensicsAction(action_type="submit_report")
def should_override_action(action: NetworkForensicsAction, obs: Any, agent_state: dict[str, Any]) -> bool:
previous_actions = agent_state.setdefault("previous_actions", [])
inspected_ids = agent_state.setdefault("inspected_ids", set())
flagged_ids = agent_state.setdefault("flagged_ids", set())
tagged_sessions = agent_state.setdefault("tagged_sessions", set())
action_repr = format_action(action)
visible_lookup = {packet.packet_id: packet for packet in obs.visible_packets}
if action.action_type not in {
"inspect_packet",
"flag_as_suspicious",
"group_into_session",
"tag_pattern",
"identify_entry_point",
"submit_report",
}:
return True
if action.action_type == "inspect_packet" and not action.packet_id:
return True
if action.action_type == "inspect_packet" and action.packet_id:
packet = visible_lookup.get(action.packet_id)
if packet is None or packet.is_revealed or action.packet_id in inspected_ids:
return True
if action.action_type == "flag_as_suspicious" and not action.packet_id:
return True
if action.action_type == "flag_as_suspicious" and action.packet_id:
if action.packet_id in flagged_ids:
return True
if action.action_type == "group_into_session" and (not action.session_name or not action.packet_ids):
return True
if action.action_type == "group_into_session" and action.packet_ids:
if len(set(action.packet_ids)) < 2:
return True
if action.action_type == "tag_pattern" and (not action.session_name or not action.pattern_type):
return True
if action.action_type == "tag_pattern" and action.session_name in tagged_sessions:
return True
if action.action_type == "identify_entry_point" and not action.claimed_entry_point:
return True
if action.action_type == "identify_entry_point" and agent_state.get("claimed_entry_point"):
return True
if len(previous_actions) >= 2 and previous_actions[-1] == action_repr and previous_actions[-2] == action_repr:
return True
return False
def choose_action(client: OpenAI, task_name: str, obs: Any, agent_state: dict[str, Any]) -> NetworkForensicsAction:
response = client.chat.completions.create(
model=MODEL_NAME,
temperature=0,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{
"role": "user",
"content": f"task={task_name}\nobservation={summarize_observation(obs)}",
},
],
)
content = response.choices[0].message.content or ""
action = sanitize_action(parse_action(content))
if should_override_action(action, obs, agent_state):
action = build_fallback_action(task_name, obs, agent_state)
agent_state.setdefault("previous_actions", []).append(format_action(action))
return action
def sync_agent_state(obs: Any, agent_state: dict[str, Any]) -> None:
inspected_ids = agent_state.setdefault("inspected_ids", set())
for packet in obs.visible_packets:
if packet.is_revealed:
inspected_ids.add(packet.packet_id)
flagged_ids = agent_state.setdefault("flagged_ids", set())
flagged_ids.update(obs.flagged_packet_ids)
tagged_sessions = agent_state.setdefault("tagged_sessions", set())
tagged_sessions.update(obs.tagged_patterns.keys())
if obs.claimed_entry_point:
agent_state["claimed_entry_point"] = obs.claimed_entry_point
def emit_step(step_number: int, action: NetworkForensicsAction, reward: float, done: bool, error: str | None) -> None:
error_text = error if error is not None else "null"
done_text = str(done).lower()
print(
f"[STEP] step={step_number} action={format_action(action)} "
f"reward={reward:.2f} done={done_text} error={error_text}"
)
def normalize_score(score: float) -> float:
return max(0.0, min(1.0, score))
class ExtendedWaitDockerProvider(LocalDockerProvider):
def wait_for_ready(self, base_url: str, timeout_s: float = 30.0) -> None:
super().wait_for_ready(base_url, timeout_s=DOCKER_READY_TIMEOUT_S)
def get_async_loop() -> asyncio.AbstractEventLoop:
global _ASYNC_LOOP
if _ASYNC_LOOP is None or _ASYNC_LOOP.is_closed():
_ASYNC_LOOP = asyncio.new_event_loop()
return _ASYNC_LOOP
def resolve_maybe_awaitable(value: Any) -> Any:
if inspect.isawaitable(value):
return get_async_loop().run_until_complete(value)
return value
def create_env(task_name: str) -> Any:
if ENV_MODE == "docker":
provider = ExtendedWaitDockerProvider()
return resolve_maybe_awaitable(
NetworkForensicsEnv.from_docker_image(LOCAL_IMAGE_NAME, provider=provider)
)
if ENV_MODE == "server":
return NetworkForensicsEnv(base_url=ENV_BASE_URL)
return NetworkForensicsEnvironment(task_id=task_name)
def reset_env(env: Any, task_name: str) -> Any:
if isinstance(env, NetworkForensicsEnvironment):
return env.reset()
result = resolve_maybe_awaitable(env.reset(task_id=task_name))
return result.observation
def step_env(env: Any, action: NetworkForensicsAction) -> Any:
if isinstance(env, NetworkForensicsEnvironment):
return env.step(action)
result = resolve_maybe_awaitable(env.step(action))
return result.observation
def close_env(env: Any) -> None:
if env is None:
return
try:
resolve_maybe_awaitable(env.close())
except Exception:
pass
def close_async_loop() -> None:
global _ASYNC_LOOP
if _ASYNC_LOOP is not None and not _ASYNC_LOOP.is_closed():
_ASYNC_LOOP.close()
_ASYNC_LOOP = None
def run_task(task_name: str) -> None:
env = None
rewards: list[float] = []
final_steps = 0
final_score = 0.0
success = False
agent_state: dict[str, Any] = {}
client = build_client()
print(f"[START] task={task_name} env=network_forensics model={MODEL_NAME}")
try:
env = create_env(task_name)
obs = reset_env(env, task_name)
sync_agent_state(obs, agent_state)
max_steps = getattr(env, "_max_steps", 50)
if not max_steps:
max_steps = obs.steps_remaining or 50
for _ in range(max_steps):
if obs.done:
break
error = None
try:
action = choose_action(client, task_name, obs, agent_state)
except Exception as exc:
error = str(exc).replace("\n", " ")
action = build_fallback_action(task_name, obs, agent_state)
obs = step_env(env, action)
sync_agent_state(obs, agent_state)
rewards.append(float(obs.reward or 0.0))
final_steps = obs.step_number
final_score = normalize_score(obs.metadata.get("final_score", obs.current_score_estimate))
emit_step(obs.step_number, action, float(obs.reward or 0.0), bool(obs.done), error)
if obs.done:
break
success = bool(obs.done and final_score >= 0.6)
except Exception:
success = False
raise
finally:
close_env(env)
rewards_text = ",".join(f"{reward:.2f}" for reward in rewards)
print(
f"[END] success={str(success).lower()} steps={final_steps} "
f"score={final_score:.2f} rewards={rewards_text}"
)
def main() -> None:
validate_config()
try:
for task_name in ("easy", "medium", "hard"):
run_task(task_name)
finally:
close_async_loop()
if __name__ == "__main__":
main()