JRNET / core /tcp_engine.py
Factor Studios
Upload 96 files
6a5b8d8 verified
"""
TCP Engine Module
Implements a complete TCP state machine in user-space:
- Full TCP state machine (SYN, SYN-ACK, ESTABLISHED, FIN, RST)
- Sequence and acknowledgment number tracking
- Sliding window implementation
- Retransmission and timeout handling
- Congestion control
"""
import time
import threading
import random
from typing import Dict, List, Optional, Tuple, Callable
from dataclasses import dataclass, field
from enum import Enum
from collections import deque
from .ip_parser import TCPHeader, IPv4Header, IPParser
class TCPState(Enum):
CLOSED = "CLOSED"
LISTEN = "LISTEN"
SYN_SENT = "SYN_SENT"
SYN_RECEIVED = "SYN_RECEIVED"
ESTABLISHED = "ESTABLISHED"
FIN_WAIT_1 = "FIN_WAIT_1"
FIN_WAIT_2 = "FIN_WAIT_2"
CLOSE_WAIT = "CLOSE_WAIT"
CLOSING = "CLOSING"
LAST_ACK = "LAST_ACK"
TIME_WAIT = "TIME_WAIT"
@dataclass
class TCPSegment:
"""Represents a TCP segment"""
seq_num: int
ack_num: int
flags: int
window: int
data: bytes
timestamp: float = field(default_factory=time.time)
retransmit_count: int = 0
@property
def data_length(self) -> int:
"""Get data length"""
return len(self.data)
@property
def seq_end(self) -> int:
"""Get sequence number after this segment"""
length = self.data_length
# SYN and FIN consume one sequence number
if self.flags & 0x02: # SYN
length += 1
if self.flags & 0x01: # FIN
length += 1
return self.seq_num + length
@dataclass
class TCPConnection:
"""Represents a TCP connection state"""
# Connection identification
local_ip: str
local_port: int
remote_ip: str
remote_port: int
# State
state: TCPState = TCPState.CLOSED
# Sequence numbers
local_seq: int = field(default_factory=lambda: random.randint(0, 0xFFFFFFFF))
local_ack: int = 0
remote_seq: int = 0
remote_ack: int = 0
initial_seq: int = 0
# Window management
local_window: int = 65535
remote_window: int = 65535
window_scale: int = 0
# Buffers
send_buffer: deque = field(default_factory=deque)
recv_buffer: deque = field(default_factory=deque)
out_of_order_buffer: Dict[int, bytes] = field(default_factory=dict)
# Retransmission
unacked_segments: Dict[int, TCPSegment] = field(default_factory=dict)
retransmit_timer: Optional[float] = None
rto: float = 1.0 # Retransmission timeout
srtt: float = 0.0 # Smoothed round-trip time
rttvar: float = 0.0 # Round-trip time variation
# Congestion control
cwnd: int = 1 # Congestion window (in MSS)
ssthresh: int = 65535 # Slow start threshold
dupacks: int = 0 # Duplicate ACK count
mss: int = 1460 # Maximum segment size
# Callbacks
on_data_received: Optional[Callable[[bytes], None]] = None
on_state_change: Optional[Callable[[TCPState], None]] = None
def __post_init__(self):
self.initial_seq = self.local_seq
def handle_packet(self, packet: bytes):
"""Process incoming TCP packet"""
try:
# Parse headers
ip_header, payload = IPParser.parse_ipv4_header(packet)
tcp_header, data = IPParser.parse_tcp_header(payload)
# Process based on current state
if self.state == TCPState.LISTEN:
self._handle_listen(tcp_header, data)
elif self.state == TCPState.SYN_SENT:
self._handle_syn_sent(tcp_header, data)
elif self.state == TCPState.SYN_RECEIVED:
self._handle_syn_received(tcp_header, data)
elif self.state == TCPState.ESTABLISHED:
self._handle_established(tcp_header, data)
elif self.state in (TCPState.FIN_WAIT_1, TCPState.FIN_WAIT_2):
self._handle_fin_wait(tcp_header, data)
elif self.state == TCPState.CLOSE_WAIT:
self._handle_close_wait(tcp_header, data)
elif self.state == TCPState.LAST_ACK:
self._handle_last_ack(tcp_header, data)
# Update RTT if this is an ACK for a sent packet
if tcp_header.ack and tcp_header.ack_num > self.local_seq:
self._update_rtt(tcp_header.ack_num)
# Handle retransmission timer
self._manage_retransmission_timer()
except Exception as e:
print(f"Error handling packet: {e}")
def send_data(self, data: bytes):
"""Send data over the connection"""
if self.state != TCPState.ESTABLISHED:
return False
# Add to send buffer
self.send_buffer.extend(data)
# Try to send what we can
self._send_from_buffer()
return True
def close(self):
"""Initiate connection close"""
if self.state == TCPState.ESTABLISHED:
self._send_fin()
self._set_state(TCPState.FIN_WAIT_1)
elif self.state == TCPState.CLOSE_WAIT:
self._send_fin()
self._set_state(TCPState.LAST_ACK)
def _set_state(self, new_state: TCPState):
"""Change connection state"""
if new_state != self.state:
self.state = new_state
if self.on_state_change:
self.on_state_change(new_state)
def _send_packet(self, flags: int, data: bytes = b''):
"""Send TCP packet"""
segment = TCPSegment(
seq_num=self.local_seq,
ack_num=self.local_ack,
flags=flags,
window=self.local_window,
data=data
)
# Add to unacked segments if not pure ACK
if data or flags != 0x10: # Not pure ACK
self.unacked_segments[self.local_seq] = segment
# Update sequence number
self.local_seq = (self.local_seq + len(data)) % 0x100000000
if flags & 0x02: # SYN
self.local_seq = (self.local_seq + 1) % 0x100000000
if flags & 0x01: # FIN
self.local_seq = (self.local_seq + 1) % 0x100000000
# TODO: Actually send the packet
def _handle_listen(self, header: TCPHeader, data: bytes):
"""Handle LISTEN state"""
if header.syn:
self.remote_seq = header.seq_num
self.local_ack = (header.seq_num + 1) % 0x100000000
self._send_packet(0x12) # SYN-ACK
self._set_state(TCPState.SYN_RECEIVED)
def _handle_syn_sent(self, header: TCPHeader, data: bytes):
"""Handle SYN_SENT state"""
if header.syn and header.ack:
if header.ack_num == (self.initial_seq + 1) % 0x100000000:
self.remote_seq = header.seq_num
self.local_ack = (header.seq_num + 1) % 0x100000000
self._send_packet(0x10) # ACK
self._set_state(TCPState.ESTABLISHED)
def _handle_established(self, header: TCPHeader, data: bytes):
"""Handle ESTABLISHED state"""
if data:
if header.seq_num == self.local_ack:
# In-order segment
if self.on_data_received:
self.on_data_received(data)
self.local_ack = (self.local_ack + len(data)) % 0x100000000
self._send_packet(0x10) # ACK
elif header.seq_num > self.local_ack:
# Out-of-order segment
self.out_of_order_buffer[header.seq_num] = data
self._send_packet(0x10) # ACK
else:
# Duplicate segment
self._send_packet(0x10) # ACK
if header.ack:
# Process acknowledgments
self._handle_ack(header.ack_num)
if header.fin:
self.local_ack = (self.local_ack + 1) % 0x100000000
self._send_packet(0x10) # ACK
self._set_state(TCPState.CLOSE_WAIT)
def _handle_ack(self, ack_num: int):
"""Handle incoming acknowledgment"""
# Remove acknowledged segments
acknowledged = [seq for seq in self.unacked_segments.keys()
if seq < ack_num]
for seq in acknowledged:
del self.unacked_segments[seq]
# Update congestion window
if self.cwnd < self.ssthresh:
# Slow start
self.cwnd += 1
else:
# Congestion avoidance
self.cwnd += 1 / self.cwnd
# Try to send more data
self._send_from_buffer()
def _send_from_buffer(self):
"""Send data from send buffer"""
while self.send_buffer:
# Calculate how much we can send
window = min(self.remote_window, self.cwnd * self.mss)
if not window:
break
# Get data to send
data = bytes(list(self.send_buffer)[:window])
if not data:
break
# Remove from buffer and send
for _ in range(len(data)):
self.send_buffer.popleft()
self._send_packet(0x18, data) # PSH-ACK
def _update_rtt(self, ack_num: int):
"""Update RTT estimation"""
for seq, segment in self.unacked_segments.items():
if seq == ack_num - 1:
rtt = time.time() - segment.timestamp
if self.srtt == 0:
self.srtt = rtt
self.rttvar = rtt / 2
else:
self.rttvar = (0.75 * self.rttvar +
0.25 * abs(self.srtt - rtt))
self.srtt = 0.875 * self.srtt + 0.125 * rtt
self.rto = self.srtt + max(4 * self.rttvar, 0.5)
break
def _manage_retransmission_timer(self):
"""Manage retransmission timer"""
if not self.unacked_segments:
self.retransmit_timer = None
return
current_time = time.time()
if self.retransmit_timer is None:
self.retransmit_timer = current_time + self.rto
elif current_time >= self.retransmit_timer:
# Timeout occurred
self._handle_timeout()
def _handle_timeout(self):
"""Handle retransmission timeout"""
# Exponential backoff
self.rto *= 2
# Reset congestion window
self.ssthresh = max(2, self.cwnd // 2)
self.cwnd = 1
# Retransmit oldest unacked segment
if self.unacked_segments:
oldest_seq = min(self.unacked_segments.keys())
segment = self.unacked_segments[oldest_seq]
if segment.retransmit_count < 5:
segment.retransmit_count += 1
self._send_packet(segment.flags, segment.data)
else:
# Too many retransmissions, close connection
self._set_state(TCPState.CLOSED)
# Reset timer
self.retransmit_timer = time.time() + self.rto
def _send_fin(self):
"""Send FIN packet"""
self._send_packet(0x11) # FIN-ACK
def _handle_fin_wait(self, header: TCPHeader, data: bytes):
"""Handle FIN_WAIT states"""
if self.state == TCPState.FIN_WAIT_1:
if header.ack and header.ack_num == self.local_seq:
self._set_state(TCPState.FIN_WAIT_2)
if header.fin:
self.local_ack = (header.seq_num + 1) % 0x100000000
self._send_packet(0x10) # ACK
if self.state == TCPState.FIN_WAIT_1:
self._set_state(TCPState.CLOSING)
else: # FIN_WAIT_2
self._set_state(TCPState.TIME_WAIT)
def _handle_close_wait(self, header: TCPHeader, data: bytes):
"""Handle CLOSE_WAIT state"""
if header.ack:
self._handle_ack(header.ack_num)
def _handle_last_ack(self, header: TCPHeader, data: bytes):
"""Handle LAST_ACK state"""
if header.ack and header.ack_num == self.local_seq:
self._set_state(TCPState.CLOSED)