""" IP Parser/Assembler Module Handles IPv4 packet parsing and construction: - Parse IPv4, UDP, and TCP headers - Calculate and verify checksums - Handle packet fragmentation and reassembly - Support various IP options """ import struct import socket from typing import Dict, List, Optional, Tuple from dataclasses import dataclass from enum import Enum class IPProtocol(Enum): ICMP = 1 TCP = 6 UDP = 17 @dataclass class IPv4Header: """IPv4 header structure""" version: int = 4 ihl: int = 5 # Internet Header Length (in 32-bit words) tos: int = 0 # Type of Service total_length: int = 0 identification: int = 0 flags: int = 0 # 3 bits: Reserved, Don't Fragment, More Fragments fragment_offset: int = 0 # 13 bits ttl: int = 64 # Time to Live protocol: int = 0 header_checksum: int = 0 source_ip: str = '0.0.0.0' dest_ip: str = '0.0.0.0' options: bytes = b'' @property def header_length(self) -> int: """Get header length in bytes""" return self.ihl * 4 @property def dont_fragment(self) -> bool: """Check if Don't Fragment flag is set""" return bool(self.flags & 0x2) @property def more_fragments(self) -> bool: """Check if More Fragments flag is set""" return bool(self.flags & 0x1) @property def is_fragment(self) -> bool: """Check if this is a fragment""" return self.more_fragments or self.fragment_offset > 0 @dataclass class TCPHeader: """TCP header structure""" source_port: int = 0 dest_port: int = 0 seq_num: int = 0 ack_num: int = 0 data_offset: int = 5 # Header length in 32-bit words reserved: int = 0 flags: int = 0 # 9 bits: NS, CWR, ECE, URG, ACK, PSH, RST, SYN, FIN window_size: int = 65535 checksum: int = 0 urgent_pointer: int = 0 options: bytes = b'' @property def header_length(self) -> int: """Get header length in bytes""" return self.data_offset * 4 # TCP Flag properties @property def fin(self) -> bool: return bool(self.flags & 0x01) @property def syn(self) -> bool: return bool(self.flags & 0x02) @property def rst(self) -> bool: return bool(self.flags & 0x04) @property def psh(self) -> bool: return bool(self.flags & 0x08) @property def ack(self) -> bool: return bool(self.flags & 0x10) @property def urg(self) -> bool: return bool(self.flags & 0x20) def set_flag(self, flag_name: str, value: bool = True): """Set TCP flag""" flag_bits = { 'fin': 0x01, 'syn': 0x02, 'rst': 0x04, 'psh': 0x08, 'ack': 0x10, 'urg': 0x20, 'ece': 0x40, 'cwr': 0x80, 'ns': 0x100 } if flag_name.lower() in flag_bits: bit = flag_bits[flag_name.lower()] if value: self.flags |= bit else: self.flags &= ~bit @dataclass class UDPHeader: """UDP header structure""" source_port: int = 0 dest_port: int = 0 length: int = 8 # Header + data length checksum: int = 0 @property def header_length(self) -> int: """Get header length in bytes (always 8 for UDP)""" return 8 @dataclass class ParsedPacket: """Parsed packet structure""" ip_header: IPv4Header transport_header: Optional[object] = None # TCPHeader or UDPHeader payload: bytes = b'' raw_packet: bytes = b'' class IPParser: """IPv4 packet parser and assembler""" @staticmethod def calculate_checksum(data: bytes) -> int: """Calculate Internet checksum""" # Pad data to even length if len(data) % 2: data += b'\x00' checksum = 0 for i in range(0, len(data), 2): word = (data[i] << 8) + data[i + 1] checksum += word # Add carry bits while checksum >> 16: checksum = (checksum & 0xFFFF) + (checksum >> 16) # One's complement return (~checksum) & 0xFFFF @staticmethod def verify_checksum(data: bytes, checksum: int) -> bool: """Verify Internet checksum""" calculated = IPParser.calculate_checksum(data) return calculated == checksum or (calculated + checksum) == 0xFFFF @classmethod def parse_ipv4_header(cls, data: bytes) -> Tuple[IPv4Header, int]: """Parse IPv4 header from raw bytes""" if len(data) < 20: raise ValueError("IPv4 header too short") # Parse fixed part of header header_data = struct.unpack('!BBHHHBBH4s4s', data[:20]) header = IPv4Header() version_ihl = header_data[0] header.version = (version_ihl >> 4) & 0xF header.ihl = version_ihl & 0xF header.tos = header_data[1] header.total_length = header_data[2] header.identification = header_data[3] flags_fragment = header_data[4] header.flags = (flags_fragment >> 13) & 0x7 header.fragment_offset = flags_fragment & 0x1FFF header.ttl = header_data[5] header.protocol = header_data[6] header.header_checksum = header_data[7] header.source_ip = socket.inet_ntoa(header_data[8]) header.dest_ip = socket.inet_ntoa(header_data[9]) # Validate version if header.version != 4: raise ValueError(f"Unsupported IP version: {header.version}") # Parse options if present options_length = header.header_length - 20 if options_length > 0: if len(data) < 20 + options_length: raise ValueError("IPv4 options truncated") header.options = data[20:20 + options_length] return header, header.header_length @classmethod def parse_tcp_header(cls, data: bytes) -> Tuple[TCPHeader, int]: """Parse TCP header from raw bytes""" if len(data) < 20: raise ValueError("TCP header too short") # Parse fixed part of header header_data = struct.unpack('!HHIIBBHHH', data[:20]) header = TCPHeader() header.source_port = header_data[0] header.dest_port = header_data[1] header.seq_num = header_data[2] header.ack_num = header_data[3] offset_reserved = header_data[4] header.data_offset = (offset_reserved >> 4) & 0xF header.reserved = (offset_reserved >> 1) & 0x7 header.flags = ((offset_reserved & 0x1) << 8) | header_data[5] header.window_size = header_data[6] header.checksum = header_data[7] header.urgent_pointer = header_data[8] # Parse options if present options_length = header.header_length - 20 if options_length > 0: if len(data) < 20 + options_length: raise ValueError("TCP options truncated") header.options = data[20:20 + options_length] return header, header.header_length @classmethod def parse_udp_header(cls, data: bytes) -> Tuple[UDPHeader, int]: """Parse UDP header from raw bytes""" if len(data) < 8: raise ValueError("UDP header too short") header_data = struct.unpack('!HHHH', data[:8]) header = UDPHeader() header.source_port = header_data[0] header.dest_port = header_data[1] header.length = header_data[2] header.checksum = header_data[3] return header, 8 @classmethod def parse_packet(cls, data: bytes) -> ParsedPacket: """Parse complete packet""" packet = ParsedPacket(raw_packet=data) # Parse IP header packet.ip_header, ip_header_len = cls.parse_ipv4_header(data) # Extract payload after IP header ip_payload = data[ip_header_len:packet.ip_header.total_length] # Parse transport layer header if packet.ip_header.protocol == IPProtocol.TCP.value: packet.transport_header, transport_header_len = cls.parse_tcp_header(ip_payload) packet.payload = ip_payload[transport_header_len:] elif packet.ip_header.protocol == IPProtocol.UDP.value: packet.transport_header, transport_header_len = cls.parse_udp_header(ip_payload) packet.payload = ip_payload[transport_header_len:] else: # Unsupported protocol, treat as raw payload packet.payload = ip_payload return packet @classmethod def build_ipv4_header(cls, header: IPv4Header) -> bytes: """Build IPv4 header as bytes""" # Calculate header length including options header.ihl = (20 + len(header.options) + 3) // 4 # Round up to 32-bit boundary # Build header without checksum version_ihl = (header.version << 4) | header.ihl flags_fragment = (header.flags << 13) | header.fragment_offset header_data = struct.pack( '!BBHHHBBH4s4s', version_ihl, header.tos, header.total_length, header.identification, flags_fragment, header.ttl, header.protocol, 0, # Checksum = 0 for calculation socket.inet_aton(header.source_ip), socket.inet_aton(header.dest_ip) ) # Add options and padding if header.options: header_data += header.options # Pad to 32-bit boundary padding_needed = (header.ihl * 4) - len(header_data) if padding_needed > 0: header_data += b'\x00' * padding_needed # Calculate and insert checksum checksum = cls.calculate_checksum(header_data) header_data = header_data[:10] + struct.pack('!H', checksum) + header_data[12:] return header_data @classmethod def build_tcp_header(cls, header: TCPHeader, source_ip: str, dest_ip: str, payload: bytes) -> bytes: """Build TCP header as bytes with checksum""" # Calculate header length including options header.data_offset = (20 + len(header.options) + 3) // 4 # Round up to 32-bit boundary # Build header without checksum offset_reserved_flags = (header.data_offset << 12) | (header.reserved << 9) | header.flags header_data = struct.pack( '!HHIIHHH', header.source_port, header.dest_port, header.seq_num, header.ack_num, offset_reserved_flags, header.window_size, 0, header.urgent_pointer # Checksum = 0 for calculation ) # Add options and padding if header.options: header_data += header.options # Pad to 32-bit boundary padding_needed = (header.data_offset * 4) - len(header_data) if padding_needed > 0: header_data += b'\x00' * padding_needed # Calculate TCP checksum with pseudo-header pseudo_header = struct.pack( '!4s4sBBH', socket.inet_aton(source_ip), socket.inet_aton(dest_ip), 0, IPProtocol.TCP.value, len(header_data) + len(payload) ) checksum_data = pseudo_header + header_data + payload checksum = cls.calculate_checksum(checksum_data) # Insert checksum header_data = header_data[:16] + struct.pack('!H', checksum) + header_data[18:] return header_data @classmethod def build_udp_header(cls, header: UDPHeader, source_ip: str, dest_ip: str, payload: bytes) -> bytes: """Build UDP header as bytes with checksum""" header.length = 8 + len(payload) # Build header without checksum header_data = struct.pack( '!HHHH', header.source_port, header.dest_port, header.length, 0 # Checksum = 0 for calculation ) # Calculate UDP checksum with pseudo-header (optional for IPv4) if header.checksum != 0: # If checksum is required pseudo_header = struct.pack( '!4s4sBBH', socket.inet_aton(source_ip), socket.inet_aton(dest_ip), 0, IPProtocol.UDP.value, header.length ) checksum_data = pseudo_header + header_data + payload checksum = cls.calculate_checksum(checksum_data) # Insert checksum header_data = header_data[:6] + struct.pack('!H', checksum) + header_data[8:] return header_data @classmethod def build_packet(cls, ip_header: IPv4Header, transport_header: Optional[object] = None, payload: bytes = b'') -> bytes: """Build complete packet""" transport_data = b'' # Build transport header if transport_header: if isinstance(transport_header, TCPHeader): transport_data = cls.build_tcp_header( transport_header, ip_header.source_ip, ip_header.dest_ip, payload ) elif isinstance(transport_header, UDPHeader): transport_data = cls.build_udp_header( transport_header, ip_header.source_ip, ip_header.dest_ip, payload ) # Update IP header total length ip_header.total_length = ip_header.header_length + len(transport_data) + len(payload) # Build IP header ip_data = cls.build_ipv4_header(ip_header) # Combine all parts return ip_data + transport_data + payload class PacketFragmenter: """Handle packet fragmentation and reassembly""" def __init__(self, mtu: int = 1500): self.mtu = mtu self.fragments: Dict[Tuple[str, str, int], List[Tuple[int, bytes]]] = {} # (src, dst, id) -> [(offset, data)] def fragment_packet(self, packet: bytes, mtu: int = None) -> List[bytes]: """Fragment a packet if it exceeds MTU""" if mtu is None: mtu = self.mtu if len(packet) <= mtu: return [packet] # Parse original packet parsed = IPParser.parse_packet(packet) ip_header = parsed.ip_header # Don't fragment if DF flag is set if ip_header.dont_fragment: raise ValueError("Packet too large and Don't Fragment flag is set") fragments = [] payload_mtu = mtu - ip_header.header_length payload_mtu = (payload_mtu // 8) * 8 # Must be multiple of 8 bytes # Get the payload to fragment (everything after IP header) payload_start = ip_header.header_length payload = packet[payload_start:] offset = 0 while offset < len(payload): # Create fragment fragment_payload = payload[offset:offset + payload_mtu] # Create new IP header for fragment frag_header = IPv4Header( version=ip_header.version, ihl=ip_header.ihl, tos=ip_header.tos, identification=ip_header.identification, ttl=ip_header.ttl, protocol=ip_header.protocol, source_ip=ip_header.source_ip, dest_ip=ip_header.dest_ip, options=ip_header.options ) # Set fragment offset and flags frag_header.fragment_offset = (ip_header.fragment_offset * 8 + offset) // 8 frag_header.flags = ip_header.flags # Set More Fragments flag if not last fragment if offset + len(fragment_payload) < len(payload): frag_header.flags |= 0x1 # More Fragments else: frag_header.flags &= ~0x1 # Clear More Fragments # Build fragment fragment = IPParser.build_packet(frag_header, payload=fragment_payload) fragments.append(fragment) offset += len(fragment_payload) return fragments def reassemble_packet(self, packet: bytes) -> Optional[bytes]: """Reassemble fragmented packet""" parsed = IPParser.parse_packet(packet) ip_header = parsed.ip_header # If not a fragment, return as-is if not ip_header.is_fragment: return packet # Create fragment key key = (ip_header.source_ip, ip_header.dest_ip, ip_header.identification) # Store fragment if key not in self.fragments: self.fragments[key] = [] payload_start = ip_header.header_length fragment_data = packet[payload_start:] self.fragments[key].append((ip_header.fragment_offset * 8, fragment_data)) # Check if we have all fragments fragments = sorted(self.fragments[key]) # Verify we have contiguous fragments starting from 0 expected_offset = 0 complete_payload = b'' for offset, data in fragments: if offset != expected_offset: return None # Missing fragment complete_payload += data expected_offset += len(data) # Check if last fragment (no More Fragments flag) last_fragment = None for frag_packet in [packet]: # We only have current packet, need to track all frag_parsed = IPParser.parse_packet(frag_packet) if not frag_parsed.ip_header.more_fragments: last_fragment = frag_parsed break if last_fragment is None: return None # Don't have last fragment yet # Reassemble complete packet complete_header = IPv4Header( version=ip_header.version, ihl=ip_header.ihl, tos=ip_header.tos, identification=ip_header.identification, flags=ip_header.flags & ~0x1, # Clear More Fragments fragment_offset=0, ttl=ip_header.ttl, protocol=ip_header.protocol, source_ip=ip_header.source_ip, dest_ip=ip_header.dest_ip, options=ip_header.options ) complete_packet = IPParser.build_packet(complete_header, payload=complete_payload) # Clean up fragments del self.fragments[key] return complete_packet