""" TCP Engine Module Implements a complete TCP state machine in user-space: - Full TCP state machine (SYN, SYN-ACK, ESTABLISHED, FIN, RST) - Sequence and acknowledgment number tracking - Sliding window implementation - Retransmission and timeout handling - Congestion control """ import time import threading import random from typing import Dict, List, Optional, Tuple, Callable from dataclasses import dataclass, field from enum import Enum from collections import deque from .ip_parser import TCPHeader, IPv4Header, IPParser class TCPState(Enum): CLOSED = "CLOSED" LISTEN = "LISTEN" SYN_SENT = "SYN_SENT" SYN_RECEIVED = "SYN_RECEIVED" ESTABLISHED = "ESTABLISHED" FIN_WAIT_1 = "FIN_WAIT_1" FIN_WAIT_2 = "FIN_WAIT_2" CLOSE_WAIT = "CLOSE_WAIT" CLOSING = "CLOSING" LAST_ACK = "LAST_ACK" TIME_WAIT = "TIME_WAIT" @dataclass class TCPSegment: """Represents a TCP segment""" seq_num: int ack_num: int flags: int window: int data: bytes timestamp: float = field(default_factory=time.time) retransmit_count: int = 0 @property def data_length(self) -> int: """Get data length""" return len(self.data) @property def seq_end(self) -> int: """Get sequence number after this segment""" length = self.data_length # SYN and FIN consume one sequence number if self.flags & 0x02: # SYN length += 1 if self.flags & 0x01: # FIN length += 1 return self.seq_num + length @dataclass class TCPConnection: """Represents a TCP connection state""" # Connection identification local_ip: str local_port: int remote_ip: str remote_port: int # State state: TCPState = TCPState.CLOSED # Sequence numbers local_seq: int = 0 local_ack: int = 0 remote_seq: int = 0 remote_ack: int = 0 initial_seq: int = 0 # Window management local_window: int = 65535 remote_window: int = 65535 window_scale: int = 0 # Buffers send_buffer: deque = field(default_factory=deque) recv_buffer: deque = field(default_factory=deque) out_of_order_buffer: Dict[int, bytes] = field(default_factory=dict) # Retransmission unacked_segments: Dict[int, TCPSegment] = field(default_factory=dict) retransmit_timer: Optional[float] = None rto: float = 1.0 # Retransmission timeout srtt: float = 0.0 # Smoothed round-trip time rttvar: float = 0.0 # Round-trip time variation # Congestion control cwnd: int = 1 # Congestion window (in MSS units) ssthresh: int = 65535 # Slow start threshold mss: int = 1460 # Maximum segment size # Timers last_activity: float = field(default_factory=time.time) time_wait_start: Optional[float] = None # Callbacks on_data_received: Optional[Callable[[bytes], None]] = None on_connection_closed: Optional[Callable[[], None]] = None @property def connection_id(self) -> str: """Get unique connection identifier""" return f"{self.local_ip}:{self.local_port}-{self.remote_ip}:{self.remote_port}" @property def is_established(self) -> bool: """Check if connection is established""" return self.state == TCPState.ESTABLISHED @property def can_send_data(self) -> bool: """Check if connection can send data""" return self.state in [TCPState.ESTABLISHED, TCPState.CLOSE_WAIT] @property def effective_window(self) -> int: """Get effective send window""" return min(self.remote_window, self.cwnd * self.mss) class TCPEngine: """TCP state machine implementation""" def __init__(self, config: Dict): self.config = config self.connections: Dict[str, TCPConnection] = {} self.listening_ports: Dict[int, Callable] = {} # port -> accept callback self.lock = threading.Lock() self.running = False self.timer_thread = None # Default configuration self.default_mss = config.get('mss', 1460) self.default_window = config.get('initial_window', 65535) self.max_retries = config.get('max_retries', 3) self.connection_timeout = config.get('timeout', 300) self.time_wait_timeout = config.get('time_wait_timeout', 120) def _generate_isn(self) -> int: """Generate Initial Sequence Number""" return random.randint(0, 0xFFFFFFFF) def _get_connection_key(self, local_ip: str, local_port: int, remote_ip: str, remote_port: int) -> str: """Get connection key""" return f"{local_ip}:{local_port}-{remote_ip}:{remote_port}" def _create_tcp_segment(self, conn: TCPConnection, flags: int, data: bytes = b'') -> TCPSegment: """Create TCP segment""" segment = TCPSegment( seq_num=conn.local_seq, ack_num=conn.local_ack, flags=flags, window=conn.local_window, data=data ) return segment def _build_tcp_packet(self, conn: TCPConnection, segment: TCPSegment) -> bytes: """Build complete TCP packet""" # Create IP header ip_header = IPv4Header( protocol=6, # TCP source_ip=conn.local_ip, dest_ip=conn.remote_ip, ttl=64 ) # Create TCP header tcp_header = TCPHeader( source_port=conn.local_port, dest_port=conn.remote_port, seq_num=segment.seq_num, ack_num=segment.ack_num, flags=segment.flags, window_size=segment.window ) # Build packet return IPParser.build_packet(ip_header, tcp_header, segment.data) def _update_rto(self, conn: TCPConnection, rtt: float): """Update retransmission timeout using RFC 6298""" if conn.srtt == 0: # First RTT measurement conn.srtt = rtt conn.rttvar = rtt / 2 else: # Subsequent measurements alpha = 0.125 beta = 0.25 conn.rttvar = (1 - beta) * conn.rttvar + beta * abs(conn.srtt - rtt) conn.srtt = (1 - alpha) * conn.srtt + alpha * rtt # Calculate RTO conn.rto = max(1.0, conn.srtt + 4 * conn.rttvar) conn.rto = min(conn.rto, 60.0) # Cap at 60 seconds def _update_congestion_window(self, conn: TCPConnection, acked_bytes: int): """Update congestion window (simplified congestion control)""" if conn.cwnd < conn.ssthresh: # Slow start conn.cwnd += 1 else: # Congestion avoidance conn.cwnd += max(1, conn.mss * conn.mss // conn.cwnd) def _handle_retransmission(self, conn: TCPConnection): """Handle segment retransmission""" current_time = time.time() # Find segments that need retransmission to_retransmit = [] for seq_num, segment in conn.unacked_segments.items(): if current_time - segment.timestamp > conn.rto: if segment.retransmit_count < self.max_retries: to_retransmit.append(segment) else: # Max retries exceeded, close connection self._close_connection(conn, reset=True) return # Retransmit segments for segment in to_retransmit: segment.retransmit_count += 1 segment.timestamp = current_time # Exponential backoff conn.rto = min(conn.rto * 2, 60.0) # Congestion control: reduce window conn.ssthresh = max(conn.cwnd // 2, 2) conn.cwnd = 1 # Send retransmitted segment packet = self._build_tcp_packet(conn, segment) self._send_packet(packet) def _send_packet(self, packet: bytes): """Send packet (to be implemented by integration layer)""" # This will be connected to the packet bridge pass def _close_connection(self, conn: TCPConnection, reset: bool = False): """Close connection""" if reset: # Send RST segment = self._create_tcp_segment(conn, 0x04) # RST flag packet = self._build_tcp_packet(conn, segment) self._send_packet(packet) conn.state = TCPState.CLOSED else: # Normal close if conn.state == TCPState.ESTABLISHED: # Send FIN segment = self._create_tcp_segment(conn, 0x01) # FIN flag packet = self._build_tcp_packet(conn, segment) self._send_packet(packet) conn.local_seq += 1 conn.state = TCPState.FIN_WAIT_1 # Cleanup if closed if conn.state == TCPState.CLOSED: if conn.on_connection_closed: conn.on_connection_closed() with self.lock: if conn.connection_id in self.connections: del self.connections[conn.connection_id] def listen(self, port: int, accept_callback: Callable): """Listen on port for incoming connections""" with self.lock: self.listening_ports[port] = accept_callback def connect(self, local_ip: str, local_port: int, remote_ip: str, remote_port: int) -> Optional[TCPConnection]: """Initiate outbound connection""" conn_key = self._get_connection_key(local_ip, local_port, remote_ip, remote_port) # Create connection conn = TCPConnection( local_ip=local_ip, local_port=local_port, remote_ip=remote_ip, remote_port=remote_port, state=TCPState.SYN_SENT, local_seq=self._generate_isn(), mss=self.default_mss, local_window=self.default_window ) conn.initial_seq = conn.local_seq with self.lock: self.connections[conn_key] = conn # Send SYN segment = self._create_tcp_segment(conn, 0x02) # SYN flag packet = self._build_tcp_packet(conn, segment) self._send_packet(packet) # Track unacked segment conn.unacked_segments[conn.local_seq] = segment conn.local_seq += 1 conn.retransmit_timer = time.time() return conn def send_data(self, conn: TCPConnection, data: bytes) -> bool: """Send data on established connection""" if not conn.can_send_data: return False # Add to send buffer conn.send_buffer.append(data) # Try to send immediately self._try_send_data(conn) return True def _try_send_data(self, conn: TCPConnection): """Try to send buffered data""" while conn.send_buffer and len(conn.unacked_segments) * conn.mss < conn.effective_window: data = conn.send_buffer.popleft() # Split data if larger than MSS while data: chunk = data[:conn.mss] data = data[conn.mss:] # Create and send segment segment = self._create_tcp_segment(conn, 0x18, chunk) # PSH+ACK flags packet = self._build_tcp_packet(conn, segment) self._send_packet(packet) # Track unacked segment conn.unacked_segments[conn.local_seq] = segment conn.local_seq += len(chunk) if not data: break def process_packet(self, packet_data: bytes) -> bool: """Process incoming TCP packet""" try: # Parse packet parsed = IPParser.parse_packet(packet_data) if not isinstance(parsed.transport_header, TCPHeader): return False ip_header = parsed.ip_header tcp_header = parsed.transport_header payload = parsed.payload # Find or create connection conn_key = self._get_connection_key( ip_header.dest_ip, tcp_header.dest_port, ip_header.source_ip, tcp_header.source_port ) with self.lock: conn = self.connections.get(conn_key) # Handle new connection (SYN to listening port) if not conn and tcp_header.syn and not tcp_header.ack: if tcp_header.dest_port in self.listening_ports: conn = self._handle_new_connection(ip_header, tcp_header) if conn: self.connections[conn_key] = conn if not conn: # Send RST for unknown connection self._send_rst(ip_header, tcp_header) return False # Process segment return self._process_segment(conn, tcp_header, payload) except Exception as e: print(f"Error processing TCP packet: {e}") return False def _handle_new_connection(self, ip_header: IPv4Header, tcp_header: TCPHeader) -> Optional[TCPConnection]: """Handle new incoming connection""" accept_callback = self.listening_ports.get(tcp_header.dest_port) if not accept_callback: return None # Create connection conn = TCPConnection( local_ip=ip_header.dest_ip, local_port=tcp_header.dest_port, remote_ip=ip_header.source_ip, remote_port=tcp_header.source_port, state=TCPState.SYN_RECEIVED, local_seq=self._generate_isn(), remote_seq=tcp_header.seq_num, local_ack=tcp_header.seq_num + 1, mss=self.default_mss, local_window=self.default_window ) conn.initial_seq = conn.local_seq # Send SYN-ACK segment = self._create_tcp_segment(conn, 0x12) # SYN+ACK flags packet = self._build_tcp_packet(conn, segment) self._send_packet(packet) # Track unacked segment conn.unacked_segments[conn.local_seq] = segment conn.local_seq += 1 conn.retransmit_timer = time.time() # Call accept callback accept_callback(conn) return conn def _process_segment(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool: """Process TCP segment based on connection state""" conn.last_activity = time.time() # Handle RST if tcp_header.rst: conn.state = TCPState.CLOSED self._close_connection(conn) return True # State machine if conn.state == TCPState.SYN_SENT: return self._handle_syn_sent(conn, tcp_header, payload) elif conn.state == TCPState.SYN_RECEIVED: return self._handle_syn_received(conn, tcp_header, payload) elif conn.state == TCPState.ESTABLISHED: return self._handle_established(conn, tcp_header, payload) elif conn.state == TCPState.FIN_WAIT_1: return self._handle_fin_wait_1(conn, tcp_header, payload) elif conn.state == TCPState.FIN_WAIT_2: return self._handle_fin_wait_2(conn, tcp_header, payload) elif conn.state == TCPState.CLOSE_WAIT: return self._handle_close_wait(conn, tcp_header, payload) elif conn.state == TCPState.CLOSING: return self._handle_closing(conn, tcp_header, payload) elif conn.state == TCPState.LAST_ACK: return self._handle_last_ack(conn, tcp_header, payload) elif conn.state == TCPState.TIME_WAIT: return self._handle_time_wait(conn, tcp_header, payload) return False def _handle_syn_sent(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool: """Handle segment in SYN_SENT state""" if tcp_header.syn and tcp_header.ack: # SYN-ACK received if tcp_header.ack_num == conn.local_seq: conn.remote_seq = tcp_header.seq_num conn.local_ack = tcp_header.seq_num + 1 conn.remote_window = tcp_header.window_size # Remove SYN from unacked segments if conn.local_seq - 1 in conn.unacked_segments: del conn.unacked_segments[conn.local_seq - 1] # Send ACK segment = self._create_tcp_segment(conn, 0x10) # ACK flag packet = self._build_tcp_packet(conn, segment) self._send_packet(packet) conn.state = TCPState.ESTABLISHED return True return False def _handle_syn_received(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool: """Handle segment in SYN_RECEIVED state""" if tcp_header.ack and tcp_header.ack_num == conn.local_seq: # ACK for our SYN-ACK conn.remote_window = tcp_header.window_size # Remove SYN-ACK from unacked segments if conn.local_seq - 1 in conn.unacked_segments: del conn.unacked_segments[conn.local_seq - 1] conn.state = TCPState.ESTABLISHED return True return False def _handle_established(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool: """Handle segment in ESTABLISHED state""" # Handle ACK if tcp_header.ack: self._process_ack(conn, tcp_header.ack_num) # Handle data if payload and tcp_header.seq_num == conn.local_ack: conn.local_ack += len(payload) # Deliver data if conn.on_data_received: conn.on_data_received(payload) # Send ACK segment = self._create_tcp_segment(conn, 0x10) # ACK flag packet = self._build_tcp_packet(conn, segment) self._send_packet(packet) # Handle FIN if tcp_header.fin: conn.local_ack += 1 # Send ACK segment = self._create_tcp_segment(conn, 0x10) # ACK flag packet = self._build_tcp_packet(conn, segment) self._send_packet(packet) conn.state = TCPState.CLOSE_WAIT return True def _handle_fin_wait_1(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool: """Handle segment in FIN_WAIT_1 state""" if tcp_header.ack: self._process_ack(conn, tcp_header.ack_num) if not conn.unacked_segments: # Our FIN was ACKed conn.state = TCPState.FIN_WAIT_2 if tcp_header.fin: conn.local_ack += 1 # Send ACK segment = self._create_tcp_segment(conn, 0x10) # ACK flag packet = self._build_tcp_packet(conn, segment) self._send_packet(packet) if conn.state == TCPState.FIN_WAIT_2: conn.state = TCPState.TIME_WAIT conn.time_wait_start = time.time() else: conn.state = TCPState.CLOSING return True def _handle_fin_wait_2(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool: """Handle segment in FIN_WAIT_2 state""" if tcp_header.fin: conn.local_ack += 1 # Send ACK segment = self._create_tcp_segment(conn, 0x10) # ACK flag packet = self._build_tcp_packet(conn, segment) self._send_packet(packet) conn.state = TCPState.TIME_WAIT conn.time_wait_start = time.time() return True def _handle_close_wait(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool: """Handle segment in CLOSE_WAIT state""" # Application should close the connection return True def _handle_closing(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool: """Handle segment in CLOSING state""" if tcp_header.ack: self._process_ack(conn, tcp_header.ack_num) if not conn.unacked_segments: # Our FIN was ACKed conn.state = TCPState.TIME_WAIT conn.time_wait_start = time.time() return True def _handle_last_ack(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool: """Handle segment in LAST_ACK state""" if tcp_header.ack: self._process_ack(conn, tcp_header.ack_num) if not conn.unacked_segments: # Our FIN was ACKed conn.state = TCPState.CLOSED self._close_connection(conn) return True def _handle_time_wait(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool: """Handle segment in TIME_WAIT state""" # Just acknowledge any segments if tcp_header.seq_num == conn.local_ack: segment = self._create_tcp_segment(conn, 0x10) # ACK flag packet = self._build_tcp_packet(conn, segment) self._send_packet(packet) return True def _process_ack(self, conn: TCPConnection, ack_num: int): """Process ACK and remove acknowledged segments""" acked_segments = [] acked_bytes = 0 for seq_num, segment in list(conn.unacked_segments.items()): if seq_num < ack_num: acked_segments.append((seq_num, segment)) acked_bytes += segment.data_length del conn.unacked_segments[seq_num] # Update RTT and congestion window if acked_segments: # Use first acked segment for RTT calculation rtt = time.time() - acked_segments[0][1].timestamp self._update_rto(conn, rtt) self._update_congestion_window(conn, acked_bytes) # Try to send more data self._try_send_data(conn) def _send_rst(self, ip_header: IPv4Header, tcp_header: TCPHeader): """Send RST for unknown connection""" # Create RST response rst_ip = IPv4Header( protocol=6, source_ip=ip_header.dest_ip, dest_ip=ip_header.source_ip, ttl=64 ) rst_tcp = TCPHeader( source_port=tcp_header.dest_port, dest_port=tcp_header.source_port, seq_num=tcp_header.ack_num if tcp_header.ack else 0, ack_num=tcp_header.seq_num + 1 if tcp_header.syn else tcp_header.seq_num, flags=0x14 if tcp_header.ack else 0x04 # RST+ACK or RST ) packet = IPParser.build_packet(rst_ip, rst_tcp) self._send_packet(packet) def _timer_loop(self): """Timer loop for handling timeouts""" while self.running: current_time = time.time() with self.lock: connections_to_check = list(self.connections.values()) for conn in connections_to_check: # Handle retransmissions if conn.unacked_segments: self._handle_retransmission(conn) # Handle connection timeout if current_time - conn.last_activity > self.connection_timeout: self._close_connection(conn, reset=True) # Handle TIME_WAIT timeout if (conn.state == TCPState.TIME_WAIT and conn.time_wait_start and current_time - conn.time_wait_start > self.time_wait_timeout): conn.state = TCPState.CLOSED self._close_connection(conn) time.sleep(1) # Check every second def start(self): """Start TCP engine""" self.running = True self.timer_thread = threading.Thread(target=self._timer_loop, daemon=True) self.timer_thread.start() print("TCP engine started") def stop(self): """Stop TCP engine""" self.running = False if self.timer_thread: self.timer_thread.join() # Close all connections with self.lock: for conn in list(self.connections.values()): self._close_connection(conn, reset=True) print("TCP engine stopped") def get_connections(self) -> Dict[str, Dict]: """Get current connections""" with self.lock: return { conn_id: { 'local_ip': conn.local_ip, 'local_port': conn.local_port, 'remote_ip': conn.remote_ip, 'remote_port': conn.remote_port, 'state': conn.state.value, 'local_seq': conn.local_seq, 'local_ack': conn.local_ack, 'remote_seq': conn.remote_seq, 'remote_ack': conn.remote_ack, 'window_size': conn.local_window, 'cwnd': conn.cwnd, 'unacked_segments': len(conn.unacked_segments), 'last_activity': conn.last_activity } for conn_id, conn in self.connections.items() }