|
|
""" |
|
|
IP Parser/Assembler Module |
|
|
|
|
|
Handles IPv4 packet parsing and construction: |
|
|
- Parse IPv4, UDP, and TCP headers |
|
|
- Calculate and verify checksums |
|
|
- Handle packet fragmentation and reassembly |
|
|
- Support various IP options |
|
|
""" |
|
|
|
|
|
import struct |
|
|
import socket |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
|
|
|
|
|
|
class IPProtocol(Enum): |
|
|
ICMP = 1 |
|
|
TCP = 6 |
|
|
UDP = 17 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class IPv4Header: |
|
|
"""IPv4 header structure""" |
|
|
version: int = 4 |
|
|
ihl: int = 5 |
|
|
tos: int = 0 |
|
|
total_length: int = 0 |
|
|
identification: int = 0 |
|
|
flags: int = 0 |
|
|
fragment_offset: int = 0 |
|
|
ttl: int = 64 |
|
|
protocol: int = 0 |
|
|
header_checksum: int = 0 |
|
|
source_ip: str = '0.0.0.0' |
|
|
dest_ip: str = '0.0.0.0' |
|
|
options: bytes = b'' |
|
|
|
|
|
@property |
|
|
def header_length(self) -> int: |
|
|
"""Get header length in bytes""" |
|
|
return self.ihl * 4 |
|
|
|
|
|
@property |
|
|
def dont_fragment(self) -> bool: |
|
|
"""Check if Don't Fragment flag is set""" |
|
|
return bool(self.flags & 0x2) |
|
|
|
|
|
@property |
|
|
def more_fragments(self) -> bool: |
|
|
"""Check if More Fragments flag is set""" |
|
|
return bool(self.flags & 0x1) |
|
|
|
|
|
@property |
|
|
def is_fragment(self) -> bool: |
|
|
"""Check if this is a fragment""" |
|
|
return self.more_fragments or self.fragment_offset > 0 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TCPHeader: |
|
|
"""TCP header structure""" |
|
|
source_port: int = 0 |
|
|
dest_port: int = 0 |
|
|
seq_num: int = 0 |
|
|
ack_num: int = 0 |
|
|
data_offset: int = 5 |
|
|
reserved: int = 0 |
|
|
flags: int = 0 |
|
|
window_size: int = 65535 |
|
|
checksum: int = 0 |
|
|
urgent_pointer: int = 0 |
|
|
options: bytes = b'' |
|
|
|
|
|
@property |
|
|
def header_length(self) -> int: |
|
|
"""Get header length in bytes""" |
|
|
return self.data_offset * 4 |
|
|
|
|
|
|
|
|
@property |
|
|
def fin(self) -> bool: |
|
|
return bool(self.flags & 0x01) |
|
|
|
|
|
@property |
|
|
def syn(self) -> bool: |
|
|
return bool(self.flags & 0x02) |
|
|
|
|
|
@property |
|
|
def rst(self) -> bool: |
|
|
return bool(self.flags & 0x04) |
|
|
|
|
|
@property |
|
|
def psh(self) -> bool: |
|
|
return bool(self.flags & 0x08) |
|
|
|
|
|
@property |
|
|
def ack(self) -> bool: |
|
|
return bool(self.flags & 0x10) |
|
|
|
|
|
@property |
|
|
def urg(self) -> bool: |
|
|
return bool(self.flags & 0x20) |
|
|
|
|
|
def set_flag(self, flag_name: str, value: bool = True): |
|
|
"""Set TCP flag""" |
|
|
flag_bits = { |
|
|
'fin': 0x01, 'syn': 0x02, 'rst': 0x04, 'psh': 0x08, |
|
|
'ack': 0x10, 'urg': 0x20, 'ece': 0x40, 'cwr': 0x80, 'ns': 0x100 |
|
|
} |
|
|
|
|
|
if flag_name.lower() in flag_bits: |
|
|
bit = flag_bits[flag_name.lower()] |
|
|
if value: |
|
|
self.flags |= bit |
|
|
else: |
|
|
self.flags &= ~bit |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class UDPHeader: |
|
|
"""UDP header structure""" |
|
|
source_port: int = 0 |
|
|
dest_port: int = 0 |
|
|
length: int = 8 |
|
|
checksum: int = 0 |
|
|
|
|
|
@property |
|
|
def header_length(self) -> int: |
|
|
"""Get header length in bytes (always 8 for UDP)""" |
|
|
return 8 |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ParsedPacket: |
|
|
"""Parsed packet structure""" |
|
|
ip_header: IPv4Header |
|
|
transport_header: Optional[object] = None |
|
|
payload: bytes = b'' |
|
|
raw_packet: bytes = b'' |
|
|
|
|
|
|
|
|
class IPParser: |
|
|
"""IPv4 packet parser and assembler""" |
|
|
|
|
|
@staticmethod |
|
|
def calculate_checksum(data: bytes) -> int: |
|
|
"""Calculate Internet checksum""" |
|
|
|
|
|
if len(data) % 2: |
|
|
data += b'\x00' |
|
|
|
|
|
checksum = 0 |
|
|
for i in range(0, len(data), 2): |
|
|
word = (data[i] << 8) + data[i + 1] |
|
|
checksum += word |
|
|
|
|
|
|
|
|
while checksum >> 16: |
|
|
checksum = (checksum & 0xFFFF) + (checksum >> 16) |
|
|
|
|
|
|
|
|
return (~checksum) & 0xFFFF |
|
|
|
|
|
@staticmethod |
|
|
def verify_checksum(data: bytes, checksum: int) -> bool: |
|
|
"""Verify Internet checksum""" |
|
|
calculated = IPParser.calculate_checksum(data) |
|
|
return calculated == checksum or (calculated + checksum) == 0xFFFF |
|
|
|
|
|
@classmethod |
|
|
def parse_ipv4_header(cls, data: bytes) -> Tuple[IPv4Header, int]: |
|
|
"""Parse IPv4 header from raw bytes""" |
|
|
if len(data) < 20: |
|
|
raise ValueError("IPv4 header too short") |
|
|
|
|
|
|
|
|
header_data = struct.unpack('!BBHHHBBH4s4s', data[:20]) |
|
|
|
|
|
header = IPv4Header() |
|
|
version_ihl = header_data[0] |
|
|
header.version = (version_ihl >> 4) & 0xF |
|
|
header.ihl = version_ihl & 0xF |
|
|
header.tos = header_data[1] |
|
|
header.total_length = header_data[2] |
|
|
header.identification = header_data[3] |
|
|
flags_fragment = header_data[4] |
|
|
header.flags = (flags_fragment >> 13) & 0x7 |
|
|
header.fragment_offset = flags_fragment & 0x1FFF |
|
|
header.ttl = header_data[5] |
|
|
header.protocol = header_data[6] |
|
|
header.header_checksum = header_data[7] |
|
|
header.source_ip = socket.inet_ntoa(header_data[8]) |
|
|
header.dest_ip = socket.inet_ntoa(header_data[9]) |
|
|
|
|
|
|
|
|
if header.version != 4: |
|
|
raise ValueError(f"Unsupported IP version: {header.version}") |
|
|
|
|
|
|
|
|
options_length = header.header_length - 20 |
|
|
if options_length > 0: |
|
|
if len(data) < 20 + options_length: |
|
|
raise ValueError("IPv4 options truncated") |
|
|
header.options = data[20:20 + options_length] |
|
|
|
|
|
return header, header.header_length |
|
|
|
|
|
@classmethod |
|
|
def parse_tcp_header(cls, data: bytes) -> Tuple[TCPHeader, int]: |
|
|
"""Parse TCP header from raw bytes""" |
|
|
if len(data) < 20: |
|
|
raise ValueError("TCP header too short") |
|
|
|
|
|
|
|
|
header_data = struct.unpack('!HHIIBBHHH', data[:20]) |
|
|
|
|
|
header = TCPHeader() |
|
|
header.source_port = header_data[0] |
|
|
header.dest_port = header_data[1] |
|
|
header.seq_num = header_data[2] |
|
|
header.ack_num = header_data[3] |
|
|
offset_reserved = header_data[4] |
|
|
header.data_offset = (offset_reserved >> 4) & 0xF |
|
|
header.reserved = (offset_reserved >> 1) & 0x7 |
|
|
header.flags = ((offset_reserved & 0x1) << 8) | header_data[5] |
|
|
header.window_size = header_data[6] |
|
|
header.checksum = header_data[7] |
|
|
header.urgent_pointer = header_data[8] |
|
|
|
|
|
|
|
|
options_length = header.header_length - 20 |
|
|
if options_length > 0: |
|
|
if len(data) < 20 + options_length: |
|
|
raise ValueError("TCP options truncated") |
|
|
header.options = data[20:20 + options_length] |
|
|
|
|
|
return header, header.header_length |
|
|
|
|
|
@classmethod |
|
|
def parse_udp_header(cls, data: bytes) -> Tuple[UDPHeader, int]: |
|
|
"""Parse UDP header from raw bytes""" |
|
|
if len(data) < 8: |
|
|
raise ValueError("UDP header too short") |
|
|
|
|
|
header_data = struct.unpack('!HHHH', data[:8]) |
|
|
|
|
|
header = UDPHeader() |
|
|
header.source_port = header_data[0] |
|
|
header.dest_port = header_data[1] |
|
|
header.length = header_data[2] |
|
|
header.checksum = header_data[3] |
|
|
|
|
|
return header, 8 |
|
|
|
|
|
@classmethod |
|
|
def parse_packet(cls, data: bytes) -> ParsedPacket: |
|
|
"""Parse complete packet""" |
|
|
packet = ParsedPacket(raw_packet=data) |
|
|
|
|
|
|
|
|
packet.ip_header, ip_header_len = cls.parse_ipv4_header(data) |
|
|
|
|
|
|
|
|
ip_payload = data[ip_header_len:packet.ip_header.total_length] |
|
|
|
|
|
|
|
|
if packet.ip_header.protocol == IPProtocol.TCP.value: |
|
|
packet.transport_header, transport_header_len = cls.parse_tcp_header(ip_payload) |
|
|
packet.payload = ip_payload[transport_header_len:] |
|
|
elif packet.ip_header.protocol == IPProtocol.UDP.value: |
|
|
packet.transport_header, transport_header_len = cls.parse_udp_header(ip_payload) |
|
|
packet.payload = ip_payload[transport_header_len:] |
|
|
else: |
|
|
|
|
|
packet.payload = ip_payload |
|
|
|
|
|
return packet |
|
|
|
|
|
@classmethod |
|
|
def build_ipv4_header(cls, header: IPv4Header) -> bytes: |
|
|
"""Build IPv4 header as bytes""" |
|
|
|
|
|
header.ihl = (20 + len(header.options) + 3) // 4 |
|
|
|
|
|
|
|
|
version_ihl = (header.version << 4) | header.ihl |
|
|
flags_fragment = (header.flags << 13) | header.fragment_offset |
|
|
|
|
|
header_data = struct.pack( |
|
|
'!BBHHHBBH4s4s', |
|
|
version_ihl, header.tos, header.total_length, |
|
|
header.identification, flags_fragment, |
|
|
header.ttl, header.protocol, 0, |
|
|
socket.inet_aton(header.source_ip), |
|
|
socket.inet_aton(header.dest_ip) |
|
|
) |
|
|
|
|
|
|
|
|
if header.options: |
|
|
header_data += header.options |
|
|
|
|
|
padding_needed = (header.ihl * 4) - len(header_data) |
|
|
if padding_needed > 0: |
|
|
header_data += b'\x00' * padding_needed |
|
|
|
|
|
|
|
|
checksum = cls.calculate_checksum(header_data) |
|
|
header_data = header_data[:10] + struct.pack('!H', checksum) + header_data[12:] |
|
|
|
|
|
return header_data |
|
|
|
|
|
@classmethod |
|
|
def build_tcp_header(cls, header: TCPHeader, source_ip: str, dest_ip: str, payload: bytes) -> bytes: |
|
|
"""Build TCP header as bytes with checksum""" |
|
|
|
|
|
header.data_offset = (20 + len(header.options) + 3) // 4 |
|
|
|
|
|
|
|
|
offset_reserved_flags = (header.data_offset << 12) | (header.reserved << 9) | header.flags |
|
|
|
|
|
header_data = struct.pack( |
|
|
'!HHIIHHH', |
|
|
header.source_port, header.dest_port, |
|
|
header.seq_num, header.ack_num, |
|
|
offset_reserved_flags, header.window_size, |
|
|
0, header.urgent_pointer |
|
|
) |
|
|
|
|
|
|
|
|
if header.options: |
|
|
header_data += header.options |
|
|
|
|
|
padding_needed = (header.data_offset * 4) - len(header_data) |
|
|
if padding_needed > 0: |
|
|
header_data += b'\x00' * padding_needed |
|
|
|
|
|
|
|
|
pseudo_header = struct.pack( |
|
|
'!4s4sBBH', |
|
|
socket.inet_aton(source_ip), |
|
|
socket.inet_aton(dest_ip), |
|
|
0, IPProtocol.TCP.value, |
|
|
len(header_data) + len(payload) |
|
|
) |
|
|
|
|
|
checksum_data = pseudo_header + header_data + payload |
|
|
checksum = cls.calculate_checksum(checksum_data) |
|
|
|
|
|
|
|
|
header_data = header_data[:16] + struct.pack('!H', checksum) + header_data[18:] |
|
|
|
|
|
return header_data |
|
|
|
|
|
@classmethod |
|
|
def build_udp_header(cls, header: UDPHeader, source_ip: str, dest_ip: str, payload: bytes) -> bytes: |
|
|
"""Build UDP header as bytes with checksum""" |
|
|
header.length = 8 + len(payload) |
|
|
|
|
|
|
|
|
header_data = struct.pack( |
|
|
'!HHHH', |
|
|
header.source_port, header.dest_port, |
|
|
header.length, 0 |
|
|
) |
|
|
|
|
|
|
|
|
if header.checksum != 0: |
|
|
pseudo_header = struct.pack( |
|
|
'!4s4sBBH', |
|
|
socket.inet_aton(source_ip), |
|
|
socket.inet_aton(dest_ip), |
|
|
0, IPProtocol.UDP.value, |
|
|
header.length |
|
|
) |
|
|
|
|
|
checksum_data = pseudo_header + header_data + payload |
|
|
checksum = cls.calculate_checksum(checksum_data) |
|
|
|
|
|
|
|
|
header_data = header_data[:6] + struct.pack('!H', checksum) + header_data[8:] |
|
|
|
|
|
return header_data |
|
|
|
|
|
@classmethod |
|
|
def build_packet(cls, ip_header: IPv4Header, transport_header: Optional[object] = None, payload: bytes = b'') -> bytes: |
|
|
"""Build complete packet""" |
|
|
transport_data = b'' |
|
|
|
|
|
|
|
|
if transport_header: |
|
|
if isinstance(transport_header, TCPHeader): |
|
|
transport_data = cls.build_tcp_header( |
|
|
transport_header, ip_header.source_ip, ip_header.dest_ip, payload |
|
|
) |
|
|
elif isinstance(transport_header, UDPHeader): |
|
|
transport_data = cls.build_udp_header( |
|
|
transport_header, ip_header.source_ip, ip_header.dest_ip, payload |
|
|
) |
|
|
|
|
|
|
|
|
ip_header.total_length = ip_header.header_length + len(transport_data) + len(payload) |
|
|
|
|
|
|
|
|
ip_data = cls.build_ipv4_header(ip_header) |
|
|
|
|
|
|
|
|
return ip_data + transport_data + payload |
|
|
|
|
|
|
|
|
class PacketFragmenter: |
|
|
"""Handle packet fragmentation and reassembly""" |
|
|
|
|
|
def __init__(self, mtu: int = 1500): |
|
|
self.mtu = mtu |
|
|
self.fragments: Dict[Tuple[str, str, int], List[Tuple[int, bytes]]] = {} |
|
|
|
|
|
def fragment_packet(self, packet: bytes, mtu: int = None) -> List[bytes]: |
|
|
"""Fragment a packet if it exceeds MTU""" |
|
|
if mtu is None: |
|
|
mtu = self.mtu |
|
|
|
|
|
if len(packet) <= mtu: |
|
|
return [packet] |
|
|
|
|
|
|
|
|
parsed = IPParser.parse_packet(packet) |
|
|
ip_header = parsed.ip_header |
|
|
|
|
|
|
|
|
if ip_header.dont_fragment: |
|
|
raise ValueError("Packet too large and Don't Fragment flag is set") |
|
|
|
|
|
fragments = [] |
|
|
payload_mtu = mtu - ip_header.header_length |
|
|
payload_mtu = (payload_mtu // 8) * 8 |
|
|
|
|
|
|
|
|
payload_start = ip_header.header_length |
|
|
payload = packet[payload_start:] |
|
|
|
|
|
offset = 0 |
|
|
while offset < len(payload): |
|
|
|
|
|
fragment_payload = payload[offset:offset + payload_mtu] |
|
|
|
|
|
|
|
|
frag_header = IPv4Header( |
|
|
version=ip_header.version, |
|
|
ihl=ip_header.ihl, |
|
|
tos=ip_header.tos, |
|
|
identification=ip_header.identification, |
|
|
ttl=ip_header.ttl, |
|
|
protocol=ip_header.protocol, |
|
|
source_ip=ip_header.source_ip, |
|
|
dest_ip=ip_header.dest_ip, |
|
|
options=ip_header.options |
|
|
) |
|
|
|
|
|
|
|
|
frag_header.fragment_offset = (ip_header.fragment_offset * 8 + offset) // 8 |
|
|
frag_header.flags = ip_header.flags |
|
|
|
|
|
|
|
|
if offset + len(fragment_payload) < len(payload): |
|
|
frag_header.flags |= 0x1 |
|
|
else: |
|
|
frag_header.flags &= ~0x1 |
|
|
|
|
|
|
|
|
fragment = IPParser.build_packet(frag_header, payload=fragment_payload) |
|
|
fragments.append(fragment) |
|
|
|
|
|
offset += len(fragment_payload) |
|
|
|
|
|
return fragments |
|
|
|
|
|
def reassemble_packet(self, packet: bytes) -> Optional[bytes]: |
|
|
"""Reassemble fragmented packet""" |
|
|
parsed = IPParser.parse_packet(packet) |
|
|
ip_header = parsed.ip_header |
|
|
|
|
|
|
|
|
if not ip_header.is_fragment: |
|
|
return packet |
|
|
|
|
|
|
|
|
key = (ip_header.source_ip, ip_header.dest_ip, ip_header.identification) |
|
|
|
|
|
|
|
|
if key not in self.fragments: |
|
|
self.fragments[key] = [] |
|
|
|
|
|
payload_start = ip_header.header_length |
|
|
fragment_data = packet[payload_start:] |
|
|
self.fragments[key].append((ip_header.fragment_offset * 8, fragment_data)) |
|
|
|
|
|
|
|
|
fragments = sorted(self.fragments[key]) |
|
|
|
|
|
|
|
|
expected_offset = 0 |
|
|
complete_payload = b'' |
|
|
|
|
|
for offset, data in fragments: |
|
|
if offset != expected_offset: |
|
|
return None |
|
|
|
|
|
complete_payload += data |
|
|
expected_offset += len(data) |
|
|
|
|
|
|
|
|
last_fragment = None |
|
|
for frag_packet in [packet]: |
|
|
frag_parsed = IPParser.parse_packet(frag_packet) |
|
|
if not frag_parsed.ip_header.more_fragments: |
|
|
last_fragment = frag_parsed |
|
|
break |
|
|
|
|
|
if last_fragment is None: |
|
|
return None |
|
|
|
|
|
|
|
|
complete_header = IPv4Header( |
|
|
version=ip_header.version, |
|
|
ihl=ip_header.ihl, |
|
|
tos=ip_header.tos, |
|
|
identification=ip_header.identification, |
|
|
flags=ip_header.flags & ~0x1, |
|
|
fragment_offset=0, |
|
|
ttl=ip_header.ttl, |
|
|
protocol=ip_header.protocol, |
|
|
source_ip=ip_header.source_ip, |
|
|
dest_ip=ip_header.dest_ip, |
|
|
options=ip_header.options |
|
|
) |
|
|
|
|
|
complete_packet = IPParser.build_packet(complete_header, payload=complete_payload) |
|
|
|
|
|
|
|
|
del self.fragments[key] |
|
|
|
|
|
return complete_packet |
|
|
|
|
|
|