""" TCP Engine Module Implements a complete TCP state machine in user-space: - Full TCP state machine (SYN, SYN-ACK, ESTABLISHED, FIN, RST) - Sequence and acknowledgment number tracking - Sliding window implementation - Retransmission and timeout handling - Congestion control """ import time import threading import random from typing import Dict, List, Optional, Tuple, Callable from dataclasses import dataclass, field from enum import Enum from collections import deque from .ip_parser import TCPHeader, IPv4Header, IPParser class TCPState(Enum): CLOSED = "CLOSED" LISTEN = "LISTEN" SYN_SENT = "SYN_SENT" SYN_RECEIVED = "SYN_RECEIVED" ESTABLISHED = "ESTABLISHED" FIN_WAIT_1 = "FIN_WAIT_1" FIN_WAIT_2 = "FIN_WAIT_2" CLOSE_WAIT = "CLOSE_WAIT" CLOSING = "CLOSING" LAST_ACK = "LAST_ACK" TIME_WAIT = "TIME_WAIT" @dataclass class TCPSegment: """Represents a TCP segment""" seq_num: int ack_num: int flags: int window: int data: bytes timestamp: float = field(default_factory=time.time) retransmit_count: int = 0 @property def data_length(self) -> int: """Get data length""" return len(self.data) @property def seq_end(self) -> int: """Get sequence number after this segment""" length = self.data_length # SYN and FIN consume one sequence number if self.flags & 0x02: # SYN length += 1 if self.flags & 0x01: # FIN length += 1 return self.seq_num + length @dataclass class TCPConnection: """Represents a TCP connection state""" # Connection identification local_ip: str local_port: int remote_ip: str remote_port: int # State state: TCPState = TCPState.CLOSED # Sequence numbers local_seq: int = field(default_factory=lambda: random.randint(0, 0xFFFFFFFF)) local_ack: int = 0 remote_seq: int = 0 remote_ack: int = 0 initial_seq: int = 0 # Window management local_window: int = 65535 remote_window: int = 65535 window_scale: int = 0 # Buffers send_buffer: deque = field(default_factory=deque) recv_buffer: deque = field(default_factory=deque) out_of_order_buffer: Dict[int, bytes] = field(default_factory=dict) # Retransmission unacked_segments: Dict[int, TCPSegment] = field(default_factory=dict) retransmit_timer: Optional[float] = None rto: float = 1.0 # Retransmission timeout srtt: float = 0.0 # Smoothed round-trip time rttvar: float = 0.0 # Round-trip time variation # Congestion control cwnd: int = 1 # Congestion window (in MSS) ssthresh: int = 65535 # Slow start threshold dupacks: int = 0 # Duplicate ACK count mss: int = 1460 # Maximum segment size # Callbacks on_data_received: Optional[Callable[[bytes], None]] = None on_state_change: Optional[Callable[[TCPState], None]] = None def __post_init__(self): self.initial_seq = self.local_seq def handle_packet(self, packet: bytes): """Process incoming TCP packet""" try: # Parse headers ip_header, payload = IPParser.parse_ipv4_header(packet) tcp_header, data = IPParser.parse_tcp_header(payload) # Process based on current state if self.state == TCPState.LISTEN: self._handle_listen(tcp_header, data) elif self.state == TCPState.SYN_SENT: self._handle_syn_sent(tcp_header, data) elif self.state == TCPState.SYN_RECEIVED: self._handle_syn_received(tcp_header, data) elif self.state == TCPState.ESTABLISHED: self._handle_established(tcp_header, data) elif self.state in (TCPState.FIN_WAIT_1, TCPState.FIN_WAIT_2): self._handle_fin_wait(tcp_header, data) elif self.state == TCPState.CLOSE_WAIT: self._handle_close_wait(tcp_header, data) elif self.state == TCPState.LAST_ACK: self._handle_last_ack(tcp_header, data) # Update RTT if this is an ACK for a sent packet if tcp_header.ack and tcp_header.ack_num > self.local_seq: self._update_rtt(tcp_header.ack_num) # Handle retransmission timer self._manage_retransmission_timer() except Exception as e: print(f"Error handling packet: {e}") def send_data(self, data: bytes): """Send data over the connection""" if self.state != TCPState.ESTABLISHED: return False # Add to send buffer self.send_buffer.extend(data) # Try to send what we can self._send_from_buffer() return True def close(self): """Initiate connection close""" if self.state == TCPState.ESTABLISHED: self._send_fin() self._set_state(TCPState.FIN_WAIT_1) elif self.state == TCPState.CLOSE_WAIT: self._send_fin() self._set_state(TCPState.LAST_ACK) def _set_state(self, new_state: TCPState): """Change connection state""" if new_state != self.state: self.state = new_state if self.on_state_change: self.on_state_change(new_state) def _send_packet(self, flags: int, data: bytes = b''): """Send TCP packet""" segment = TCPSegment( seq_num=self.local_seq, ack_num=self.local_ack, flags=flags, window=self.local_window, data=data ) # Add to unacked segments if not pure ACK if data or flags != 0x10: # Not pure ACK self.unacked_segments[self.local_seq] = segment # Update sequence number self.local_seq = (self.local_seq + len(data)) % 0x100000000 if flags & 0x02: # SYN self.local_seq = (self.local_seq + 1) % 0x100000000 if flags & 0x01: # FIN self.local_seq = (self.local_seq + 1) % 0x100000000 # TODO: Actually send the packet def _handle_listen(self, header: TCPHeader, data: bytes): """Handle LISTEN state""" if header.syn: self.remote_seq = header.seq_num self.local_ack = (header.seq_num + 1) % 0x100000000 self._send_packet(0x12) # SYN-ACK self._set_state(TCPState.SYN_RECEIVED) def _handle_syn_sent(self, header: TCPHeader, data: bytes): """Handle SYN_SENT state""" if header.syn and header.ack: if header.ack_num == (self.initial_seq + 1) % 0x100000000: self.remote_seq = header.seq_num self.local_ack = (header.seq_num + 1) % 0x100000000 self._send_packet(0x10) # ACK self._set_state(TCPState.ESTABLISHED) def _handle_established(self, header: TCPHeader, data: bytes): """Handle ESTABLISHED state""" if data: if header.seq_num == self.local_ack: # In-order segment if self.on_data_received: self.on_data_received(data) self.local_ack = (self.local_ack + len(data)) % 0x100000000 self._send_packet(0x10) # ACK elif header.seq_num > self.local_ack: # Out-of-order segment self.out_of_order_buffer[header.seq_num] = data self._send_packet(0x10) # ACK else: # Duplicate segment self._send_packet(0x10) # ACK if header.ack: # Process acknowledgments self._handle_ack(header.ack_num) if header.fin: self.local_ack = (self.local_ack + 1) % 0x100000000 self._send_packet(0x10) # ACK self._set_state(TCPState.CLOSE_WAIT) def _handle_ack(self, ack_num: int): """Handle incoming acknowledgment""" # Remove acknowledged segments acknowledged = [seq for seq in self.unacked_segments.keys() if seq < ack_num] for seq in acknowledged: del self.unacked_segments[seq] # Update congestion window if self.cwnd < self.ssthresh: # Slow start self.cwnd += 1 else: # Congestion avoidance self.cwnd += 1 / self.cwnd # Try to send more data self._send_from_buffer() def _send_from_buffer(self): """Send data from send buffer""" while self.send_buffer: # Calculate how much we can send window = min(self.remote_window, self.cwnd * self.mss) if not window: break # Get data to send data = bytes(list(self.send_buffer)[:window]) if not data: break # Remove from buffer and send for _ in range(len(data)): self.send_buffer.popleft() self._send_packet(0x18, data) # PSH-ACK def _update_rtt(self, ack_num: int): """Update RTT estimation""" for seq, segment in self.unacked_segments.items(): if seq == ack_num - 1: rtt = time.time() - segment.timestamp if self.srtt == 0: self.srtt = rtt self.rttvar = rtt / 2 else: self.rttvar = (0.75 * self.rttvar + 0.25 * abs(self.srtt - rtt)) self.srtt = 0.875 * self.srtt + 0.125 * rtt self.rto = self.srtt + max(4 * self.rttvar, 0.5) break def _manage_retransmission_timer(self): """Manage retransmission timer""" if not self.unacked_segments: self.retransmit_timer = None return current_time = time.time() if self.retransmit_timer is None: self.retransmit_timer = current_time + self.rto elif current_time >= self.retransmit_timer: # Timeout occurred self._handle_timeout() def _handle_timeout(self): """Handle retransmission timeout""" # Exponential backoff self.rto *= 2 # Reset congestion window self.ssthresh = max(2, self.cwnd // 2) self.cwnd = 1 # Retransmit oldest unacked segment if self.unacked_segments: oldest_seq = min(self.unacked_segments.keys()) segment = self.unacked_segments[oldest_seq] if segment.retransmit_count < 5: segment.retransmit_count += 1 self._send_packet(segment.flags, segment.data) else: # Too many retransmissions, close connection self._set_state(TCPState.CLOSED) # Reset timer self.retransmit_timer = time.time() + self.rto def _send_fin(self): """Send FIN packet""" self._send_packet(0x11) # FIN-ACK def _handle_fin_wait(self, header: TCPHeader, data: bytes): """Handle FIN_WAIT states""" if self.state == TCPState.FIN_WAIT_1: if header.ack and header.ack_num == self.local_seq: self._set_state(TCPState.FIN_WAIT_2) if header.fin: self.local_ack = (header.seq_num + 1) % 0x100000000 self._send_packet(0x10) # ACK if self.state == TCPState.FIN_WAIT_1: self._set_state(TCPState.CLOSING) else: # FIN_WAIT_2 self._set_state(TCPState.TIME_WAIT) def _handle_close_wait(self, header: TCPHeader, data: bytes): """Handle CLOSE_WAIT state""" if header.ack: self._handle_ack(header.ack_num) def _handle_last_ack(self, header: TCPHeader, data: bytes): """Handle LAST_ACK state""" if header.ack and header.ack_num == self.local_seq: self._set_state(TCPState.CLOSED)