HINTECH / core /tcp_engine.py
Factor Studios
Upload 73 files
aaaaa79 verified
"""
TCP Engine Module
Implements a complete TCP state machine in user-space:
- Full TCP state machine (SYN, SYN-ACK, ESTABLISHED, FIN, RST)
- Sequence and acknowledgment number tracking
- Sliding window implementation
- Retransmission and timeout handling
- Congestion control
"""
import time
import threading
import random
from typing import Dict, List, Optional, Tuple, Callable
from dataclasses import dataclass, field
from enum import Enum
from collections import deque
from .ip_parser import TCPHeader, IPv4Header, IPParser
class TCPState(Enum):
CLOSED = "CLOSED"
LISTEN = "LISTEN"
SYN_SENT = "SYN_SENT"
SYN_RECEIVED = "SYN_RECEIVED"
ESTABLISHED = "ESTABLISHED"
FIN_WAIT_1 = "FIN_WAIT_1"
FIN_WAIT_2 = "FIN_WAIT_2"
CLOSE_WAIT = "CLOSE_WAIT"
CLOSING = "CLOSING"
LAST_ACK = "LAST_ACK"
TIME_WAIT = "TIME_WAIT"
@dataclass
class TCPSegment:
"""Represents a TCP segment"""
seq_num: int
ack_num: int
flags: int
window: int
data: bytes
timestamp: float = field(default_factory=time.time)
retransmit_count: int = 0
@property
def data_length(self) -> int:
"""Get data length"""
return len(self.data)
@property
def seq_end(self) -> int:
"""Get sequence number after this segment"""
length = self.data_length
# SYN and FIN consume one sequence number
if self.flags & 0x02: # SYN
length += 1
if self.flags & 0x01: # FIN
length += 1
return self.seq_num + length
@dataclass
class TCPConnection:
"""Represents a TCP connection state"""
# Connection identification
local_ip: str
local_port: int
remote_ip: str
remote_port: int
# State
state: TCPState = TCPState.CLOSED
# Sequence numbers
local_seq: int = 0
local_ack: int = 0
remote_seq: int = 0
remote_ack: int = 0
initial_seq: int = 0
# Window management
local_window: int = 65535
remote_window: int = 65535
window_scale: int = 0
# Buffers
send_buffer: deque = field(default_factory=deque)
recv_buffer: deque = field(default_factory=deque)
out_of_order_buffer: Dict[int, bytes] = field(default_factory=dict)
# Retransmission
unacked_segments: Dict[int, TCPSegment] = field(default_factory=dict)
retransmit_timer: Optional[float] = None
rto: float = 1.0 # Retransmission timeout
srtt: float = 0.0 # Smoothed round-trip time
rttvar: float = 0.0 # Round-trip time variation
# Congestion control
cwnd: int = 1 # Congestion window (in MSS units)
ssthresh: int = 65535 # Slow start threshold
mss: int = 1460 # Maximum segment size
# Timers
last_activity: float = field(default_factory=time.time)
time_wait_start: Optional[float] = None
# Callbacks
on_data_received: Optional[Callable[[bytes], None]] = None
on_connection_closed: Optional[Callable[[], None]] = None
@property
def connection_id(self) -> str:
"""Get unique connection identifier"""
return f"{self.local_ip}:{self.local_port}-{self.remote_ip}:{self.remote_port}"
@property
def is_established(self) -> bool:
"""Check if connection is established"""
return self.state == TCPState.ESTABLISHED
@property
def can_send_data(self) -> bool:
"""Check if connection can send data"""
return self.state in [TCPState.ESTABLISHED, TCPState.CLOSE_WAIT]
@property
def effective_window(self) -> int:
"""Get effective send window"""
return min(self.remote_window, self.cwnd * self.mss)
class TCPEngine:
"""TCP state machine implementation"""
def __init__(self, config: Dict):
self.config = config
self.connections: Dict[str, TCPConnection] = {}
self.listening_ports: Dict[int, Callable] = {} # port -> accept callback
self.lock = threading.Lock()
self.running = False
self.timer_thread = None
# Default configuration
self.default_mss = config.get('mss', 1460)
self.default_window = config.get('initial_window', 65535)
self.max_retries = config.get('max_retries', 3)
self.connection_timeout = config.get('timeout', 300)
self.time_wait_timeout = config.get('time_wait_timeout', 120)
def _generate_isn(self) -> int:
"""Generate Initial Sequence Number"""
return random.randint(0, 0xFFFFFFFF)
def _get_connection_key(self, local_ip: str, local_port: int, remote_ip: str, remote_port: int) -> str:
"""Get connection key"""
return f"{local_ip}:{local_port}-{remote_ip}:{remote_port}"
def _create_tcp_segment(self, conn: TCPConnection, flags: int, data: bytes = b'') -> TCPSegment:
"""Create TCP segment"""
segment = TCPSegment(
seq_num=conn.local_seq,
ack_num=conn.local_ack,
flags=flags,
window=conn.local_window,
data=data
)
return segment
def _build_tcp_packet(self, conn: TCPConnection, segment: TCPSegment) -> bytes:
"""Build complete TCP packet"""
# Create IP header
ip_header = IPv4Header(
protocol=6, # TCP
source_ip=conn.local_ip,
dest_ip=conn.remote_ip,
ttl=64
)
# Create TCP header
tcp_header = TCPHeader(
source_port=conn.local_port,
dest_port=conn.remote_port,
seq_num=segment.seq_num,
ack_num=segment.ack_num,
flags=segment.flags,
window_size=segment.window
)
# Build packet
return IPParser.build_packet(ip_header, tcp_header, segment.data)
def _update_rto(self, conn: TCPConnection, rtt: float):
"""Update retransmission timeout using RFC 6298"""
if conn.srtt == 0:
# First RTT measurement
conn.srtt = rtt
conn.rttvar = rtt / 2
else:
# Subsequent measurements
alpha = 0.125
beta = 0.25
conn.rttvar = (1 - beta) * conn.rttvar + beta * abs(conn.srtt - rtt)
conn.srtt = (1 - alpha) * conn.srtt + alpha * rtt
# Calculate RTO
conn.rto = max(1.0, conn.srtt + 4 * conn.rttvar)
conn.rto = min(conn.rto, 60.0) # Cap at 60 seconds
def _update_congestion_window(self, conn: TCPConnection, acked_bytes: int):
"""Update congestion window (simplified congestion control)"""
if conn.cwnd < conn.ssthresh:
# Slow start
conn.cwnd += 1
else:
# Congestion avoidance
conn.cwnd += max(1, conn.mss * conn.mss // conn.cwnd)
def _handle_retransmission(self, conn: TCPConnection):
"""Handle segment retransmission"""
current_time = time.time()
# Find segments that need retransmission
to_retransmit = []
for seq_num, segment in conn.unacked_segments.items():
if current_time - segment.timestamp > conn.rto:
if segment.retransmit_count < self.max_retries:
to_retransmit.append(segment)
else:
# Max retries exceeded, close connection
self._close_connection(conn, reset=True)
return
# Retransmit segments
for segment in to_retransmit:
segment.retransmit_count += 1
segment.timestamp = current_time
# Exponential backoff
conn.rto = min(conn.rto * 2, 60.0)
# Congestion control: reduce window
conn.ssthresh = max(conn.cwnd // 2, 2)
conn.cwnd = 1
# Send retransmitted segment
packet = self._build_tcp_packet(conn, segment)
self._send_packet(packet)
def _send_packet(self, packet: bytes):
"""Send packet (to be implemented by integration layer)"""
# This will be connected to the packet bridge
pass
def _close_connection(self, conn: TCPConnection, reset: bool = False):
"""Close connection"""
if reset:
# Send RST
segment = self._create_tcp_segment(conn, 0x04) # RST flag
packet = self._build_tcp_packet(conn, segment)
self._send_packet(packet)
conn.state = TCPState.CLOSED
else:
# Normal close
if conn.state == TCPState.ESTABLISHED:
# Send FIN
segment = self._create_tcp_segment(conn, 0x01) # FIN flag
packet = self._build_tcp_packet(conn, segment)
self._send_packet(packet)
conn.local_seq += 1
conn.state = TCPState.FIN_WAIT_1
# Cleanup if closed
if conn.state == TCPState.CLOSED:
if conn.on_connection_closed:
conn.on_connection_closed()
with self.lock:
if conn.connection_id in self.connections:
del self.connections[conn.connection_id]
def listen(self, port: int, accept_callback: Callable):
"""Listen on port for incoming connections"""
with self.lock:
self.listening_ports[port] = accept_callback
def connect(self, local_ip: str, local_port: int, remote_ip: str, remote_port: int) -> Optional[TCPConnection]:
"""Initiate outbound connection"""
conn_key = self._get_connection_key(local_ip, local_port, remote_ip, remote_port)
# Create connection
conn = TCPConnection(
local_ip=local_ip,
local_port=local_port,
remote_ip=remote_ip,
remote_port=remote_port,
state=TCPState.SYN_SENT,
local_seq=self._generate_isn(),
mss=self.default_mss,
local_window=self.default_window
)
conn.initial_seq = conn.local_seq
with self.lock:
self.connections[conn_key] = conn
# Send SYN
segment = self._create_tcp_segment(conn, 0x02) # SYN flag
packet = self._build_tcp_packet(conn, segment)
self._send_packet(packet)
# Track unacked segment
conn.unacked_segments[conn.local_seq] = segment
conn.local_seq += 1
conn.retransmit_timer = time.time()
return conn
def send_data(self, conn: TCPConnection, data: bytes) -> bool:
"""Send data on established connection"""
if not conn.can_send_data:
return False
# Add to send buffer
conn.send_buffer.append(data)
# Try to send immediately
self._try_send_data(conn)
return True
def _try_send_data(self, conn: TCPConnection):
"""Try to send buffered data"""
while conn.send_buffer and len(conn.unacked_segments) * conn.mss < conn.effective_window:
data = conn.send_buffer.popleft()
# Split data if larger than MSS
while data:
chunk = data[:conn.mss]
data = data[conn.mss:]
# Create and send segment
segment = self._create_tcp_segment(conn, 0x18, chunk) # PSH+ACK flags
packet = self._build_tcp_packet(conn, segment)
self._send_packet(packet)
# Track unacked segment
conn.unacked_segments[conn.local_seq] = segment
conn.local_seq += len(chunk)
if not data:
break
def process_packet(self, packet_data: bytes) -> bool:
"""Process incoming TCP packet"""
try:
# Parse packet
parsed = IPParser.parse_packet(packet_data)
if not isinstance(parsed.transport_header, TCPHeader):
return False
ip_header = parsed.ip_header
tcp_header = parsed.transport_header
payload = parsed.payload
# Find or create connection
conn_key = self._get_connection_key(
ip_header.dest_ip, tcp_header.dest_port,
ip_header.source_ip, tcp_header.source_port
)
with self.lock:
conn = self.connections.get(conn_key)
# Handle new connection (SYN to listening port)
if not conn and tcp_header.syn and not tcp_header.ack:
if tcp_header.dest_port in self.listening_ports:
conn = self._handle_new_connection(ip_header, tcp_header)
if conn:
self.connections[conn_key] = conn
if not conn:
# Send RST for unknown connection
self._send_rst(ip_header, tcp_header)
return False
# Process segment
return self._process_segment(conn, tcp_header, payload)
except Exception as e:
print(f"Error processing TCP packet: {e}")
return False
def _handle_new_connection(self, ip_header: IPv4Header, tcp_header: TCPHeader) -> Optional[TCPConnection]:
"""Handle new incoming connection"""
accept_callback = self.listening_ports.get(tcp_header.dest_port)
if not accept_callback:
return None
# Create connection
conn = TCPConnection(
local_ip=ip_header.dest_ip,
local_port=tcp_header.dest_port,
remote_ip=ip_header.source_ip,
remote_port=tcp_header.source_port,
state=TCPState.SYN_RECEIVED,
local_seq=self._generate_isn(),
remote_seq=tcp_header.seq_num,
local_ack=tcp_header.seq_num + 1,
mss=self.default_mss,
local_window=self.default_window
)
conn.initial_seq = conn.local_seq
# Send SYN-ACK
segment = self._create_tcp_segment(conn, 0x12) # SYN+ACK flags
packet = self._build_tcp_packet(conn, segment)
self._send_packet(packet)
# Track unacked segment
conn.unacked_segments[conn.local_seq] = segment
conn.local_seq += 1
conn.retransmit_timer = time.time()
# Call accept callback
accept_callback(conn)
return conn
def _process_segment(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool:
"""Process TCP segment based on connection state"""
conn.last_activity = time.time()
# Handle RST
if tcp_header.rst:
conn.state = TCPState.CLOSED
self._close_connection(conn)
return True
# State machine
if conn.state == TCPState.SYN_SENT:
return self._handle_syn_sent(conn, tcp_header, payload)
elif conn.state == TCPState.SYN_RECEIVED:
return self._handle_syn_received(conn, tcp_header, payload)
elif conn.state == TCPState.ESTABLISHED:
return self._handle_established(conn, tcp_header, payload)
elif conn.state == TCPState.FIN_WAIT_1:
return self._handle_fin_wait_1(conn, tcp_header, payload)
elif conn.state == TCPState.FIN_WAIT_2:
return self._handle_fin_wait_2(conn, tcp_header, payload)
elif conn.state == TCPState.CLOSE_WAIT:
return self._handle_close_wait(conn, tcp_header, payload)
elif conn.state == TCPState.CLOSING:
return self._handle_closing(conn, tcp_header, payload)
elif conn.state == TCPState.LAST_ACK:
return self._handle_last_ack(conn, tcp_header, payload)
elif conn.state == TCPState.TIME_WAIT:
return self._handle_time_wait(conn, tcp_header, payload)
return False
def _handle_syn_sent(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool:
"""Handle segment in SYN_SENT state"""
if tcp_header.syn and tcp_header.ack:
# SYN-ACK received
if tcp_header.ack_num == conn.local_seq:
conn.remote_seq = tcp_header.seq_num
conn.local_ack = tcp_header.seq_num + 1
conn.remote_window = tcp_header.window_size
# Remove SYN from unacked segments
if conn.local_seq - 1 in conn.unacked_segments:
del conn.unacked_segments[conn.local_seq - 1]
# Send ACK
segment = self._create_tcp_segment(conn, 0x10) # ACK flag
packet = self._build_tcp_packet(conn, segment)
self._send_packet(packet)
conn.state = TCPState.ESTABLISHED
return True
return False
def _handle_syn_received(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool:
"""Handle segment in SYN_RECEIVED state"""
if tcp_header.ack and tcp_header.ack_num == conn.local_seq:
# ACK for our SYN-ACK
conn.remote_window = tcp_header.window_size
# Remove SYN-ACK from unacked segments
if conn.local_seq - 1 in conn.unacked_segments:
del conn.unacked_segments[conn.local_seq - 1]
conn.state = TCPState.ESTABLISHED
return True
return False
def _handle_established(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool:
"""Handle segment in ESTABLISHED state"""
# Handle ACK
if tcp_header.ack:
self._process_ack(conn, tcp_header.ack_num)
# Handle data
if payload and tcp_header.seq_num == conn.local_ack:
conn.local_ack += len(payload)
# Deliver data
if conn.on_data_received:
conn.on_data_received(payload)
# Send ACK
segment = self._create_tcp_segment(conn, 0x10) # ACK flag
packet = self._build_tcp_packet(conn, segment)
self._send_packet(packet)
# Handle FIN
if tcp_header.fin:
conn.local_ack += 1
# Send ACK
segment = self._create_tcp_segment(conn, 0x10) # ACK flag
packet = self._build_tcp_packet(conn, segment)
self._send_packet(packet)
conn.state = TCPState.CLOSE_WAIT
return True
def _handle_fin_wait_1(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool:
"""Handle segment in FIN_WAIT_1 state"""
if tcp_header.ack:
self._process_ack(conn, tcp_header.ack_num)
if not conn.unacked_segments: # Our FIN was ACKed
conn.state = TCPState.FIN_WAIT_2
if tcp_header.fin:
conn.local_ack += 1
# Send ACK
segment = self._create_tcp_segment(conn, 0x10) # ACK flag
packet = self._build_tcp_packet(conn, segment)
self._send_packet(packet)
if conn.state == TCPState.FIN_WAIT_2:
conn.state = TCPState.TIME_WAIT
conn.time_wait_start = time.time()
else:
conn.state = TCPState.CLOSING
return True
def _handle_fin_wait_2(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool:
"""Handle segment in FIN_WAIT_2 state"""
if tcp_header.fin:
conn.local_ack += 1
# Send ACK
segment = self._create_tcp_segment(conn, 0x10) # ACK flag
packet = self._build_tcp_packet(conn, segment)
self._send_packet(packet)
conn.state = TCPState.TIME_WAIT
conn.time_wait_start = time.time()
return True
def _handle_close_wait(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool:
"""Handle segment in CLOSE_WAIT state"""
# Application should close the connection
return True
def _handle_closing(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool:
"""Handle segment in CLOSING state"""
if tcp_header.ack:
self._process_ack(conn, tcp_header.ack_num)
if not conn.unacked_segments: # Our FIN was ACKed
conn.state = TCPState.TIME_WAIT
conn.time_wait_start = time.time()
return True
def _handle_last_ack(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool:
"""Handle segment in LAST_ACK state"""
if tcp_header.ack:
self._process_ack(conn, tcp_header.ack_num)
if not conn.unacked_segments: # Our FIN was ACKed
conn.state = TCPState.CLOSED
self._close_connection(conn)
return True
def _handle_time_wait(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool:
"""Handle segment in TIME_WAIT state"""
# Just acknowledge any segments
if tcp_header.seq_num == conn.local_ack:
segment = self._create_tcp_segment(conn, 0x10) # ACK flag
packet = self._build_tcp_packet(conn, segment)
self._send_packet(packet)
return True
def _process_ack(self, conn: TCPConnection, ack_num: int):
"""Process ACK and remove acknowledged segments"""
acked_segments = []
acked_bytes = 0
for seq_num, segment in list(conn.unacked_segments.items()):
if seq_num < ack_num:
acked_segments.append((seq_num, segment))
acked_bytes += segment.data_length
del conn.unacked_segments[seq_num]
# Update RTT and congestion window
if acked_segments:
# Use first acked segment for RTT calculation
rtt = time.time() - acked_segments[0][1].timestamp
self._update_rto(conn, rtt)
self._update_congestion_window(conn, acked_bytes)
# Try to send more data
self._try_send_data(conn)
def _send_rst(self, ip_header: IPv4Header, tcp_header: TCPHeader):
"""Send RST for unknown connection"""
# Create RST response
rst_ip = IPv4Header(
protocol=6,
source_ip=ip_header.dest_ip,
dest_ip=ip_header.source_ip,
ttl=64
)
rst_tcp = TCPHeader(
source_port=tcp_header.dest_port,
dest_port=tcp_header.source_port,
seq_num=tcp_header.ack_num if tcp_header.ack else 0,
ack_num=tcp_header.seq_num + 1 if tcp_header.syn else tcp_header.seq_num,
flags=0x14 if tcp_header.ack else 0x04 # RST+ACK or RST
)
packet = IPParser.build_packet(rst_ip, rst_tcp)
self._send_packet(packet)
def _timer_loop(self):
"""Timer loop for handling timeouts"""
while self.running:
current_time = time.time()
with self.lock:
connections_to_check = list(self.connections.values())
for conn in connections_to_check:
# Handle retransmissions
if conn.unacked_segments:
self._handle_retransmission(conn)
# Handle connection timeout
if current_time - conn.last_activity > self.connection_timeout:
self._close_connection(conn, reset=True)
# Handle TIME_WAIT timeout
if (conn.state == TCPState.TIME_WAIT and
conn.time_wait_start and
current_time - conn.time_wait_start > self.time_wait_timeout):
conn.state = TCPState.CLOSED
self._close_connection(conn)
time.sleep(1) # Check every second
def start(self):
"""Start TCP engine"""
self.running = True
self.timer_thread = threading.Thread(target=self._timer_loop, daemon=True)
self.timer_thread.start()
print("TCP engine started")
def stop(self):
"""Stop TCP engine"""
self.running = False
if self.timer_thread:
self.timer_thread.join()
# Close all connections
with self.lock:
for conn in list(self.connections.values()):
self._close_connection(conn, reset=True)
print("TCP engine stopped")
def get_connections(self) -> Dict[str, Dict]:
"""Get current connections"""
with self.lock:
return {
conn_id: {
'local_ip': conn.local_ip,
'local_port': conn.local_port,
'remote_ip': conn.remote_ip,
'remote_port': conn.remote_port,
'state': conn.state.value,
'local_seq': conn.local_seq,
'local_ack': conn.local_ack,
'remote_seq': conn.remote_seq,
'remote_ack': conn.remote_ack,
'window_size': conn.local_window,
'cwnd': conn.cwnd,
'unacked_segments': len(conn.unacked_segments),
'last_activity': conn.last_activity
}
for conn_id, conn in self.connections.items()
}