Spaces:
Runtime error
Runtime error
| """ | |
| TCP Engine Module | |
| Implements a complete TCP state machine in user-space: | |
| - Full TCP state machine (SYN, SYN-ACK, ESTABLISHED, FIN, RST) | |
| - Sequence and acknowledgment number tracking | |
| - Sliding window implementation | |
| - Retransmission and timeout handling | |
| - Congestion control | |
| """ | |
| import time | |
| import threading | |
| import random | |
| from typing import Dict, List, Optional, Tuple, Callable | |
| from dataclasses import dataclass, field | |
| from enum import Enum | |
| from collections import deque | |
| from .ip_parser import TCPHeader, IPv4Header, IPParser | |
| class TCPState(Enum): | |
| CLOSED = "CLOSED" | |
| LISTEN = "LISTEN" | |
| SYN_SENT = "SYN_SENT" | |
| SYN_RECEIVED = "SYN_RECEIVED" | |
| ESTABLISHED = "ESTABLISHED" | |
| FIN_WAIT_1 = "FIN_WAIT_1" | |
| FIN_WAIT_2 = "FIN_WAIT_2" | |
| CLOSE_WAIT = "CLOSE_WAIT" | |
| CLOSING = "CLOSING" | |
| LAST_ACK = "LAST_ACK" | |
| TIME_WAIT = "TIME_WAIT" | |
| class TCPSegment: | |
| """Represents a TCP segment""" | |
| seq_num: int | |
| ack_num: int | |
| flags: int | |
| window: int | |
| data: bytes | |
| timestamp: float = field(default_factory=time.time) | |
| retransmit_count: int = 0 | |
| def data_length(self) -> int: | |
| """Get data length""" | |
| return len(self.data) | |
| def seq_end(self) -> int: | |
| """Get sequence number after this segment""" | |
| length = self.data_length | |
| # SYN and FIN consume one sequence number | |
| if self.flags & 0x02: # SYN | |
| length += 1 | |
| if self.flags & 0x01: # FIN | |
| length += 1 | |
| return self.seq_num + length | |
| class TCPConnection: | |
| """Represents a TCP connection state""" | |
| # Connection identification | |
| local_ip: str | |
| local_port: int | |
| remote_ip: str | |
| remote_port: int | |
| # State | |
| state: TCPState = TCPState.CLOSED | |
| # Sequence numbers | |
| local_seq: int = field(default_factory=lambda: random.randint(0, 0xFFFFFFFF)) | |
| local_ack: int = 0 | |
| remote_seq: int = 0 | |
| remote_ack: int = 0 | |
| initial_seq: int = 0 | |
| # Window management | |
| local_window: int = 65535 | |
| remote_window: int = 65535 | |
| window_scale: int = 0 | |
| # Buffers | |
| send_buffer: deque = field(default_factory=deque) | |
| recv_buffer: deque = field(default_factory=deque) | |
| out_of_order_buffer: Dict[int, bytes] = field(default_factory=dict) | |
| # Retransmission | |
| unacked_segments: Dict[int, TCPSegment] = field(default_factory=dict) | |
| retransmit_timer: Optional[float] = None | |
| rto: float = 1.0 # Retransmission timeout | |
| srtt: float = 0.0 # Smoothed round-trip time | |
| rttvar: float = 0.0 # Round-trip time variation | |
| # Congestion control | |
| cwnd: int = 1 # Congestion window (in MSS) | |
| ssthresh: int = 65535 # Slow start threshold | |
| dupacks: int = 0 # Duplicate ACK count | |
| mss: int = 1460 # Maximum segment size | |
| # Callbacks | |
| on_data_received: Optional[Callable[[bytes], None]] = None | |
| on_state_change: Optional[Callable[[TCPState], None]] = None | |
| def __post_init__(self): | |
| self.initial_seq = self.local_seq | |
| def handle_packet(self, packet: bytes): | |
| """Process incoming TCP packet""" | |
| try: | |
| # Parse headers | |
| ip_header, payload = IPParser.parse_ipv4_header(packet) | |
| tcp_header, data = IPParser.parse_tcp_header(payload) | |
| # Process based on current state | |
| if self.state == TCPState.LISTEN: | |
| self._handle_listen(tcp_header, data) | |
| elif self.state == TCPState.SYN_SENT: | |
| self._handle_syn_sent(tcp_header, data) | |
| elif self.state == TCPState.SYN_RECEIVED: | |
| self._handle_syn_received(tcp_header, data) | |
| elif self.state == TCPState.ESTABLISHED: | |
| self._handle_established(tcp_header, data) | |
| elif self.state in (TCPState.FIN_WAIT_1, TCPState.FIN_WAIT_2): | |
| self._handle_fin_wait(tcp_header, data) | |
| elif self.state == TCPState.CLOSE_WAIT: | |
| self._handle_close_wait(tcp_header, data) | |
| elif self.state == TCPState.LAST_ACK: | |
| self._handle_last_ack(tcp_header, data) | |
| # Update RTT if this is an ACK for a sent packet | |
| if tcp_header.ack and tcp_header.ack_num > self.local_seq: | |
| self._update_rtt(tcp_header.ack_num) | |
| # Handle retransmission timer | |
| self._manage_retransmission_timer() | |
| except Exception as e: | |
| print(f"Error handling packet: {e}") | |
| def send_data(self, data: bytes): | |
| """Send data over the connection""" | |
| if self.state != TCPState.ESTABLISHED: | |
| return False | |
| # Add to send buffer | |
| self.send_buffer.extend(data) | |
| # Try to send what we can | |
| self._send_from_buffer() | |
| return True | |
| def close(self): | |
| """Initiate connection close""" | |
| if self.state == TCPState.ESTABLISHED: | |
| self._send_fin() | |
| self._set_state(TCPState.FIN_WAIT_1) | |
| elif self.state == TCPState.CLOSE_WAIT: | |
| self._send_fin() | |
| self._set_state(TCPState.LAST_ACK) | |
| def _set_state(self, new_state: TCPState): | |
| """Change connection state""" | |
| if new_state != self.state: | |
| self.state = new_state | |
| if self.on_state_change: | |
| self.on_state_change(new_state) | |
| def _send_packet(self, flags: int, data: bytes = b''): | |
| """Send TCP packet""" | |
| segment = TCPSegment( | |
| seq_num=self.local_seq, | |
| ack_num=self.local_ack, | |
| flags=flags, | |
| window=self.local_window, | |
| data=data | |
| ) | |
| # Add to unacked segments if not pure ACK | |
| if data or flags != 0x10: # Not pure ACK | |
| self.unacked_segments[self.local_seq] = segment | |
| # Update sequence number | |
| self.local_seq = (self.local_seq + len(data)) % 0x100000000 | |
| if flags & 0x02: # SYN | |
| self.local_seq = (self.local_seq + 1) % 0x100000000 | |
| if flags & 0x01: # FIN | |
| self.local_seq = (self.local_seq + 1) % 0x100000000 | |
| # TODO: Actually send the packet | |
| def _handle_listen(self, header: TCPHeader, data: bytes): | |
| """Handle LISTEN state""" | |
| if header.syn: | |
| self.remote_seq = header.seq_num | |
| self.local_ack = (header.seq_num + 1) % 0x100000000 | |
| self._send_packet(0x12) # SYN-ACK | |
| self._set_state(TCPState.SYN_RECEIVED) | |
| def _handle_syn_sent(self, header: TCPHeader, data: bytes): | |
| """Handle SYN_SENT state""" | |
| if header.syn and header.ack: | |
| if header.ack_num == (self.initial_seq + 1) % 0x100000000: | |
| self.remote_seq = header.seq_num | |
| self.local_ack = (header.seq_num + 1) % 0x100000000 | |
| self._send_packet(0x10) # ACK | |
| self._set_state(TCPState.ESTABLISHED) | |
| def _handle_established(self, header: TCPHeader, data: bytes): | |
| """Handle ESTABLISHED state""" | |
| if data: | |
| if header.seq_num == self.local_ack: | |
| # In-order segment | |
| if self.on_data_received: | |
| self.on_data_received(data) | |
| self.local_ack = (self.local_ack + len(data)) % 0x100000000 | |
| self._send_packet(0x10) # ACK | |
| elif header.seq_num > self.local_ack: | |
| # Out-of-order segment | |
| self.out_of_order_buffer[header.seq_num] = data | |
| self._send_packet(0x10) # ACK | |
| else: | |
| # Duplicate segment | |
| self._send_packet(0x10) # ACK | |
| if header.ack: | |
| # Process acknowledgments | |
| self._handle_ack(header.ack_num) | |
| if header.fin: | |
| self.local_ack = (self.local_ack + 1) % 0x100000000 | |
| self._send_packet(0x10) # ACK | |
| self._set_state(TCPState.CLOSE_WAIT) | |
| def _handle_ack(self, ack_num: int): | |
| """Handle incoming acknowledgment""" | |
| # Remove acknowledged segments | |
| acknowledged = [seq for seq in self.unacked_segments.keys() | |
| if seq < ack_num] | |
| for seq in acknowledged: | |
| del self.unacked_segments[seq] | |
| # Update congestion window | |
| if self.cwnd < self.ssthresh: | |
| # Slow start | |
| self.cwnd += 1 | |
| else: | |
| # Congestion avoidance | |
| self.cwnd += 1 / self.cwnd | |
| # Try to send more data | |
| self._send_from_buffer() | |
| def _send_from_buffer(self): | |
| """Send data from send buffer""" | |
| while self.send_buffer: | |
| # Calculate how much we can send | |
| window = min(self.remote_window, self.cwnd * self.mss) | |
| if not window: | |
| break | |
| # Get data to send | |
| data = bytes(list(self.send_buffer)[:window]) | |
| if not data: | |
| break | |
| # Remove from buffer and send | |
| for _ in range(len(data)): | |
| self.send_buffer.popleft() | |
| self._send_packet(0x18, data) # PSH-ACK | |
| def _update_rtt(self, ack_num: int): | |
| """Update RTT estimation""" | |
| for seq, segment in self.unacked_segments.items(): | |
| if seq == ack_num - 1: | |
| rtt = time.time() - segment.timestamp | |
| if self.srtt == 0: | |
| self.srtt = rtt | |
| self.rttvar = rtt / 2 | |
| else: | |
| self.rttvar = (0.75 * self.rttvar + | |
| 0.25 * abs(self.srtt - rtt)) | |
| self.srtt = 0.875 * self.srtt + 0.125 * rtt | |
| self.rto = self.srtt + max(4 * self.rttvar, 0.5) | |
| break | |
| def _manage_retransmission_timer(self): | |
| """Manage retransmission timer""" | |
| if not self.unacked_segments: | |
| self.retransmit_timer = None | |
| return | |
| current_time = time.time() | |
| if self.retransmit_timer is None: | |
| self.retransmit_timer = current_time + self.rto | |
| elif current_time >= self.retransmit_timer: | |
| # Timeout occurred | |
| self._handle_timeout() | |
| def _handle_timeout(self): | |
| """Handle retransmission timeout""" | |
| # Exponential backoff | |
| self.rto *= 2 | |
| # Reset congestion window | |
| self.ssthresh = max(2, self.cwnd // 2) | |
| self.cwnd = 1 | |
| # Retransmit oldest unacked segment | |
| if self.unacked_segments: | |
| oldest_seq = min(self.unacked_segments.keys()) | |
| segment = self.unacked_segments[oldest_seq] | |
| if segment.retransmit_count < 5: | |
| segment.retransmit_count += 1 | |
| self._send_packet(segment.flags, segment.data) | |
| else: | |
| # Too many retransmissions, close connection | |
| self._set_state(TCPState.CLOSED) | |
| # Reset timer | |
| self.retransmit_timer = time.time() + self.rto | |
| def _send_fin(self): | |
| """Send FIN packet""" | |
| self._send_packet(0x11) # FIN-ACK | |
| def _handle_fin_wait(self, header: TCPHeader, data: bytes): | |
| """Handle FIN_WAIT states""" | |
| if self.state == TCPState.FIN_WAIT_1: | |
| if header.ack and header.ack_num == self.local_seq: | |
| self._set_state(TCPState.FIN_WAIT_2) | |
| if header.fin: | |
| self.local_ack = (header.seq_num + 1) % 0x100000000 | |
| self._send_packet(0x10) # ACK | |
| if self.state == TCPState.FIN_WAIT_1: | |
| self._set_state(TCPState.CLOSING) | |
| else: # FIN_WAIT_2 | |
| self._set_state(TCPState.TIME_WAIT) | |
| def _handle_close_wait(self, header: TCPHeader, data: bytes): | |
| """Handle CLOSE_WAIT state""" | |
| if header.ack: | |
| self._handle_ack(header.ack_num) | |
| def _handle_last_ack(self, header: TCPHeader, data: bytes): | |
| """Handle LAST_ACK state""" | |
| if header.ack and header.ack_num == self.local_seq: | |
| self._set_state(TCPState.CLOSED) | |