TNT / core /ip_parser.py
Fred808's picture
Upload 48 files
50d86e3 verified
"""
IP Parser/Assembler Module
Handles IPv4 packet parsing and construction:
- Parse IPv4, UDP, and TCP headers
- Calculate and verify checksums
- Handle packet fragmentation and reassembly
- Support various IP options
"""
import struct
import socket
from typing import Dict, List, Optional, Tuple
from dataclasses import dataclass
from enum import Enum
class IPProtocol(Enum):
ICMP = 1
TCP = 6
UDP = 17
@dataclass
class IPv4Header:
"""IPv4 header structure"""
version: int = 4
ihl: int = 5 # Internet Header Length (in 32-bit words)
tos: int = 0 # Type of Service
total_length: int = 0
identification: int = 0
flags: int = 0 # 3 bits: Reserved, Don't Fragment, More Fragments
fragment_offset: int = 0 # 13 bits
ttl: int = 64 # Time to Live
protocol: int = 0
header_checksum: int = 0
source_ip: str = '0.0.0.0'
dest_ip: str = '0.0.0.0'
options: bytes = b''
@property
def header_length(self) -> int:
"""Get header length in bytes"""
return self.ihl * 4
@property
def dont_fragment(self) -> bool:
"""Check if Don't Fragment flag is set"""
return bool(self.flags & 0x2)
@property
def more_fragments(self) -> bool:
"""Check if More Fragments flag is set"""
return bool(self.flags & 0x1)
@property
def is_fragment(self) -> bool:
"""Check if this is a fragment"""
return self.more_fragments or self.fragment_offset > 0
@dataclass
class TCPHeader:
"""TCP header structure"""
source_port: int = 0
dest_port: int = 0
seq_num: int = 0
ack_num: int = 0
data_offset: int = 5 # Header length in 32-bit words
reserved: int = 0
flags: int = 0 # 9 bits: NS, CWR, ECE, URG, ACK, PSH, RST, SYN, FIN
window_size: int = 65535
checksum: int = 0
urgent_pointer: int = 0
options: bytes = b''
@property
def header_length(self) -> int:
"""Get header length in bytes"""
return self.data_offset * 4
# TCP Flag properties
@property
def fin(self) -> bool:
return bool(self.flags & 0x01)
@property
def syn(self) -> bool:
return bool(self.flags & 0x02)
@property
def rst(self) -> bool:
return bool(self.flags & 0x04)
@property
def psh(self) -> bool:
return bool(self.flags & 0x08)
@property
def ack(self) -> bool:
return bool(self.flags & 0x10)
@property
def urg(self) -> bool:
return bool(self.flags & 0x20)
def set_flag(self, flag_name: str, value: bool = True):
"""Set TCP flag"""
flag_bits = {
'fin': 0x01, 'syn': 0x02, 'rst': 0x04, 'psh': 0x08,
'ack': 0x10, 'urg': 0x20, 'ece': 0x40, 'cwr': 0x80, 'ns': 0x100
}
if flag_name.lower() in flag_bits:
bit = flag_bits[flag_name.lower()]
if value:
self.flags |= bit
else:
self.flags &= ~bit
@dataclass
class UDPHeader:
"""UDP header structure"""
source_port: int = 0
dest_port: int = 0
length: int = 8 # Header + data length
checksum: int = 0
@property
def header_length(self) -> int:
"""Get header length in bytes (always 8 for UDP)"""
return 8
@dataclass
class ParsedPacket:
"""Parsed packet structure"""
ip_header: IPv4Header
transport_header: Optional[object] = None # TCPHeader or UDPHeader
payload: bytes = b''
raw_packet: bytes = b''
class IPParser:
"""IPv4 packet parser and assembler"""
@staticmethod
def calculate_checksum(data: bytes) -> int:
"""Calculate Internet checksum"""
# Pad data to even length
if len(data) % 2:
data += b'\x00'
checksum = 0
for i in range(0, len(data), 2):
word = (data[i] << 8) + data[i + 1]
checksum += word
# Add carry bits
while checksum >> 16:
checksum = (checksum & 0xFFFF) + (checksum >> 16)
# One's complement
return (~checksum) & 0xFFFF
@staticmethod
def verify_checksum(data: bytes, checksum: int) -> bool:
"""Verify Internet checksum"""
calculated = IPParser.calculate_checksum(data)
return calculated == checksum or (calculated + checksum) == 0xFFFF
@classmethod
def parse_ipv4_header(cls, data: bytes) -> Tuple[IPv4Header, int]:
"""Parse IPv4 header from raw bytes"""
if len(data) < 20:
raise ValueError("IPv4 header too short")
# Parse fixed part of header
header_data = struct.unpack('!BBHHHBBH4s4s', data[:20])
header = IPv4Header()
version_ihl = header_data[0]
header.version = (version_ihl >> 4) & 0xF
header.ihl = version_ihl & 0xF
header.tos = header_data[1]
header.total_length = header_data[2]
header.identification = header_data[3]
flags_fragment = header_data[4]
header.flags = (flags_fragment >> 13) & 0x7
header.fragment_offset = flags_fragment & 0x1FFF
header.ttl = header_data[5]
header.protocol = header_data[6]
header.header_checksum = header_data[7]
header.source_ip = socket.inet_ntoa(header_data[8])
header.dest_ip = socket.inet_ntoa(header_data[9])
# Validate version
if header.version != 4:
raise ValueError(f"Unsupported IP version: {header.version}")
# Parse options if present
options_length = header.header_length - 20
if options_length > 0:
if len(data) < 20 + options_length:
raise ValueError("IPv4 options truncated")
header.options = data[20:20 + options_length]
return header, header.header_length
@classmethod
def parse_tcp_header(cls, data: bytes) -> Tuple[TCPHeader, int]:
"""Parse TCP header from raw bytes"""
if len(data) < 20:
raise ValueError("TCP header too short")
# Parse fixed part of header
header_data = struct.unpack('!HHIIBBHHH', data[:20])
header = TCPHeader()
header.source_port = header_data[0]
header.dest_port = header_data[1]
header.seq_num = header_data[2]
header.ack_num = header_data[3]
offset_reserved = header_data[4]
header.data_offset = (offset_reserved >> 4) & 0xF
header.reserved = (offset_reserved >> 1) & 0x7
header.flags = ((offset_reserved & 0x1) << 8) | header_data[5]
header.window_size = header_data[6]
header.checksum = header_data[7]
header.urgent_pointer = header_data[8]
# Parse options if present
options_length = header.header_length - 20
if options_length > 0:
if len(data) < 20 + options_length:
raise ValueError("TCP options truncated")
header.options = data[20:20 + options_length]
return header, header.header_length
@classmethod
def parse_udp_header(cls, data: bytes) -> Tuple[UDPHeader, int]:
"""Parse UDP header from raw bytes"""
if len(data) < 8:
raise ValueError("UDP header too short")
header_data = struct.unpack('!HHHH', data[:8])
header = UDPHeader()
header.source_port = header_data[0]
header.dest_port = header_data[1]
header.length = header_data[2]
header.checksum = header_data[3]
return header, 8
@classmethod
def parse_packet(cls, data: bytes) -> ParsedPacket:
"""Parse complete packet"""
packet = ParsedPacket(raw_packet=data)
# Parse IP header
packet.ip_header, ip_header_len = cls.parse_ipv4_header(data)
# Extract payload after IP header
ip_payload = data[ip_header_len:packet.ip_header.total_length]
# Parse transport layer header
if packet.ip_header.protocol == IPProtocol.TCP.value:
packet.transport_header, transport_header_len = cls.parse_tcp_header(ip_payload)
packet.payload = ip_payload[transport_header_len:]
elif packet.ip_header.protocol == IPProtocol.UDP.value:
packet.transport_header, transport_header_len = cls.parse_udp_header(ip_payload)
packet.payload = ip_payload[transport_header_len:]
else:
# Unsupported protocol, treat as raw payload
packet.payload = ip_payload
return packet
@classmethod
def build_ipv4_header(cls, header: IPv4Header) -> bytes:
"""Build IPv4 header as bytes"""
# Calculate header length including options
header.ihl = (20 + len(header.options) + 3) // 4 # Round up to 32-bit boundary
# Build header without checksum
version_ihl = (header.version << 4) | header.ihl
flags_fragment = (header.flags << 13) | header.fragment_offset
header_data = struct.pack(
'!BBHHHBBH4s4s',
version_ihl, header.tos, header.total_length,
header.identification, flags_fragment,
header.ttl, header.protocol, 0, # Checksum = 0 for calculation
socket.inet_aton(header.source_ip),
socket.inet_aton(header.dest_ip)
)
# Add options and padding
if header.options:
header_data += header.options
# Pad to 32-bit boundary
padding_needed = (header.ihl * 4) - len(header_data)
if padding_needed > 0:
header_data += b'\x00' * padding_needed
# Calculate and insert checksum
checksum = cls.calculate_checksum(header_data)
header_data = header_data[:10] + struct.pack('!H', checksum) + header_data[12:]
return header_data
@classmethod
def build_tcp_header(cls, header: TCPHeader, source_ip: str, dest_ip: str, payload: bytes) -> bytes:
"""Build TCP header as bytes with checksum"""
# Calculate header length including options
header.data_offset = (20 + len(header.options) + 3) // 4 # Round up to 32-bit boundary
# Build header without checksum
offset_reserved_flags = (header.data_offset << 12) | (header.reserved << 9) | header.flags
header_data = struct.pack(
'!HHIIHHH',
header.source_port, header.dest_port,
header.seq_num, header.ack_num,
offset_reserved_flags, header.window_size,
0, header.urgent_pointer # Checksum = 0 for calculation
)
# Add options and padding
if header.options:
header_data += header.options
# Pad to 32-bit boundary
padding_needed = (header.data_offset * 4) - len(header_data)
if padding_needed > 0:
header_data += b'\x00' * padding_needed
# Calculate TCP checksum with pseudo-header
pseudo_header = struct.pack(
'!4s4sBBH',
socket.inet_aton(source_ip),
socket.inet_aton(dest_ip),
0, IPProtocol.TCP.value,
len(header_data) + len(payload)
)
checksum_data = pseudo_header + header_data + payload
checksum = cls.calculate_checksum(checksum_data)
# Insert checksum
header_data = header_data[:16] + struct.pack('!H', checksum) + header_data[18:]
return header_data
@classmethod
def build_udp_header(cls, header: UDPHeader, source_ip: str, dest_ip: str, payload: bytes) -> bytes:
"""Build UDP header as bytes with checksum"""
header.length = 8 + len(payload)
# Build header without checksum
header_data = struct.pack(
'!HHHH',
header.source_port, header.dest_port,
header.length, 0 # Checksum = 0 for calculation
)
# Calculate UDP checksum with pseudo-header (optional for IPv4)
if header.checksum != 0: # If checksum is required
pseudo_header = struct.pack(
'!4s4sBBH',
socket.inet_aton(source_ip),
socket.inet_aton(dest_ip),
0, IPProtocol.UDP.value,
header.length
)
checksum_data = pseudo_header + header_data + payload
checksum = cls.calculate_checksum(checksum_data)
# Insert checksum
header_data = header_data[:6] + struct.pack('!H', checksum) + header_data[8:]
return header_data
@classmethod
def build_packet(cls, ip_header: IPv4Header, transport_header: Optional[object] = None, payload: bytes = b'') -> bytes:
"""Build complete packet"""
transport_data = b''
# Build transport header
if transport_header:
if isinstance(transport_header, TCPHeader):
transport_data = cls.build_tcp_header(
transport_header, ip_header.source_ip, ip_header.dest_ip, payload
)
elif isinstance(transport_header, UDPHeader):
transport_data = cls.build_udp_header(
transport_header, ip_header.source_ip, ip_header.dest_ip, payload
)
# Update IP header total length
ip_header.total_length = ip_header.header_length + len(transport_data) + len(payload)
# Build IP header
ip_data = cls.build_ipv4_header(ip_header)
# Combine all parts
return ip_data + transport_data + payload
class PacketFragmenter:
"""Handle packet fragmentation and reassembly"""
def __init__(self, mtu: int = 1500):
self.mtu = mtu
self.fragments: Dict[Tuple[str, str, int], List[Tuple[int, bytes]]] = {} # (src, dst, id) -> [(offset, data)]
def fragment_packet(self, packet: bytes, mtu: int = None) -> List[bytes]:
"""Fragment a packet if it exceeds MTU"""
if mtu is None:
mtu = self.mtu
if len(packet) <= mtu:
return [packet]
# Parse original packet
parsed = IPParser.parse_packet(packet)
ip_header = parsed.ip_header
# Don't fragment if DF flag is set
if ip_header.dont_fragment:
raise ValueError("Packet too large and Don't Fragment flag is set")
fragments = []
payload_mtu = mtu - ip_header.header_length
payload_mtu = (payload_mtu // 8) * 8 # Must be multiple of 8 bytes
# Get the payload to fragment (everything after IP header)
payload_start = ip_header.header_length
payload = packet[payload_start:]
offset = 0
while offset < len(payload):
# Create fragment
fragment_payload = payload[offset:offset + payload_mtu]
# Create new IP header for fragment
frag_header = IPv4Header(
version=ip_header.version,
ihl=ip_header.ihl,
tos=ip_header.tos,
identification=ip_header.identification,
ttl=ip_header.ttl,
protocol=ip_header.protocol,
source_ip=ip_header.source_ip,
dest_ip=ip_header.dest_ip,
options=ip_header.options
)
# Set fragment offset and flags
frag_header.fragment_offset = (ip_header.fragment_offset * 8 + offset) // 8
frag_header.flags = ip_header.flags
# Set More Fragments flag if not last fragment
if offset + len(fragment_payload) < len(payload):
frag_header.flags |= 0x1 # More Fragments
else:
frag_header.flags &= ~0x1 # Clear More Fragments
# Build fragment
fragment = IPParser.build_packet(frag_header, payload=fragment_payload)
fragments.append(fragment)
offset += len(fragment_payload)
return fragments
def reassemble_packet(self, packet: bytes) -> Optional[bytes]:
"""Reassemble fragmented packet"""
parsed = IPParser.parse_packet(packet)
ip_header = parsed.ip_header
# If not a fragment, return as-is
if not ip_header.is_fragment:
return packet
# Create fragment key
key = (ip_header.source_ip, ip_header.dest_ip, ip_header.identification)
# Store fragment
if key not in self.fragments:
self.fragments[key] = []
payload_start = ip_header.header_length
fragment_data = packet[payload_start:]
self.fragments[key].append((ip_header.fragment_offset * 8, fragment_data))
# Check if we have all fragments
fragments = sorted(self.fragments[key])
# Verify we have contiguous fragments starting from 0
expected_offset = 0
complete_payload = b''
for offset, data in fragments:
if offset != expected_offset:
return None # Missing fragment
complete_payload += data
expected_offset += len(data)
# Check if last fragment (no More Fragments flag)
last_fragment = None
for frag_packet in [packet]: # We only have current packet, need to track all
frag_parsed = IPParser.parse_packet(frag_packet)
if not frag_parsed.ip_header.more_fragments:
last_fragment = frag_parsed
break
if last_fragment is None:
return None # Don't have last fragment yet
# Reassemble complete packet
complete_header = IPv4Header(
version=ip_header.version,
ihl=ip_header.ihl,
tos=ip_header.tos,
identification=ip_header.identification,
flags=ip_header.flags & ~0x1, # Clear More Fragments
fragment_offset=0,
ttl=ip_header.ttl,
protocol=ip_header.protocol,
source_ip=ip_header.source_ip,
dest_ip=ip_header.dest_ip,
options=ip_header.options
)
complete_packet = IPParser.build_packet(complete_header, payload=complete_payload)
# Clean up fragments
del self.fragments[key]
return complete_packet