| """
|
| TCP Engine Module
|
|
|
| Implements a complete TCP state machine in user-space:
|
| - Full TCP state machine (SYN, SYN-ACK, ESTABLISHED, FIN, RST)
|
| - Sequence and acknowledgment number tracking
|
| - Sliding window implementation
|
| - Retransmission and timeout handling
|
| - Congestion control
|
| """
|
|
|
| import time
|
| import threading
|
| import random
|
| from typing import Dict, List, Optional, Tuple, Callable
|
| from dataclasses import dataclass, field
|
| from enum import Enum
|
| from collections import deque
|
|
|
| from .ip_parser import TCPHeader, IPv4Header, IPParser
|
|
|
|
|
| class TCPState(Enum):
|
| CLOSED = "CLOSED"
|
| LISTEN = "LISTEN"
|
| SYN_SENT = "SYN_SENT"
|
| SYN_RECEIVED = "SYN_RECEIVED"
|
| ESTABLISHED = "ESTABLISHED"
|
| FIN_WAIT_1 = "FIN_WAIT_1"
|
| FIN_WAIT_2 = "FIN_WAIT_2"
|
| CLOSE_WAIT = "CLOSE_WAIT"
|
| CLOSING = "CLOSING"
|
| LAST_ACK = "LAST_ACK"
|
| TIME_WAIT = "TIME_WAIT"
|
|
|
|
|
| @dataclass
|
| class TCPSegment:
|
| """Represents a TCP segment"""
|
| seq_num: int
|
| ack_num: int
|
| flags: int
|
| window: int
|
| data: bytes
|
| timestamp: float = field(default_factory=time.time)
|
| retransmit_count: int = 0
|
|
|
| @property
|
| def data_length(self) -> int:
|
| """Get data length"""
|
| return len(self.data)
|
|
|
| @property
|
| def seq_end(self) -> int:
|
| """Get sequence number after this segment"""
|
| length = self.data_length
|
|
|
| if self.flags & 0x02:
|
| length += 1
|
| if self.flags & 0x01:
|
| length += 1
|
| return self.seq_num + length
|
|
|
|
|
| @dataclass
|
| class TCPConnection:
|
| """Represents a TCP connection state"""
|
|
|
| local_ip: str
|
| local_port: int
|
| remote_ip: str
|
| remote_port: int
|
|
|
|
|
| state: TCPState = TCPState.CLOSED
|
|
|
|
|
| local_seq: int = field(default_factory=lambda: random.randint(0, 0xFFFFFFFF))
|
| local_ack: int = 0
|
| remote_seq: int = 0
|
| remote_ack: int = 0
|
| initial_seq: int = 0
|
|
|
|
|
| local_window: int = 65535
|
| remote_window: int = 65535
|
| window_scale: int = 0
|
|
|
|
|
| send_buffer: deque = field(default_factory=deque)
|
| recv_buffer: deque = field(default_factory=deque)
|
| out_of_order_buffer: Dict[int, bytes] = field(default_factory=dict)
|
|
|
|
|
| unacked_segments: Dict[int, TCPSegment] = field(default_factory=dict)
|
| retransmit_timer: Optional[float] = None
|
| rto: float = 1.0
|
| srtt: float = 0.0
|
| rttvar: float = 0.0
|
|
|
|
|
| cwnd: int = 1
|
| ssthresh: int = 65535
|
| dupacks: int = 0
|
| mss: int = 1460
|
|
|
|
|
| on_data_received: Optional[Callable[[bytes], None]] = None
|
| on_state_change: Optional[Callable[[TCPState], None]] = None
|
|
|
| def __post_init__(self):
|
| self.initial_seq = self.local_seq
|
|
|
| def handle_packet(self, packet: bytes):
|
| """Process incoming TCP packet"""
|
| try:
|
|
|
| ip_header, payload = IPParser.parse_ipv4_header(packet)
|
| tcp_header, data = IPParser.parse_tcp_header(payload)
|
|
|
|
|
| if self.state == TCPState.LISTEN:
|
| self._handle_listen(tcp_header, data)
|
| elif self.state == TCPState.SYN_SENT:
|
| self._handle_syn_sent(tcp_header, data)
|
| elif self.state == TCPState.SYN_RECEIVED:
|
| self._handle_syn_received(tcp_header, data)
|
| elif self.state == TCPState.ESTABLISHED:
|
| self._handle_established(tcp_header, data)
|
| elif self.state in (TCPState.FIN_WAIT_1, TCPState.FIN_WAIT_2):
|
| self._handle_fin_wait(tcp_header, data)
|
| elif self.state == TCPState.CLOSE_WAIT:
|
| self._handle_close_wait(tcp_header, data)
|
| elif self.state == TCPState.LAST_ACK:
|
| self._handle_last_ack(tcp_header, data)
|
|
|
|
|
| if tcp_header.ack and tcp_header.ack_num > self.local_seq:
|
| self._update_rtt(tcp_header.ack_num)
|
|
|
|
|
| self._manage_retransmission_timer()
|
|
|
| except Exception as e:
|
| print(f"Error handling packet: {e}")
|
|
|
| def send_data(self, data: bytes):
|
| """Send data over the connection"""
|
| if self.state != TCPState.ESTABLISHED:
|
| return False
|
|
|
|
|
| self.send_buffer.extend(data)
|
|
|
|
|
| self._send_from_buffer()
|
|
|
| return True
|
|
|
| def close(self):
|
| """Initiate connection close"""
|
| if self.state == TCPState.ESTABLISHED:
|
| self._send_fin()
|
| self._set_state(TCPState.FIN_WAIT_1)
|
| elif self.state == TCPState.CLOSE_WAIT:
|
| self._send_fin()
|
| self._set_state(TCPState.LAST_ACK)
|
|
|
| def _set_state(self, new_state: TCPState):
|
| """Change connection state"""
|
| if new_state != self.state:
|
| self.state = new_state
|
| if self.on_state_change:
|
| self.on_state_change(new_state)
|
|
|
| def _send_packet(self, flags: int, data: bytes = b''):
|
| """Send TCP packet"""
|
| segment = TCPSegment(
|
| seq_num=self.local_seq,
|
| ack_num=self.local_ack,
|
| flags=flags,
|
| window=self.local_window,
|
| data=data
|
| )
|
|
|
|
|
| if data or flags != 0x10:
|
| self.unacked_segments[self.local_seq] = segment
|
|
|
|
|
| self.local_seq = (self.local_seq + len(data)) % 0x100000000
|
| if flags & 0x02:
|
| self.local_seq = (self.local_seq + 1) % 0x100000000
|
| if flags & 0x01:
|
| self.local_seq = (self.local_seq + 1) % 0x100000000
|
|
|
|
|
|
|
| def _handle_listen(self, header: TCPHeader, data: bytes):
|
| """Handle LISTEN state"""
|
| if header.syn:
|
| self.remote_seq = header.seq_num
|
| self.local_ack = (header.seq_num + 1) % 0x100000000
|
| self._send_packet(0x12)
|
| self._set_state(TCPState.SYN_RECEIVED)
|
|
|
| def _handle_syn_sent(self, header: TCPHeader, data: bytes):
|
| """Handle SYN_SENT state"""
|
| if header.syn and header.ack:
|
| if header.ack_num == (self.initial_seq + 1) % 0x100000000:
|
| self.remote_seq = header.seq_num
|
| self.local_ack = (header.seq_num + 1) % 0x100000000
|
| self._send_packet(0x10)
|
| self._set_state(TCPState.ESTABLISHED)
|
|
|
| def _handle_established(self, header: TCPHeader, data: bytes):
|
| """Handle ESTABLISHED state"""
|
| if data:
|
| if header.seq_num == self.local_ack:
|
|
|
| if self.on_data_received:
|
| self.on_data_received(data)
|
| self.local_ack = (self.local_ack + len(data)) % 0x100000000
|
| self._send_packet(0x10)
|
| elif header.seq_num > self.local_ack:
|
|
|
| self.out_of_order_buffer[header.seq_num] = data
|
| self._send_packet(0x10)
|
| else:
|
|
|
| self._send_packet(0x10)
|
|
|
| if header.ack:
|
|
|
| self._handle_ack(header.ack_num)
|
|
|
| if header.fin:
|
| self.local_ack = (self.local_ack + 1) % 0x100000000
|
| self._send_packet(0x10)
|
| self._set_state(TCPState.CLOSE_WAIT)
|
|
|
| def _handle_ack(self, ack_num: int):
|
| """Handle incoming acknowledgment"""
|
|
|
| acknowledged = [seq for seq in self.unacked_segments.keys()
|
| if seq < ack_num]
|
| for seq in acknowledged:
|
| del self.unacked_segments[seq]
|
|
|
|
|
| if self.cwnd < self.ssthresh:
|
|
|
| self.cwnd += 1
|
| else:
|
|
|
| self.cwnd += 1 / self.cwnd
|
|
|
|
|
| self._send_from_buffer()
|
|
|
| def _send_from_buffer(self):
|
| """Send data from send buffer"""
|
| while self.send_buffer:
|
|
|
| window = min(self.remote_window, self.cwnd * self.mss)
|
| if not window:
|
| break
|
|
|
|
|
| data = bytes(list(self.send_buffer)[:window])
|
| if not data:
|
| break
|
|
|
|
|
| for _ in range(len(data)):
|
| self.send_buffer.popleft()
|
| self._send_packet(0x18, data)
|
|
|
| def _update_rtt(self, ack_num: int):
|
| """Update RTT estimation"""
|
| for seq, segment in self.unacked_segments.items():
|
| if seq == ack_num - 1:
|
| rtt = time.time() - segment.timestamp
|
| if self.srtt == 0:
|
| self.srtt = rtt
|
| self.rttvar = rtt / 2
|
| else:
|
| self.rttvar = (0.75 * self.rttvar +
|
| 0.25 * abs(self.srtt - rtt))
|
| self.srtt = 0.875 * self.srtt + 0.125 * rtt
|
| self.rto = self.srtt + max(4 * self.rttvar, 0.5)
|
| break
|
|
|
| def _manage_retransmission_timer(self):
|
| """Manage retransmission timer"""
|
| if not self.unacked_segments:
|
| self.retransmit_timer = None
|
| return
|
|
|
| current_time = time.time()
|
| if self.retransmit_timer is None:
|
| self.retransmit_timer = current_time + self.rto
|
| elif current_time >= self.retransmit_timer:
|
|
|
| self._handle_timeout()
|
|
|
| def _handle_timeout(self):
|
| """Handle retransmission timeout"""
|
|
|
| self.rto *= 2
|
|
|
|
|
| self.ssthresh = max(2, self.cwnd // 2)
|
| self.cwnd = 1
|
|
|
|
|
| if self.unacked_segments:
|
| oldest_seq = min(self.unacked_segments.keys())
|
| segment = self.unacked_segments[oldest_seq]
|
| if segment.retransmit_count < 5:
|
| segment.retransmit_count += 1
|
| self._send_packet(segment.flags, segment.data)
|
| else:
|
|
|
| self._set_state(TCPState.CLOSED)
|
|
|
|
|
| self.retransmit_timer = time.time() + self.rto
|
|
|
| def _send_fin(self):
|
| """Send FIN packet"""
|
| self._send_packet(0x11)
|
|
|
| def _handle_fin_wait(self, header: TCPHeader, data: bytes):
|
| """Handle FIN_WAIT states"""
|
| if self.state == TCPState.FIN_WAIT_1:
|
| if header.ack and header.ack_num == self.local_seq:
|
| self._set_state(TCPState.FIN_WAIT_2)
|
|
|
| if header.fin:
|
| self.local_ack = (header.seq_num + 1) % 0x100000000
|
| self._send_packet(0x10)
|
| if self.state == TCPState.FIN_WAIT_1:
|
| self._set_state(TCPState.CLOSING)
|
| else:
|
| self._set_state(TCPState.TIME_WAIT)
|
|
|
| def _handle_close_wait(self, header: TCPHeader, data: bytes):
|
| """Handle CLOSE_WAIT state"""
|
| if header.ack:
|
| self._handle_ack(header.ack_num)
|
|
|
| def _handle_last_ack(self, header: TCPHeader, data: bytes):
|
| """Handle LAST_ACK state"""
|
| if header.ack and header.ack_num == self.local_seq:
|
| self._set_state(TCPState.CLOSED)
|
|
|