Spaces:
Runtime error
Runtime error
| """ | |
| Traffic Forwarding Engine | |
| Handles IP packet forwarding and NAT for VPN tunnels | |
| """ | |
| import asyncio | |
| import socket | |
| import struct | |
| from typing import Dict, Optional, Tuple, Union | |
| from dataclasses import dataclass | |
| import os | |
| from .ip_parser import IPv4Header, IPParser | |
| from .logger import Logger, LogCategory | |
| from .nat_engine import NATEngine | |
| class ForwardSession: | |
| src_ip: str | |
| dst_ip: str | |
| src_port: int | |
| dst_port: int | |
| protocol: int | |
| created_at: float | |
| last_seen: float | |
| bytes_in: int = 0 | |
| bytes_out: int = 0 | |
| class TrafficForwarder: | |
| """Handles packet forwarding and NAT""" | |
| def __init__(self, logger: Logger, nat_engine: NATEngine): | |
| self.logger = logger | |
| self.nat_engine = nat_engine | |
| self.sessions: Dict[Tuple[str, str, int, int, int], ForwardSession] = {} | |
| self.tcp_connections = {} | |
| self.udp_endpoints = {} | |
| async def forward_packet(self, data: bytes, client_ip: str) -> Optional[bytes]: | |
| """Forward an IP packet""" | |
| try: | |
| # Parse IP header | |
| ip_header = IPParser.parse_ipv4_header(data) | |
| # Apply NAT | |
| translated_packet = self.nat_engine.translate_outbound(data) | |
| if not translated_packet: | |
| return None | |
| # Track session | |
| session_key = ( | |
| ip_header.src_ip, | |
| ip_header.dst_ip, | |
| ip_header.protocol, | |
| self._get_src_port(data[ip_header.ihl*4:], ip_header.protocol), | |
| self._get_dst_port(data[ip_header.ihl*4:], ip_header.protocol) | |
| ) | |
| if session_key not in self.sessions: | |
| self.sessions[session_key] = ForwardSession( | |
| src_ip=ip_header.src_ip, | |
| dst_ip=ip_header.dst_ip, | |
| src_port=session_key[3], | |
| dst_port=session_key[4], | |
| protocol=ip_header.protocol, | |
| created_at=asyncio.get_running_loop().time(), | |
| last_seen=asyncio.get_running_loop().time() | |
| ) | |
| session = self.sessions[session_key] | |
| session.last_seen = asyncio.get_running_loop().time() | |
| session.bytes_out += len(data) | |
| # Forward based on protocol | |
| if ip_header.protocol == socket.IPPROTO_TCP: | |
| return await self._forward_tcp(translated_packet, session) | |
| elif ip_header.protocol == socket.IPPROTO_UDP: | |
| return await self._forward_udp(translated_packet, session) | |
| else: | |
| # Forward other IP protocols directly | |
| return translated_packet | |
| except Exception as e: | |
| self.logger.error(LogCategory.SYSTEM, "traffic_forwarder", f"Error forwarding packet: {e}") | |
| return None | |
| async def _forward_tcp(self, data: bytes, session: ForwardSession) -> Optional[bytes]: | |
| """Forward TCP packet""" | |
| try: | |
| ip_header = IPParser.parse_ipv4_header(data) | |
| tcp_header_offset = ip_header.ihl * 4 | |
| if len(data) < tcp_header_offset + 20: # TCP header is at least 20 bytes | |
| return None | |
| # Parse TCP header | |
| tcp_header = data[tcp_header_offset:tcp_header_offset + 20] | |
| flags = tcp_header[13] | |
| seq_num = struct.unpack('!I', tcp_header[4:8])[0] | |
| ack_num = struct.unpack('!I', tcp_header[8:12])[0] | |
| conn_key = (session.src_ip, session.src_port, session.dst_ip, session.dst_port) | |
| # Handle TCP state | |
| if flags & 0x02: # SYN | |
| if conn_key not in self.tcp_connections: | |
| self.tcp_connections[conn_key] = { | |
| 'state': 'SYN_SENT', | |
| 'seq': seq_num, | |
| 'ack': 0 | |
| } | |
| elif flags & 0x01: # FIN | |
| if conn_key in self.tcp_connections: | |
| self.tcp_connections[conn_key]['state'] = 'FIN_WAIT' | |
| elif flags & 0x04: # RST | |
| if conn_key in self.tcp_connections: | |
| del self.tcp_connections[conn_key] | |
| # Forward the packet | |
| return await self._send_packet(data) | |
| except Exception as e: | |
| self.logger.error(LogCategory.SYSTEM, "traffic_forwarder", f"Error forwarding TCP: {e}") | |
| return None | |
| async def _forward_udp(self, data: bytes, session: ForwardSession) -> Optional[bytes]: | |
| """Forward UDP packet""" | |
| try: | |
| ip_header = IPParser.parse_ipv4_header(data) | |
| udp_header_offset = ip_header.ihl * 4 | |
| if len(data) < udp_header_offset + 8: # UDP header is 8 bytes | |
| return None | |
| # Track UDP endpoint | |
| endpoint_key = (session.src_ip, session.src_port, session.dst_ip, session.dst_port) | |
| self.udp_endpoints[endpoint_key] = asyncio.get_running_loop().time() | |
| # Forward the packet | |
| return await self._send_packet(data) | |
| except Exception as e: | |
| self.logger.error(LogCategory.SYSTEM, "traffic_forwarder", f"Error forwarding UDP: {e}") | |
| return None | |
| async def _send_packet(self, data: bytes) -> Optional[bytes]: | |
| """Send packet to destination""" | |
| try: | |
| # This is where you'd actually send the packet | |
| # For now, we'll just return it for the VPN server to handle | |
| return data | |
| except Exception as e: | |
| self.logger.error(LogCategory.SYSTEM, "traffic_forwarder", f"Error sending packet: {e}") | |
| return None | |
| def _get_src_port(self, transport_header: bytes, protocol: int) -> int: | |
| """Extract source port from transport header""" | |
| if len(transport_header) >= 2: | |
| return struct.unpack('!H', transport_header[0:2])[0] | |
| return 0 | |
| def _get_dst_port(self, transport_header: bytes, protocol: int) -> int: | |
| """Extract destination port from transport header""" | |
| if len(transport_header) >= 4: | |
| return struct.unpack('!H', transport_header[2:4])[0] | |
| return 0 | |
| async def cleanup(self): | |
| """Clean up expired sessions""" | |
| current_time = asyncio.get_running_loop().time() | |
| # Clean TCP connections | |
| for key, conn in list(self.tcp_connections.items()): | |
| if current_time - conn.get('last_seen', 0) > 300: # 5 minutes timeout | |
| del self.tcp_connections[key] | |
| # Clean UDP endpoints | |
| for key, last_seen in list(self.udp_endpoints.items()): | |
| if current_time - last_seen > 60: # 1 minute timeout | |
| del self.udp_endpoints[key] | |
| # Clean sessions | |
| for key, session in list(self.sessions.items()): | |
| if current_time - session.last_seen > 300: # 5 minutes timeout | |
| del self.sessions[key] | |