Spaces:
Sleeping
Sleeping
| import json | |
| from pathlib import Path | |
| from typing import List, Tuple, Dict, Any | |
| from scapy.all import rdpcap, IP, TCP, UDP, ICMP, DNS, Raw | |
| from faker import Faker | |
| from models import PacketRecord, TaskConfig, GroundTruth | |
| fake = Faker() | |
| def parse_packets(pcap_path: str) -> Tuple[List[PacketRecord], GroundTruth]: | |
| packets = [] | |
| ground_truth = GroundTruth( | |
| malicious_packets=[], | |
| packet_roles={}, | |
| sessions={}, | |
| session_roles={}, | |
| entry_point=None, | |
| ) | |
| try: | |
| scapy_packets = rdpcap(pcap_path) | |
| except FileNotFoundError: | |
| return packets, ground_truth | |
| except Exception as e: | |
| print(f"Error reading PCAP: {e}") | |
| return packets, ground_truth | |
| for idx, pkt in enumerate(scapy_packets): | |
| if IP not in pkt: | |
| continue | |
| ip_layer = pkt[IP] | |
| src_ip = ip_layer.src | |
| dst_ip = ip_layer.dst | |
| src_port = 0 | |
| dst_port = 0 | |
| protocol = "OTHER" | |
| flags = [] | |
| if TCP in pkt: | |
| protocol = "TCP" | |
| tcp_layer = pkt[TCP] | |
| src_port = tcp_layer.sport | |
| dst_port = tcp_layer.dport | |
| flags = [] | |
| if tcp_layer.flags & 0x02: | |
| flags.append("SYN") | |
| if tcp_layer.flags & 0x10: | |
| flags.append("ACK") | |
| if tcp_layer.flags & 0x01: | |
| flags.append("FIN") | |
| if tcp_layer.flags & 0x04: | |
| flags.append("RST") | |
| if tcp_layer.flags & 0x08: | |
| flags.append("PSH") | |
| elif UDP in pkt: | |
| protocol = "UDP" | |
| udp_layer = pkt[UDP] | |
| src_port = udp_layer.sport | |
| dst_port = udp_layer.dport | |
| elif ICMP in pkt: | |
| protocol = "ICMP" | |
| elif DNS in pkt: | |
| protocol = "DNS" | |
| dst_port = 53 | |
| raw_payload = b"" | |
| if Raw in pkt: | |
| raw_payload = bytes(pkt[Raw].load) | |
| elif bytes(ip_layer.payload): | |
| raw_payload = bytes(ip_layer.payload) | |
| payload_size = len(ip_layer.payload) if ip_layer else 0 | |
| payload_preview = "" | |
| full_payload = None | |
| if raw_payload[:20]: | |
| payload_preview = raw_payload[:20].hex()[:40] | |
| try: | |
| full_payload = raw_payload.decode("utf-8", errors="replace") | |
| except Exception: | |
| full_payload = raw_payload.hex() | |
| packets.append(PacketRecord( | |
| packet_id=f"pkt_{idx+1:04d}", | |
| timestamp=float(pkt.time) if hasattr(pkt, 'time') else float(idx), | |
| src_ip=src_ip, | |
| dst_ip=dst_ip, | |
| src_port=src_port, | |
| dst_port=dst_port, | |
| protocol=protocol, | |
| payload_size=payload_size, | |
| ttl=ip_layer.ttl if hasattr(ip_layer, 'ttl') else 64, | |
| flags=flags, | |
| is_revealed=False, | |
| payload_preview=payload_preview, | |
| full_payload=full_payload, | |
| )) | |
| return packets, ground_truth | |
| def load_task_annotation(annotation_path: str) -> Dict[str, Any]: | |
| path = Path(annotation_path) | |
| if not path.exists(): | |
| return {} | |
| with open(path, 'r') as f: | |
| return json.load(f) | |
| class RealPCAPGenerator: | |
| def __init__(self, config: TaskConfig, annotation: Dict[str, Any]): | |
| self.config = config | |
| self.annotation = annotation | |
| self.pcap_file = annotation.get("pcap_file", "") | |
| def generate(self, seed: int = None) -> Tuple[List[PacketRecord], GroundTruth]: | |
| if not self.pcap_file: | |
| return [], GroundTruth() | |
| base_dir = Path(__file__).parent.parent / "pcaps" | |
| pcap_path = base_dir / self.pcap_file | |
| packets, ground_truth = parse_packets(str(pcap_path)) | |
| malicious_ids = [self._normalize_packet_id(pid) for pid in self.annotation.get("malicious_packets", [])] | |
| packet_roles = { | |
| self._normalize_packet_id(pid): role | |
| for pid, role in self.annotation.get("packet_roles", {}).items() | |
| } | |
| sessions = { | |
| session_name: [self._normalize_packet_id(pid) for pid in packet_ids] | |
| for session_name, packet_ids in self.annotation.get("sessions", {}).items() | |
| } | |
| session_roles = { | |
| session_name: role | |
| for session_name, role in self.annotation.get("session_roles", {}).items() | |
| } | |
| ground_truth.malicious_packets = malicious_ids | |
| ground_truth.packet_roles = packet_roles | |
| ground_truth.sessions = sessions | |
| ground_truth.session_roles = session_roles | |
| entry_point = self.annotation.get("entry_point") | |
| ground_truth.entry_point = self._normalize_packet_id(entry_point) if entry_point else None | |
| packet_lookup = {packet.packet_id: packet for packet in packets} | |
| for packet_id in malicious_ids: | |
| packet = packet_lookup.get(packet_id) | |
| if packet: | |
| packet.is_malicious = True | |
| packet.attack_role = packet_roles.get(packet_id) | |
| return packets, ground_truth | |
| def _normalize_packet_id(value: Any) -> str: | |
| text = str(value) | |
| if text.startswith("pkt_"): | |
| return text | |
| if text.isdigit(): | |
| return f"pkt_{int(text):04d}" | |
| return text | |
| class PCAPGenerator: | |
| def __init__(self, config: TaskConfig, annotation: Dict[str, Any] = None): | |
| self.config = config | |
| self.annotation = annotation or {} | |
| def generate(self, seed: int = None) -> Tuple[List[PacketRecord], GroundTruth]: | |
| pcap_file = getattr(self.config, 'pcap_file', None) | |
| if pcap_file: | |
| annotation_path = Path(__file__).parent.parent / "pcaps" / f"{pcap_file}.json" | |
| self.annotation = load_task_annotation(str(annotation_path)) | |
| return RealPCAPGenerator(self.config, self.annotation).generate(seed) | |
| rng = __import__('random').Random(seed or self.config.seed) | |
| packets = [] | |
| ground_truth = GroundTruth( | |
| malicious_packets=[], | |
| packet_roles={}, | |
| sessions={}, | |
| session_roles={}, | |
| entry_point=None, | |
| ) | |
| attacker_ip = f"10.{rng.randint(1, 254)}.{rng.randint(1, 254)}.{rng.randint(1, 254)}" | |
| target_network = "192.168.1" | |
| target_ip = f"{target_network}.{rng.randint(1, 254)}" | |
| scan_count = int(self.config.total_packets * 0.1) | |
| for i in range(scan_count): | |
| pkt_id = f"pkt_{i+1:04d}" | |
| packets.append(PacketRecord( | |
| packet_id=pkt_id, | |
| timestamp=1000.0 + i * 0.001, | |
| src_ip=attacker_ip, | |
| dst_ip=f"{target_network}.{i+1}", | |
| src_port=rng.randint(40000, 60000), | |
| dst_port=i + 1, | |
| protocol="TCP", | |
| payload_size=0, | |
| ttl=rng.randint(32, 64), | |
| flags=["SYN"], | |
| is_revealed=False, | |
| payload_preview="", | |
| is_malicious=True, | |
| attack_role="scan", | |
| )) | |
| ground_truth.malicious_packets.append(pkt_id) | |
| ground_truth.packet_roles[pkt_id] = "scan" | |
| ground_truth.scan_packets.append(pkt_id) | |
| if i == 0: | |
| ground_truth.entry_point = pkt_id | |
| c2_count = int(self.config.total_packets * 0.3) | |
| c2_port = 4444 | |
| for i in range(c2_count): | |
| pkt_id = f"pkt_{scan_count + i+1:04d}" | |
| packets.append(PacketRecord( | |
| packet_id=pkt_id, | |
| timestamp=1001.0 + i * 1.0, | |
| src_ip=attacker_ip, | |
| dst_ip=target_ip, | |
| src_port=rng.randint(40000, 60000), | |
| dst_port=c2_port, | |
| protocol="TCP", | |
| payload_size=rng.randint(32, 128), | |
| ttl=128, | |
| flags=["PSH", "ACK"], | |
| is_revealed=False, | |
| payload_preview=fake.sha256()[:20], | |
| is_malicious=True, | |
| attack_role="c2", | |
| )) | |
| ground_truth.malicious_packets.append(pkt_id) | |
| ground_truth.packet_roles[pkt_id] = "c2" | |
| noise_count = int(self.config.total_packets * 0.6) | |
| base_idx = scan_count + c2_count | |
| for i in range(noise_count): | |
| protocol = rng.choice(["TCP", "UDP", "DNS", "HTTPS"]) | |
| pkt_id = f"pkt_{base_idx + i+1:04d}" | |
| if protocol == "DNS": | |
| dst_ip = "8.8.8.8" | |
| dst_port = 53 | |
| elif protocol == "HTTPS": | |
| dst_ip = fake.ipv4() | |
| dst_port = 443 | |
| else: | |
| dst_ip = target_ip | |
| dst_port = rng.choice([80, 443, 445, 3389]) | |
| packets.append(PacketRecord( | |
| packet_id=pkt_id, | |
| timestamp=1000.0 + rng.uniform(0, 100), | |
| src_ip=target_ip if rng.random() > 0.3 else fake.ipv4(), | |
| dst_ip=dst_ip, | |
| src_port=rng.randint(40000, 60000), | |
| dst_port=dst_port, | |
| protocol=protocol, | |
| payload_size=rng.randint(40, 1500), | |
| ttl=rng.choice([64, 128, 255]), | |
| flags=[], | |
| is_revealed=False, | |
| payload_preview="", | |
| is_malicious=False, | |
| )) | |
| packets.sort(key=lambda p: p.timestamp) | |
| for i, p in enumerate(packets): | |
| p.packet_id = f"pkt_{i+1:04d}" | |
| return packets, ground_truth | |