network_forensics / server /network_forensics_environment.py
WHOAM-EYE's picture
Upload folder using huggingface_hub
d9ac8a7 verified
import sys
from pathlib import Path
from typing import Any, Dict, Optional
from uuid import uuid4
sys.path.insert(0, str(Path(__file__).parent.parent))
from openenv.core.env_server.interfaces import Environment
from openenv.core.env_server.types import State
from models import (
NetworkForensicsAction,
NetworkForensicsObservation,
PacketRecord,
Reward,
TaskConfig,
GroundTruth,
)
from src.pcap_generator import PCAPGenerator
from src.tasks.easy import EasyTask
from src.reward import compute_reward
from src.graph import ConnectionGraph
class NetworkForensicsEnvironment(Environment):
SUPPORTS_CONCURRENT_SESSIONS: bool = True
def __init__(self, task_id: str = "easy"):
self._state = State(episode_id=str(uuid4()), step_count=0)
self._task_id = task_id
self._packets: list[PacketRecord] = []
self._ground_truth: Optional[GroundTruth] = None
self._flagged_packets: set = set()
self._grouped_sessions: Dict[str, list] = {}
self._tagged_patterns: Dict[str, str] = {}
self._claimed_entry_point: Optional[str] = None
self._reward_state: Dict[str, Any] = {}
self._current_score: float = 0.0
self._reward_history: list[float] = []
self._max_steps: int = 50
self._connection_graph: ConnectionGraph = ConnectionGraph()
def config(self) -> Dict[str, Any]:
return {"task_id": self._task_id, "max_steps": self._max_steps}
def _build_graph(self) -> None:
"""Build the connection graph from all packets."""
self._connection_graph = ConnectionGraph()
for packet in self._packets:
self._connection_graph.add_packet(packet)
def _get_graph_summary(self) -> Dict[str, Any]:
"""Return a compact graph summary for the observation."""
full_summary = self._connection_graph.get_summary()
# Include top-level stats and top-N nodes/edges to keep payload manageable
top_nodes = sorted(
full_summary.get("nodes", []),
key=lambda n: n.get("packet_count", 0),
reverse=True,
)[:15]
top_edges = sorted(
full_summary.get("edges", []),
key=lambda e: e.get("packet_count", 0),
reverse=True,
)[:20]
return {
"node_count": full_summary.get("node_count", 0),
"edge_count": full_summary.get("edge_count", 0),
"top_talkers": top_nodes,
"top_flows": top_edges,
}
def reset(
self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any
) -> NetworkForensicsObservation:
requested_task = kwargs.get("task_id")
if requested_task in {"easy", "medium", "hard"}:
self._task_id = requested_task
self._state = State(
episode_id=episode_id or str(uuid4()),
step_count=0,
)
if self._task_id == "medium":
from src.tasks.medium import MediumTask
task = MediumTask()
elif self._task_id == "hard":
from src.tasks.hard import HardTask
task = HardTask()
else:
task = EasyTask()
config = task.get_config()
if hasattr(task, 'get_annotation'):
self._annotation = task.get_annotation()
generator = PCAPGenerator(config, self._annotation)
else:
self._annotation = {}
generator = PCAPGenerator(config)
self._packets, self._ground_truth = generator.generate(seed=seed or config.seed)
self._flagged_packets = set()
self._grouped_sessions = {}
self._tagged_patterns = {}
self._claimed_entry_point = None
self._reward_state = {}
self._current_score = 0.0
self._reward_history = []
self._max_steps = config.max_steps
# Build the connection graph from all packets
self._build_graph()
visible = [
PacketRecord(
packet_id=p.packet_id,
timestamp=p.timestamp,
src_ip=p.src_ip,
dst_ip=p.dst_ip,
src_port=p.src_port,
dst_port=p.dst_port,
protocol=p.protocol,
payload_size=p.payload_size,
ttl=p.ttl,
flags=p.flags,
is_revealed=False,
payload_preview=p.payload_preview,
full_payload=p.full_payload if p.is_revealed else None,
)
for p in self._packets
]
return NetworkForensicsObservation(
step_number=0,
steps_remaining=self._max_steps,
total_packets=len(self._packets),
visible_packets=visible,
flagged_packet_ids=[],
grouped_sessions={},
tagged_patterns={},
claimed_entry_point=None,
connection_graph_summary=self._get_graph_summary(),
current_score_estimate=0.0,
final_metrics={},
done=False,
reward=0.0,
)
def step(
self, action: NetworkForensicsAction, timeout_s: Optional[float] = None, **kwargs: Any
) -> NetworkForensicsObservation:
self._state.step_count += 1
action_result = compute_reward(
action=action,
packets=self._packets,
ground_truth=self._ground_truth,
flagged_packets=self._flagged_packets,
grouped_sessions=self._grouped_sessions,
tagged_patterns=self._tagged_patterns,
reward_state=self._reward_state,
task_id=self._task_id,
)
if action.action_type == "flag_as_suspicious" and action.packet_id:
self._flagged_packets.add(action.packet_id)
# Mark the node as flagged in the connection graph
packet_map = {p.packet_id: p for p in self._packets}
pkt = packet_map.get(action.packet_id)
if pkt:
for ip in (pkt.src_ip, pkt.dst_ip):
if ip in self._connection_graph._node_attributes:
self._connection_graph._node_attributes[ip]["flagged"] = True
elif action.action_type == "group_into_session":
if action.session_name and action.packet_ids:
self._grouped_sessions[action.session_name] = action.packet_ids
elif action.action_type == "tag_pattern":
if action.session_name and action.pattern_type:
self._tagged_patterns[action.session_name] = action.pattern_type
elif action.action_type == "identify_entry_point":
self._claimed_entry_point = action.claimed_entry_point
self._reward_history.append(action_result.step_reward)
self._current_score = sum(self._reward_history) / len(self._reward_history)
visible = [
PacketRecord(
packet_id=p.packet_id,
timestamp=p.timestamp,
src_ip=p.src_ip,
dst_ip=p.dst_ip,
src_port=p.src_port,
dst_port=p.dst_port,
protocol=p.protocol,
payload_size=p.payload_size,
ttl=p.ttl,
flags=p.flags,
is_revealed=p.is_revealed,
payload_preview=p.payload_preview,
full_payload=p.full_payload if p.is_revealed else None,
)
for p in self._packets
]
done = (
action.action_type == "submit_report"
or self._state.step_count >= self._max_steps
)
return NetworkForensicsObservation(
step_number=self._state.step_count,
steps_remaining=max(0, self._max_steps - self._state.step_count),
total_packets=len(self._packets),
visible_packets=visible,
flagged_packet_ids=list(self._flagged_packets),
grouped_sessions=self._grouped_sessions,
tagged_patterns=self._tagged_patterns,
claimed_entry_point=self._claimed_entry_point,
connection_graph_summary=self._get_graph_summary(),
current_score_estimate=self._current_score,
final_metrics=action_result.breakdown,
done=done,
reward=action_result.step_reward,
metadata=action_result.breakdown,
)
@property
def state(self) -> State:
return self._state