|
|
""" |
|
|
TCP Engine Module |
|
|
|
|
|
Implements a complete TCP state machine in user-space: |
|
|
- Full TCP state machine (SYN, SYN-ACK, ESTABLISHED, FIN, RST) |
|
|
- Sequence and acknowledgment number tracking |
|
|
- Sliding window implementation |
|
|
- Retransmission and timeout handling |
|
|
- Congestion control |
|
|
""" |
|
|
|
|
|
import time |
|
|
import threading |
|
|
import random |
|
|
from typing import Dict, List, Optional, Tuple, Callable |
|
|
from dataclasses import dataclass, field |
|
|
from enum import Enum |
|
|
from collections import deque |
|
|
|
|
|
from .ip_parser import TCPHeader, IPv4Header, IPParser |
|
|
|
|
|
|
|
|
class TCPState(Enum): |
|
|
CLOSED = "CLOSED" |
|
|
LISTEN = "LISTEN" |
|
|
SYN_SENT = "SYN_SENT" |
|
|
SYN_RECEIVED = "SYN_RECEIVED" |
|
|
ESTABLISHED = "ESTABLISHED" |
|
|
FIN_WAIT_1 = "FIN_WAIT_1" |
|
|
FIN_WAIT_2 = "FIN_WAIT_2" |
|
|
CLOSE_WAIT = "CLOSE_WAIT" |
|
|
CLOSING = "CLOSING" |
|
|
LAST_ACK = "LAST_ACK" |
|
|
TIME_WAIT = "TIME_WAIT" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TCPSegment: |
|
|
"""Represents a TCP segment""" |
|
|
seq_num: int |
|
|
ack_num: int |
|
|
flags: int |
|
|
window: int |
|
|
data: bytes |
|
|
timestamp: float = field(default_factory=time.time) |
|
|
retransmit_count: int = 0 |
|
|
|
|
|
@property |
|
|
def data_length(self) -> int: |
|
|
"""Get data length""" |
|
|
return len(self.data) |
|
|
|
|
|
@property |
|
|
def seq_end(self) -> int: |
|
|
"""Get sequence number after this segment""" |
|
|
length = self.data_length |
|
|
|
|
|
if self.flags & 0x02: |
|
|
length += 1 |
|
|
if self.flags & 0x01: |
|
|
length += 1 |
|
|
return self.seq_num + length |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TCPConnection: |
|
|
"""Represents a TCP connection state""" |
|
|
|
|
|
local_ip: str |
|
|
local_port: int |
|
|
remote_ip: str |
|
|
remote_port: int |
|
|
|
|
|
|
|
|
state: TCPState = TCPState.CLOSED |
|
|
|
|
|
|
|
|
local_seq: int = 0 |
|
|
local_ack: int = 0 |
|
|
remote_seq: int = 0 |
|
|
remote_ack: int = 0 |
|
|
initial_seq: int = 0 |
|
|
|
|
|
|
|
|
local_window: int = 65535 |
|
|
remote_window: int = 65535 |
|
|
window_scale: int = 0 |
|
|
|
|
|
|
|
|
send_buffer: deque = field(default_factory=deque) |
|
|
recv_buffer: deque = field(default_factory=deque) |
|
|
out_of_order_buffer: Dict[int, bytes] = field(default_factory=dict) |
|
|
|
|
|
|
|
|
unacked_segments: Dict[int, TCPSegment] = field(default_factory=dict) |
|
|
retransmit_timer: Optional[float] = None |
|
|
rto: float = 1.0 |
|
|
srtt: float = 0.0 |
|
|
rttvar: float = 0.0 |
|
|
|
|
|
|
|
|
cwnd: int = 1 |
|
|
ssthresh: int = 65535 |
|
|
mss: int = 1460 |
|
|
|
|
|
|
|
|
last_activity: float = field(default_factory=time.time) |
|
|
time_wait_start: Optional[float] = None |
|
|
|
|
|
|
|
|
on_data_received: Optional[Callable[[bytes], None]] = None |
|
|
on_connection_closed: Optional[Callable[[], None]] = None |
|
|
|
|
|
@property |
|
|
def connection_id(self) -> str: |
|
|
"""Get unique connection identifier""" |
|
|
return f"{self.local_ip}:{self.local_port}-{self.remote_ip}:{self.remote_port}" |
|
|
|
|
|
@property |
|
|
def is_established(self) -> bool: |
|
|
"""Check if connection is established""" |
|
|
return self.state == TCPState.ESTABLISHED |
|
|
|
|
|
@property |
|
|
def can_send_data(self) -> bool: |
|
|
"""Check if connection can send data""" |
|
|
return self.state in [TCPState.ESTABLISHED, TCPState.CLOSE_WAIT] |
|
|
|
|
|
@property |
|
|
def effective_window(self) -> int: |
|
|
"""Get effective send window""" |
|
|
return min(self.remote_window, self.cwnd * self.mss) |
|
|
|
|
|
|
|
|
class TCPEngine: |
|
|
"""TCP state machine implementation""" |
|
|
|
|
|
def __init__(self, config: Dict): |
|
|
self.config = config |
|
|
self.connections: Dict[str, TCPConnection] = {} |
|
|
self.listening_ports: Dict[int, Callable] = {} |
|
|
self.lock = threading.Lock() |
|
|
self.running = False |
|
|
self.timer_thread = None |
|
|
|
|
|
|
|
|
self.default_mss = config.get('mss', 1460) |
|
|
self.default_window = config.get('initial_window', 65535) |
|
|
self.max_retries = config.get('max_retries', 3) |
|
|
self.connection_timeout = config.get('timeout', 300) |
|
|
self.time_wait_timeout = config.get('time_wait_timeout', 120) |
|
|
|
|
|
def _generate_isn(self) -> int: |
|
|
"""Generate Initial Sequence Number""" |
|
|
return random.randint(0, 0xFFFFFFFF) |
|
|
|
|
|
def _get_connection_key(self, local_ip: str, local_port: int, remote_ip: str, remote_port: int) -> str: |
|
|
"""Get connection key""" |
|
|
return f"{local_ip}:{local_port}-{remote_ip}:{remote_port}" |
|
|
|
|
|
def _create_tcp_segment(self, conn: TCPConnection, flags: int, data: bytes = b'') -> TCPSegment: |
|
|
"""Create TCP segment""" |
|
|
segment = TCPSegment( |
|
|
seq_num=conn.local_seq, |
|
|
ack_num=conn.local_ack, |
|
|
flags=flags, |
|
|
window=conn.local_window, |
|
|
data=data |
|
|
) |
|
|
return segment |
|
|
|
|
|
def _build_tcp_packet(self, conn: TCPConnection, segment: TCPSegment) -> bytes: |
|
|
"""Build complete TCP packet""" |
|
|
|
|
|
ip_header = IPv4Header( |
|
|
protocol=6, |
|
|
source_ip=conn.local_ip, |
|
|
dest_ip=conn.remote_ip, |
|
|
ttl=64 |
|
|
) |
|
|
|
|
|
|
|
|
tcp_header = TCPHeader( |
|
|
source_port=conn.local_port, |
|
|
dest_port=conn.remote_port, |
|
|
seq_num=segment.seq_num, |
|
|
ack_num=segment.ack_num, |
|
|
flags=segment.flags, |
|
|
window_size=segment.window |
|
|
) |
|
|
|
|
|
|
|
|
return IPParser.build_packet(ip_header, tcp_header, segment.data) |
|
|
|
|
|
def _update_rto(self, conn: TCPConnection, rtt: float): |
|
|
"""Update retransmission timeout using RFC 6298""" |
|
|
if conn.srtt == 0: |
|
|
|
|
|
conn.srtt = rtt |
|
|
conn.rttvar = rtt / 2 |
|
|
else: |
|
|
|
|
|
alpha = 0.125 |
|
|
beta = 0.25 |
|
|
conn.rttvar = (1 - beta) * conn.rttvar + beta * abs(conn.srtt - rtt) |
|
|
conn.srtt = (1 - alpha) * conn.srtt + alpha * rtt |
|
|
|
|
|
|
|
|
conn.rto = max(1.0, conn.srtt + 4 * conn.rttvar) |
|
|
conn.rto = min(conn.rto, 60.0) |
|
|
|
|
|
def _update_congestion_window(self, conn: TCPConnection, acked_bytes: int): |
|
|
"""Update congestion window (simplified congestion control)""" |
|
|
if conn.cwnd < conn.ssthresh: |
|
|
|
|
|
conn.cwnd += 1 |
|
|
else: |
|
|
|
|
|
conn.cwnd += max(1, conn.mss * conn.mss // conn.cwnd) |
|
|
|
|
|
def _handle_retransmission(self, conn: TCPConnection): |
|
|
"""Handle segment retransmission""" |
|
|
current_time = time.time() |
|
|
|
|
|
|
|
|
to_retransmit = [] |
|
|
for seq_num, segment in conn.unacked_segments.items(): |
|
|
if current_time - segment.timestamp > conn.rto: |
|
|
if segment.retransmit_count < self.max_retries: |
|
|
to_retransmit.append(segment) |
|
|
else: |
|
|
|
|
|
self._close_connection(conn, reset=True) |
|
|
return |
|
|
|
|
|
|
|
|
for segment in to_retransmit: |
|
|
segment.retransmit_count += 1 |
|
|
segment.timestamp = current_time |
|
|
|
|
|
|
|
|
conn.rto = min(conn.rto * 2, 60.0) |
|
|
|
|
|
|
|
|
conn.ssthresh = max(conn.cwnd // 2, 2) |
|
|
conn.cwnd = 1 |
|
|
|
|
|
|
|
|
packet = self._build_tcp_packet(conn, segment) |
|
|
self._send_packet(packet) |
|
|
|
|
|
def _send_packet(self, packet: bytes): |
|
|
"""Send packet (to be implemented by integration layer)""" |
|
|
|
|
|
pass |
|
|
|
|
|
def _close_connection(self, conn: TCPConnection, reset: bool = False): |
|
|
"""Close connection""" |
|
|
if reset: |
|
|
|
|
|
segment = self._create_tcp_segment(conn, 0x04) |
|
|
packet = self._build_tcp_packet(conn, segment) |
|
|
self._send_packet(packet) |
|
|
conn.state = TCPState.CLOSED |
|
|
else: |
|
|
|
|
|
if conn.state == TCPState.ESTABLISHED: |
|
|
|
|
|
segment = self._create_tcp_segment(conn, 0x01) |
|
|
packet = self._build_tcp_packet(conn, segment) |
|
|
self._send_packet(packet) |
|
|
conn.local_seq += 1 |
|
|
conn.state = TCPState.FIN_WAIT_1 |
|
|
|
|
|
|
|
|
if conn.state == TCPState.CLOSED: |
|
|
if conn.on_connection_closed: |
|
|
conn.on_connection_closed() |
|
|
|
|
|
with self.lock: |
|
|
if conn.connection_id in self.connections: |
|
|
del self.connections[conn.connection_id] |
|
|
|
|
|
def listen(self, port: int, accept_callback: Callable): |
|
|
"""Listen on port for incoming connections""" |
|
|
with self.lock: |
|
|
self.listening_ports[port] = accept_callback |
|
|
|
|
|
def connect(self, local_ip: str, local_port: int, remote_ip: str, remote_port: int) -> Optional[TCPConnection]: |
|
|
"""Initiate outbound connection""" |
|
|
conn_key = self._get_connection_key(local_ip, local_port, remote_ip, remote_port) |
|
|
|
|
|
|
|
|
conn = TCPConnection( |
|
|
local_ip=local_ip, |
|
|
local_port=local_port, |
|
|
remote_ip=remote_ip, |
|
|
remote_port=remote_port, |
|
|
state=TCPState.SYN_SENT, |
|
|
local_seq=self._generate_isn(), |
|
|
mss=self.default_mss, |
|
|
local_window=self.default_window |
|
|
) |
|
|
conn.initial_seq = conn.local_seq |
|
|
|
|
|
with self.lock: |
|
|
self.connections[conn_key] = conn |
|
|
|
|
|
|
|
|
segment = self._create_tcp_segment(conn, 0x02) |
|
|
packet = self._build_tcp_packet(conn, segment) |
|
|
self._send_packet(packet) |
|
|
|
|
|
|
|
|
conn.unacked_segments[conn.local_seq] = segment |
|
|
conn.local_seq += 1 |
|
|
conn.retransmit_timer = time.time() |
|
|
|
|
|
return conn |
|
|
|
|
|
def send_data(self, conn: TCPConnection, data: bytes) -> bool: |
|
|
"""Send data on established connection""" |
|
|
if not conn.can_send_data: |
|
|
return False |
|
|
|
|
|
|
|
|
conn.send_buffer.append(data) |
|
|
|
|
|
|
|
|
self._try_send_data(conn) |
|
|
|
|
|
return True |
|
|
|
|
|
def _try_send_data(self, conn: TCPConnection): |
|
|
"""Try to send buffered data""" |
|
|
while conn.send_buffer and len(conn.unacked_segments) * conn.mss < conn.effective_window: |
|
|
data = conn.send_buffer.popleft() |
|
|
|
|
|
|
|
|
while data: |
|
|
chunk = data[:conn.mss] |
|
|
data = data[conn.mss:] |
|
|
|
|
|
|
|
|
segment = self._create_tcp_segment(conn, 0x18, chunk) |
|
|
packet = self._build_tcp_packet(conn, segment) |
|
|
self._send_packet(packet) |
|
|
|
|
|
|
|
|
conn.unacked_segments[conn.local_seq] = segment |
|
|
conn.local_seq += len(chunk) |
|
|
|
|
|
if not data: |
|
|
break |
|
|
|
|
|
def process_packet(self, packet_data: bytes) -> bool: |
|
|
"""Process incoming TCP packet""" |
|
|
try: |
|
|
|
|
|
parsed = IPParser.parse_packet(packet_data) |
|
|
if not isinstance(parsed.transport_header, TCPHeader): |
|
|
return False |
|
|
|
|
|
ip_header = parsed.ip_header |
|
|
tcp_header = parsed.transport_header |
|
|
payload = parsed.payload |
|
|
|
|
|
|
|
|
conn_key = self._get_connection_key( |
|
|
ip_header.dest_ip, tcp_header.dest_port, |
|
|
ip_header.source_ip, tcp_header.source_port |
|
|
) |
|
|
|
|
|
with self.lock: |
|
|
conn = self.connections.get(conn_key) |
|
|
|
|
|
|
|
|
if not conn and tcp_header.syn and not tcp_header.ack: |
|
|
if tcp_header.dest_port in self.listening_ports: |
|
|
conn = self._handle_new_connection(ip_header, tcp_header) |
|
|
if conn: |
|
|
self.connections[conn_key] = conn |
|
|
|
|
|
if not conn: |
|
|
|
|
|
self._send_rst(ip_header, tcp_header) |
|
|
return False |
|
|
|
|
|
|
|
|
return self._process_segment(conn, tcp_header, payload) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing TCP packet: {e}") |
|
|
return False |
|
|
|
|
|
def _handle_new_connection(self, ip_header: IPv4Header, tcp_header: TCPHeader) -> Optional[TCPConnection]: |
|
|
"""Handle new incoming connection""" |
|
|
accept_callback = self.listening_ports.get(tcp_header.dest_port) |
|
|
if not accept_callback: |
|
|
return None |
|
|
|
|
|
|
|
|
conn = TCPConnection( |
|
|
local_ip=ip_header.dest_ip, |
|
|
local_port=tcp_header.dest_port, |
|
|
remote_ip=ip_header.source_ip, |
|
|
remote_port=tcp_header.source_port, |
|
|
state=TCPState.SYN_RECEIVED, |
|
|
local_seq=self._generate_isn(), |
|
|
remote_seq=tcp_header.seq_num, |
|
|
local_ack=tcp_header.seq_num + 1, |
|
|
mss=self.default_mss, |
|
|
local_window=self.default_window |
|
|
) |
|
|
conn.initial_seq = conn.local_seq |
|
|
|
|
|
|
|
|
segment = self._create_tcp_segment(conn, 0x12) |
|
|
packet = self._build_tcp_packet(conn, segment) |
|
|
self._send_packet(packet) |
|
|
|
|
|
|
|
|
conn.unacked_segments[conn.local_seq] = segment |
|
|
conn.local_seq += 1 |
|
|
conn.retransmit_timer = time.time() |
|
|
|
|
|
|
|
|
accept_callback(conn) |
|
|
|
|
|
return conn |
|
|
|
|
|
def _process_segment(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool: |
|
|
"""Process TCP segment based on connection state""" |
|
|
conn.last_activity = time.time() |
|
|
|
|
|
|
|
|
if tcp_header.rst: |
|
|
conn.state = TCPState.CLOSED |
|
|
self._close_connection(conn) |
|
|
return True |
|
|
|
|
|
|
|
|
if conn.state == TCPState.SYN_SENT: |
|
|
return self._handle_syn_sent(conn, tcp_header, payload) |
|
|
elif conn.state == TCPState.SYN_RECEIVED: |
|
|
return self._handle_syn_received(conn, tcp_header, payload) |
|
|
elif conn.state == TCPState.ESTABLISHED: |
|
|
return self._handle_established(conn, tcp_header, payload) |
|
|
elif conn.state == TCPState.FIN_WAIT_1: |
|
|
return self._handle_fin_wait_1(conn, tcp_header, payload) |
|
|
elif conn.state == TCPState.FIN_WAIT_2: |
|
|
return self._handle_fin_wait_2(conn, tcp_header, payload) |
|
|
elif conn.state == TCPState.CLOSE_WAIT: |
|
|
return self._handle_close_wait(conn, tcp_header, payload) |
|
|
elif conn.state == TCPState.CLOSING: |
|
|
return self._handle_closing(conn, tcp_header, payload) |
|
|
elif conn.state == TCPState.LAST_ACK: |
|
|
return self._handle_last_ack(conn, tcp_header, payload) |
|
|
elif conn.state == TCPState.TIME_WAIT: |
|
|
return self._handle_time_wait(conn, tcp_header, payload) |
|
|
|
|
|
return False |
|
|
|
|
|
def _handle_syn_sent(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool: |
|
|
"""Handle segment in SYN_SENT state""" |
|
|
if tcp_header.syn and tcp_header.ack: |
|
|
|
|
|
if tcp_header.ack_num == conn.local_seq: |
|
|
conn.remote_seq = tcp_header.seq_num |
|
|
conn.local_ack = tcp_header.seq_num + 1 |
|
|
conn.remote_window = tcp_header.window_size |
|
|
|
|
|
|
|
|
if conn.local_seq - 1 in conn.unacked_segments: |
|
|
del conn.unacked_segments[conn.local_seq - 1] |
|
|
|
|
|
|
|
|
segment = self._create_tcp_segment(conn, 0x10) |
|
|
packet = self._build_tcp_packet(conn, segment) |
|
|
self._send_packet(packet) |
|
|
|
|
|
conn.state = TCPState.ESTABLISHED |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
def _handle_syn_received(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool: |
|
|
"""Handle segment in SYN_RECEIVED state""" |
|
|
if tcp_header.ack and tcp_header.ack_num == conn.local_seq: |
|
|
|
|
|
conn.remote_window = tcp_header.window_size |
|
|
|
|
|
|
|
|
if conn.local_seq - 1 in conn.unacked_segments: |
|
|
del conn.unacked_segments[conn.local_seq - 1] |
|
|
|
|
|
conn.state = TCPState.ESTABLISHED |
|
|
return True |
|
|
|
|
|
return False |
|
|
|
|
|
def _handle_established(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool: |
|
|
"""Handle segment in ESTABLISHED state""" |
|
|
|
|
|
if tcp_header.ack: |
|
|
self._process_ack(conn, tcp_header.ack_num) |
|
|
|
|
|
|
|
|
if payload and tcp_header.seq_num == conn.local_ack: |
|
|
conn.local_ack += len(payload) |
|
|
|
|
|
|
|
|
if conn.on_data_received: |
|
|
conn.on_data_received(payload) |
|
|
|
|
|
|
|
|
segment = self._create_tcp_segment(conn, 0x10) |
|
|
packet = self._build_tcp_packet(conn, segment) |
|
|
self._send_packet(packet) |
|
|
|
|
|
|
|
|
if tcp_header.fin: |
|
|
conn.local_ack += 1 |
|
|
|
|
|
|
|
|
segment = self._create_tcp_segment(conn, 0x10) |
|
|
packet = self._build_tcp_packet(conn, segment) |
|
|
self._send_packet(packet) |
|
|
|
|
|
conn.state = TCPState.CLOSE_WAIT |
|
|
|
|
|
return True |
|
|
|
|
|
def _handle_fin_wait_1(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool: |
|
|
"""Handle segment in FIN_WAIT_1 state""" |
|
|
if tcp_header.ack: |
|
|
self._process_ack(conn, tcp_header.ack_num) |
|
|
if not conn.unacked_segments: |
|
|
conn.state = TCPState.FIN_WAIT_2 |
|
|
|
|
|
if tcp_header.fin: |
|
|
conn.local_ack += 1 |
|
|
|
|
|
|
|
|
segment = self._create_tcp_segment(conn, 0x10) |
|
|
packet = self._build_tcp_packet(conn, segment) |
|
|
self._send_packet(packet) |
|
|
|
|
|
if conn.state == TCPState.FIN_WAIT_2: |
|
|
conn.state = TCPState.TIME_WAIT |
|
|
conn.time_wait_start = time.time() |
|
|
else: |
|
|
conn.state = TCPState.CLOSING |
|
|
|
|
|
return True |
|
|
|
|
|
def _handle_fin_wait_2(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool: |
|
|
"""Handle segment in FIN_WAIT_2 state""" |
|
|
if tcp_header.fin: |
|
|
conn.local_ack += 1 |
|
|
|
|
|
|
|
|
segment = self._create_tcp_segment(conn, 0x10) |
|
|
packet = self._build_tcp_packet(conn, segment) |
|
|
self._send_packet(packet) |
|
|
|
|
|
conn.state = TCPState.TIME_WAIT |
|
|
conn.time_wait_start = time.time() |
|
|
|
|
|
return True |
|
|
|
|
|
def _handle_close_wait(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool: |
|
|
"""Handle segment in CLOSE_WAIT state""" |
|
|
|
|
|
return True |
|
|
|
|
|
def _handle_closing(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool: |
|
|
"""Handle segment in CLOSING state""" |
|
|
if tcp_header.ack: |
|
|
self._process_ack(conn, tcp_header.ack_num) |
|
|
if not conn.unacked_segments: |
|
|
conn.state = TCPState.TIME_WAIT |
|
|
conn.time_wait_start = time.time() |
|
|
|
|
|
return True |
|
|
|
|
|
def _handle_last_ack(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool: |
|
|
"""Handle segment in LAST_ACK state""" |
|
|
if tcp_header.ack: |
|
|
self._process_ack(conn, tcp_header.ack_num) |
|
|
if not conn.unacked_segments: |
|
|
conn.state = TCPState.CLOSED |
|
|
self._close_connection(conn) |
|
|
|
|
|
return True |
|
|
|
|
|
def _handle_time_wait(self, conn: TCPConnection, tcp_header: TCPHeader, payload: bytes) -> bool: |
|
|
"""Handle segment in TIME_WAIT state""" |
|
|
|
|
|
if tcp_header.seq_num == conn.local_ack: |
|
|
segment = self._create_tcp_segment(conn, 0x10) |
|
|
packet = self._build_tcp_packet(conn, segment) |
|
|
self._send_packet(packet) |
|
|
|
|
|
return True |
|
|
|
|
|
def _process_ack(self, conn: TCPConnection, ack_num: int): |
|
|
"""Process ACK and remove acknowledged segments""" |
|
|
acked_segments = [] |
|
|
acked_bytes = 0 |
|
|
|
|
|
for seq_num, segment in list(conn.unacked_segments.items()): |
|
|
if seq_num < ack_num: |
|
|
acked_segments.append((seq_num, segment)) |
|
|
acked_bytes += segment.data_length |
|
|
del conn.unacked_segments[seq_num] |
|
|
|
|
|
|
|
|
if acked_segments: |
|
|
|
|
|
rtt = time.time() - acked_segments[0][1].timestamp |
|
|
self._update_rto(conn, rtt) |
|
|
self._update_congestion_window(conn, acked_bytes) |
|
|
|
|
|
|
|
|
self._try_send_data(conn) |
|
|
|
|
|
def _send_rst(self, ip_header: IPv4Header, tcp_header: TCPHeader): |
|
|
"""Send RST for unknown connection""" |
|
|
|
|
|
rst_ip = IPv4Header( |
|
|
protocol=6, |
|
|
source_ip=ip_header.dest_ip, |
|
|
dest_ip=ip_header.source_ip, |
|
|
ttl=64 |
|
|
) |
|
|
|
|
|
rst_tcp = TCPHeader( |
|
|
source_port=tcp_header.dest_port, |
|
|
dest_port=tcp_header.source_port, |
|
|
seq_num=tcp_header.ack_num if tcp_header.ack else 0, |
|
|
ack_num=tcp_header.seq_num + 1 if tcp_header.syn else tcp_header.seq_num, |
|
|
flags=0x14 if tcp_header.ack else 0x04 |
|
|
) |
|
|
|
|
|
packet = IPParser.build_packet(rst_ip, rst_tcp) |
|
|
self._send_packet(packet) |
|
|
|
|
|
def _timer_loop(self): |
|
|
"""Timer loop for handling timeouts""" |
|
|
while self.running: |
|
|
current_time = time.time() |
|
|
|
|
|
with self.lock: |
|
|
connections_to_check = list(self.connections.values()) |
|
|
|
|
|
for conn in connections_to_check: |
|
|
|
|
|
if conn.unacked_segments: |
|
|
self._handle_retransmission(conn) |
|
|
|
|
|
|
|
|
if current_time - conn.last_activity > self.connection_timeout: |
|
|
self._close_connection(conn, reset=True) |
|
|
|
|
|
|
|
|
if (conn.state == TCPState.TIME_WAIT and |
|
|
conn.time_wait_start and |
|
|
current_time - conn.time_wait_start > self.time_wait_timeout): |
|
|
conn.state = TCPState.CLOSED |
|
|
self._close_connection(conn) |
|
|
|
|
|
time.sleep(1) |
|
|
|
|
|
def start(self): |
|
|
"""Start TCP engine""" |
|
|
self.running = True |
|
|
self.timer_thread = threading.Thread(target=self._timer_loop, daemon=True) |
|
|
self.timer_thread.start() |
|
|
print("TCP engine started") |
|
|
|
|
|
def stop(self): |
|
|
"""Stop TCP engine""" |
|
|
self.running = False |
|
|
if self.timer_thread: |
|
|
self.timer_thread.join() |
|
|
|
|
|
|
|
|
with self.lock: |
|
|
for conn in list(self.connections.values()): |
|
|
self._close_connection(conn, reset=True) |
|
|
|
|
|
print("TCP engine stopped") |
|
|
|
|
|
def get_connections(self) -> Dict[str, Dict]: |
|
|
"""Get current connections""" |
|
|
with self.lock: |
|
|
return { |
|
|
conn_id: { |
|
|
'local_ip': conn.local_ip, |
|
|
'local_port': conn.local_port, |
|
|
'remote_ip': conn.remote_ip, |
|
|
'remote_port': conn.remote_port, |
|
|
'state': conn.state.value, |
|
|
'local_seq': conn.local_seq, |
|
|
'local_ack': conn.local_ack, |
|
|
'remote_seq': conn.remote_seq, |
|
|
'remote_ack': conn.remote_ack, |
|
|
'window_size': conn.local_window, |
|
|
'cwnd': conn.cwnd, |
|
|
'unacked_segments': len(conn.unacked_segments), |
|
|
'last_activity': conn.last_activity |
|
|
} |
|
|
for conn_id, conn in self.connections.items() |
|
|
} |
|
|
|
|
|
|