Spaces:
Sleeping
Sleeping
| import sys | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional | |
| from uuid import uuid4 | |
| sys.path.insert(0, str(Path(__file__).parent.parent)) | |
| from openenv.core.env_server.interfaces import Environment | |
| from openenv.core.env_server.types import State | |
| from models import ( | |
| NetworkForensicsAction, | |
| NetworkForensicsObservation, | |
| PacketRecord, | |
| Reward, | |
| TaskConfig, | |
| GroundTruth, | |
| ) | |
| from src.pcap_generator import PCAPGenerator | |
| from src.tasks.easy import EasyTask | |
| from src.reward import compute_reward | |
| from src.graph import ConnectionGraph | |
| class NetworkForensicsEnvironment(Environment): | |
| SUPPORTS_CONCURRENT_SESSIONS: bool = True | |
| def __init__(self, task_id: str = "easy"): | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._task_id = task_id | |
| self._packets: list[PacketRecord] = [] | |
| self._ground_truth: Optional[GroundTruth] = None | |
| self._flagged_packets: set = set() | |
| self._grouped_sessions: Dict[str, list] = {} | |
| self._tagged_patterns: Dict[str, str] = {} | |
| self._claimed_entry_point: Optional[str] = None | |
| self._reward_state: Dict[str, Any] = {} | |
| self._current_score: float = 0.0 | |
| self._reward_history: list[float] = [] | |
| self._max_steps: int = 50 | |
| self._connection_graph: ConnectionGraph = ConnectionGraph() | |
| def config(self) -> Dict[str, Any]: | |
| return {"task_id": self._task_id, "max_steps": self._max_steps} | |
| def _build_graph(self) -> None: | |
| """Build the connection graph from all packets.""" | |
| self._connection_graph = ConnectionGraph() | |
| for packet in self._packets: | |
| self._connection_graph.add_packet(packet) | |
| def _get_graph_summary(self) -> Dict[str, Any]: | |
| """Return a compact graph summary for the observation.""" | |
| full_summary = self._connection_graph.get_summary() | |
| # Include top-level stats and top-N nodes/edges to keep payload manageable | |
| top_nodes = sorted( | |
| full_summary.get("nodes", []), | |
| key=lambda n: n.get("packet_count", 0), | |
| reverse=True, | |
| )[:15] | |
| top_edges = sorted( | |
| full_summary.get("edges", []), | |
| key=lambda e: e.get("packet_count", 0), | |
| reverse=True, | |
| )[:20] | |
| return { | |
| "node_count": full_summary.get("node_count", 0), | |
| "edge_count": full_summary.get("edge_count", 0), | |
| "top_talkers": top_nodes, | |
| "top_flows": top_edges, | |
| } | |
| def reset( | |
| self, seed: Optional[int] = None, episode_id: Optional[str] = None, **kwargs: Any | |
| ) -> NetworkForensicsObservation: | |
| requested_task = kwargs.get("task_id") | |
| if requested_task in {"easy", "medium", "hard"}: | |
| self._task_id = requested_task | |
| self._state = State( | |
| episode_id=episode_id or str(uuid4()), | |
| step_count=0, | |
| ) | |
| if self._task_id == "medium": | |
| from src.tasks.medium import MediumTask | |
| task = MediumTask() | |
| elif self._task_id == "hard": | |
| from src.tasks.hard import HardTask | |
| task = HardTask() | |
| else: | |
| task = EasyTask() | |
| config = task.get_config() | |
| if hasattr(task, 'get_annotation'): | |
| self._annotation = task.get_annotation() | |
| generator = PCAPGenerator(config, self._annotation) | |
| else: | |
| self._annotation = {} | |
| generator = PCAPGenerator(config) | |
| self._packets, self._ground_truth = generator.generate(seed=seed or config.seed) | |
| self._flagged_packets = set() | |
| self._grouped_sessions = {} | |
| self._tagged_patterns = {} | |
| self._claimed_entry_point = None | |
| self._reward_state = {} | |
| self._current_score = 0.0 | |
| self._reward_history = [] | |
| self._max_steps = config.max_steps | |
| # Build the connection graph from all packets | |
| self._build_graph() | |
| visible = [ | |
| PacketRecord( | |
| packet_id=p.packet_id, | |
| timestamp=p.timestamp, | |
| src_ip=p.src_ip, | |
| dst_ip=p.dst_ip, | |
| src_port=p.src_port, | |
| dst_port=p.dst_port, | |
| protocol=p.protocol, | |
| payload_size=p.payload_size, | |
| ttl=p.ttl, | |
| flags=p.flags, | |
| is_revealed=False, | |
| payload_preview=p.payload_preview, | |
| full_payload=p.full_payload if p.is_revealed else None, | |
| ) | |
| for p in self._packets | |
| ] | |
| return NetworkForensicsObservation( | |
| step_number=0, | |
| steps_remaining=self._max_steps, | |
| total_packets=len(self._packets), | |
| visible_packets=visible, | |
| flagged_packet_ids=[], | |
| grouped_sessions={}, | |
| tagged_patterns={}, | |
| claimed_entry_point=None, | |
| connection_graph_summary=self._get_graph_summary(), | |
| current_score_estimate=0.0, | |
| final_metrics={}, | |
| done=False, | |
| reward=0.0, | |
| ) | |
| def step( | |
| self, action: NetworkForensicsAction, timeout_s: Optional[float] = None, **kwargs: Any | |
| ) -> NetworkForensicsObservation: | |
| self._state.step_count += 1 | |
| action_result = compute_reward( | |
| action=action, | |
| packets=self._packets, | |
| ground_truth=self._ground_truth, | |
| flagged_packets=self._flagged_packets, | |
| grouped_sessions=self._grouped_sessions, | |
| tagged_patterns=self._tagged_patterns, | |
| reward_state=self._reward_state, | |
| task_id=self._task_id, | |
| ) | |
| if action.action_type == "flag_as_suspicious" and action.packet_id: | |
| self._flagged_packets.add(action.packet_id) | |
| # Mark the node as flagged in the connection graph | |
| packet_map = {p.packet_id: p for p in self._packets} | |
| pkt = packet_map.get(action.packet_id) | |
| if pkt: | |
| for ip in (pkt.src_ip, pkt.dst_ip): | |
| if ip in self._connection_graph._node_attributes: | |
| self._connection_graph._node_attributes[ip]["flagged"] = True | |
| elif action.action_type == "group_into_session": | |
| if action.session_name and action.packet_ids: | |
| self._grouped_sessions[action.session_name] = action.packet_ids | |
| elif action.action_type == "tag_pattern": | |
| if action.session_name and action.pattern_type: | |
| self._tagged_patterns[action.session_name] = action.pattern_type | |
| elif action.action_type == "identify_entry_point": | |
| self._claimed_entry_point = action.claimed_entry_point | |
| self._reward_history.append(action_result.step_reward) | |
| self._current_score = sum(self._reward_history) / len(self._reward_history) | |
| visible = [ | |
| PacketRecord( | |
| packet_id=p.packet_id, | |
| timestamp=p.timestamp, | |
| src_ip=p.src_ip, | |
| dst_ip=p.dst_ip, | |
| src_port=p.src_port, | |
| dst_port=p.dst_port, | |
| protocol=p.protocol, | |
| payload_size=p.payload_size, | |
| ttl=p.ttl, | |
| flags=p.flags, | |
| is_revealed=p.is_revealed, | |
| payload_preview=p.payload_preview, | |
| full_payload=p.full_payload if p.is_revealed else None, | |
| ) | |
| for p in self._packets | |
| ] | |
| done = ( | |
| action.action_type == "submit_report" | |
| or self._state.step_count >= self._max_steps | |
| ) | |
| return NetworkForensicsObservation( | |
| step_number=self._state.step_count, | |
| steps_remaining=max(0, self._max_steps - self._state.step_count), | |
| total_packets=len(self._packets), | |
| visible_packets=visible, | |
| flagged_packet_ids=list(self._flagged_packets), | |
| grouped_sessions=self._grouped_sessions, | |
| tagged_patterns=self._tagged_patterns, | |
| claimed_entry_point=self._claimed_entry_point, | |
| connection_graph_summary=self._get_graph_summary(), | |
| current_score_estimate=self._current_score, | |
| final_metrics=action_result.breakdown, | |
| done=done, | |
| reward=action_result.step_reward, | |
| metadata=action_result.breakdown, | |
| ) | |
| def state(self) -> State: | |
| return self._state | |