""" Traffic Forwarding Engine Handles IP packet forwarding and NAT for VPN tunnels """ import asyncio import socket import struct from typing import Dict, Optional, Tuple, Union from dataclasses import dataclass import os from .ip_parser import IPv4Header, IPParser from .logger import Logger, LogCategory from .nat_engine import NATEngine @dataclass class ForwardSession: src_ip: str dst_ip: str src_port: int dst_port: int protocol: int created_at: float last_seen: float bytes_in: int = 0 bytes_out: int = 0 class TrafficForwarder: """Handles packet forwarding and NAT""" def __init__(self, logger: Logger, nat_engine: NATEngine): self.logger = logger self.nat_engine = nat_engine self.sessions: Dict[Tuple[str, str, int, int, int], ForwardSession] = {} self.tcp_connections = {} self.udp_endpoints = {} async def forward_packet(self, data: bytes, client_ip: str) -> Optional[bytes]: """Forward an IP packet""" try: # Parse IP header ip_header = IPParser.parse_ipv4_header(data) # Apply NAT translated_packet = self.nat_engine.translate_outbound(data) if not translated_packet: return None # Track session session_key = ( ip_header.src_ip, ip_header.dst_ip, ip_header.protocol, self._get_src_port(data[ip_header.ihl*4:], ip_header.protocol), self._get_dst_port(data[ip_header.ihl*4:], ip_header.protocol) ) if session_key not in self.sessions: self.sessions[session_key] = ForwardSession( src_ip=ip_header.src_ip, dst_ip=ip_header.dst_ip, src_port=session_key[3], dst_port=session_key[4], protocol=ip_header.protocol, created_at=asyncio.get_running_loop().time(), last_seen=asyncio.get_running_loop().time() ) session = self.sessions[session_key] session.last_seen = asyncio.get_running_loop().time() session.bytes_out += len(data) # Forward based on protocol if ip_header.protocol == socket.IPPROTO_TCP: return await self._forward_tcp(translated_packet, session) elif ip_header.protocol == socket.IPPROTO_UDP: return await self._forward_udp(translated_packet, session) else: # Forward other IP protocols directly return translated_packet except Exception as e: self.logger.error(LogCategory.SYSTEM, "traffic_forwarder", f"Error forwarding packet: {e}") return None async def _forward_tcp(self, data: bytes, session: ForwardSession) -> Optional[bytes]: """Forward TCP packet""" try: ip_header = IPParser.parse_ipv4_header(data) tcp_header_offset = ip_header.ihl * 4 if len(data) < tcp_header_offset + 20: # TCP header is at least 20 bytes return None # Parse TCP header tcp_header = data[tcp_header_offset:tcp_header_offset + 20] flags = tcp_header[13] seq_num = struct.unpack('!I', tcp_header[4:8])[0] ack_num = struct.unpack('!I', tcp_header[8:12])[0] conn_key = (session.src_ip, session.src_port, session.dst_ip, session.dst_port) # Handle TCP state if flags & 0x02: # SYN if conn_key not in self.tcp_connections: self.tcp_connections[conn_key] = { 'state': 'SYN_SENT', 'seq': seq_num, 'ack': 0 } elif flags & 0x01: # FIN if conn_key in self.tcp_connections: self.tcp_connections[conn_key]['state'] = 'FIN_WAIT' elif flags & 0x04: # RST if conn_key in self.tcp_connections: del self.tcp_connections[conn_key] # Forward the packet return await self._send_packet(data) except Exception as e: self.logger.error(LogCategory.SYSTEM, "traffic_forwarder", f"Error forwarding TCP: {e}") return None async def _forward_udp(self, data: bytes, session: ForwardSession) -> Optional[bytes]: """Forward UDP packet""" try: ip_header = IPParser.parse_ipv4_header(data) udp_header_offset = ip_header.ihl * 4 if len(data) < udp_header_offset + 8: # UDP header is 8 bytes return None # Track UDP endpoint endpoint_key = (session.src_ip, session.src_port, session.dst_ip, session.dst_port) self.udp_endpoints[endpoint_key] = asyncio.get_running_loop().time() # Forward the packet return await self._send_packet(data) except Exception as e: self.logger.error(LogCategory.SYSTEM, "traffic_forwarder", f"Error forwarding UDP: {e}") return None async def _send_packet(self, data: bytes) -> Optional[bytes]: """Send packet to destination""" try: # This is where you'd actually send the packet # For now, we'll just return it for the VPN server to handle return data except Exception as e: self.logger.error(LogCategory.SYSTEM, "traffic_forwarder", f"Error sending packet: {e}") return None def _get_src_port(self, transport_header: bytes, protocol: int) -> int: """Extract source port from transport header""" if len(transport_header) >= 2: return struct.unpack('!H', transport_header[0:2])[0] return 0 def _get_dst_port(self, transport_header: bytes, protocol: int) -> int: """Extract destination port from transport header""" if len(transport_header) >= 4: return struct.unpack('!H', transport_header[2:4])[0] return 0 async def cleanup(self): """Clean up expired sessions""" current_time = asyncio.get_running_loop().time() # Clean TCP connections for key, conn in list(self.tcp_connections.items()): if current_time - conn.get('last_seen', 0) > 300: # 5 minutes timeout del self.tcp_connections[key] # Clean UDP endpoints for key, last_seen in list(self.udp_endpoints.items()): if current_time - last_seen > 60: # 1 minute timeout del self.udp_endpoints[key] # Clean sessions for key, session in list(self.sessions.items()): if current_time - session.last_seen > 300: # 5 minutes timeout del self.sessions[key]