""" NAT Engine Module Implements Network Address Translation: - Map (virtualIP, virtualPort) to (hostIP, hostPort) - Maintain connection tracking table - Handle port allocation and deallocation - Support connection state tracking """ import time import threading import socket import random import struct from typing import Dict, Optional, Tuple, Set from dataclasses import dataclass from enum import Enum # Assuming IPProtocol is defined elsewhere or will be defined # from .ip_parser import IPProtocol class NATType(Enum): SNAT = "SNAT" # Source NAT DNAT = "DNAT" # Destination NAT @dataclass class NATSession: """Represents a NAT session""" # Virtual (internal) endpoint virtual_ip: str virtual_port: int # Real (external) endpoint real_ip: str real_port: int # Host (translated) endpoint host_ip: str host_port: int # Session metadata protocol: int # IP protocol number (e.g., 6 for TCP, 17 for UDP) nat_type: NATType created_time: float last_activity: float bytes_in: int = 0 bytes_out: int = 0 packets_in: int = 0 packets_out: int = 0 @property def session_id(self) -> str: """Get unique session identifier""" return f"{self.virtual_ip}:{self.virtual_port}-{self.real_ip}:{self.real_port}-{self.protocol}" @property def is_expired(self) -> bool: """Check if session has expired""" timeout = 300 if self.protocol == socket.IPPROTO_TCP else 60 # 5 min for TCP, 1 min for UDP return time.time() - self.last_activity > timeout @property def duration(self) -> float: """Get session duration in seconds""" return time.time() - self.created_time def update_activity(self, bytes_transferred: int = 0, direction: str = 'out'): """Update session activity""" self.last_activity = time.time() if direction == 'out': self.bytes_out += bytes_transferred self.packets_out += 1 else: self.bytes_in += bytes_transferred self.packets_in += 1 class PortPool: """Manages available ports for NAT""" def __init__(self, start_port: int = 10000, end_port: int = 65535): self.start_port = start_port self.end_port = end_port self.available_ports: Set[int] = set(range(start_port, end_port + 1)) self.allocated_ports: Dict[int, str] = {} # port -> session_id self.lock = threading.Lock() def allocate_port(self, session_id: str) -> Optional[int]: """Allocate a port for a session""" with self.lock: if not self.available_ports: return None # Try to get a random port to distribute load port = random.choice(list(self.available_ports)) self.available_ports.remove(port) self.allocated_ports[port] = session_id return port def release_port(self, port: int) -> bool: """Release a port back to the pool""" with self.lock: if port in self.allocated_ports: del self.allocated_ports[port] if self.start_port <= port <= self.end_port: self.available_ports.add(port) return True return False def get_session_for_port(self, port: int) -> Optional[str]: """Get session ID for a port""" with self.lock: return self.allocated_ports.get(port) def get_stats(self) -> Dict: """Get port pool statistics""" with self.lock: return { 'total_ports': self.end_port - self.start_port + 1, 'available_ports': len(self.available_ports), 'allocated_ports': len(self.allocated_ports), 'utilization': len(self.allocated_ports) / (self.end_port - self.start_port + 1) } class NATEngine: """Network Address Translation engine""" def __init__(self, config: Dict): self.config = config self.sessions: Dict[str, NATSession] = {} # session_id -> session self.virtual_to_session: Dict[Tuple[str, int, int], str] = {} # (vip, vport, proto) -> session_id self.host_to_session: Dict[Tuple[str, int, int], str] = {} # (hip, hport, proto) -> session_id self.lock = threading.Lock() # Port pool for outbound connections self.port_pool = PortPool( config.get('port_range_start', 10000), config.get('port_range_end', 65535) ) # Host IP for outbound connections self.host_ip = config.get('host_ip', self._get_default_host_ip()) # Session timeout self.session_timeout = config.get('session_timeout', 300) # Statistics self.stats = { 'total_sessions': 0, 'active_sessions': 0, 'expired_sessions': 0, 'port_exhaustion_events': 0, 'bytes_translated': 0, 'packets_translated': 0 } # Cleanup thread self.running = False self.cleanup_thread = None def _get_default_host_ip(self) -> str: """Get default host IP address""" try: # Connect to a remote address to determine local IP with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: s.connect(('8.8.8.8', 80)) return s.getsockname()[0] except Exception: return '127.0.0.1' def _cleanup_expired_sessions(self): """Clean up expired sessions""" current_time = time.time() expired_sessions = [] with self.lock: for session_id, session in self.sessions.items(): if session.is_expired: expired_sessions.append(session_id) for session_id in expired_sessions: self._remove_session(session_id) self.stats['expired_sessions'] += 1 def _remove_session(self, session_id: str): """Remove a session and clean up resources""" with self.lock: if session_id not in self.sessions: return session = self.sessions[session_id] # Remove from lookup tables virtual_key = (session.virtual_ip, session.virtual_port, session.protocol) if virtual_key in self.virtual_to_session: del self.virtual_to_session[virtual_key] host_key = (session.host_ip, session.host_port, session.protocol) if host_key in self.host_to_session: del self.host_to_session[host_key] # Release port self.port_pool.release_port(session.host_port) # Remove session del self.sessions[session_id] self.stats['active_sessions'] = len(self.sessions) def create_outbound_session(self, virtual_ip: str, virtual_port: int, real_ip: str, real_port: int, protocol: int) -> Optional[NATSession]: """Create NAT session for outbound connection""" # Allocate host port session_id = f"{virtual_ip}:{virtual_port}-{real_ip}:{real_port}-{protocol}" host_port = self.port_pool.allocate_port(session_id) if host_port is None: self.stats['port_exhaustion_events'] += 1 return None # Create session session = NATSession( virtual_ip=virtual_ip, virtual_port=virtual_port, real_ip=real_ip, real_port=real_port, host_ip=self.host_ip, host_port=host_port, protocol=protocol, nat_type=NATType.SNAT, created_time=time.time(), last_activity=time.time() ) with self.lock: self.sessions[session_id] = session # Add to lookup tables virtual_key = (virtual_ip, virtual_port, protocol) self.virtual_to_session[virtual_key] = session_id host_key = (self.host_ip, host_port, protocol) self.host_to_session[host_key] = session_id self.stats['total_sessions'] += 1 self.stats['active_sessions'] = len(self.sessions) return session def translate_outbound(self, virtual_ip: str, virtual_port: int, real_ip: str, real_port: int, protocol: int) -> Optional[Tuple[str, int]]: """Translate outbound packet (virtual -> host)""" virtual_key = (virtual_ip, virtual_port, protocol) with self.lock: session_id = self.virtual_to_session.get(virtual_key) if session_id: session = self.sessions[session_id] session.update_activity(direction='out') return (session.host_ip, session.host_port) else: # Create new session session = self.create_outbound_session(virtual_ip, virtual_port, real_ip, real_port, protocol) if session: return (session.host_ip, session.host_port) return None def translate_inbound(self, host_ip: str, host_port: int, protocol: int) -> Optional[Tuple[str, int]]: """Translate inbound packet (host -> virtual)""" host_key = (host_ip, host_port, protocol) with self.lock: session_id = self.host_to_session.get(host_key) if session_id and session_id in self.sessions: session = self.sessions[session_id] session.update_activity(direction='in') return (session.virtual_ip, session.virtual_port) return None def get_session_by_virtual(self, virtual_ip: str, virtual_port: int, protocol: int) -> Optional[NATSession]: """Get session by virtual endpoint""" virtual_key = (virtual_ip, virtual_port, protocol) with self.lock: session_id = self.virtual_to_session.get(virtual_key) if session_id and session_id in self.sessions: return self.sessions[session_id] return None def get_session_by_host(self, host_ip: str, host_port: int, protocol: int) -> Optional[NATSession]: """Get session by host endpoint""" host_key = (host_ip, host_port, protocol) with self.lock: session_id = self.host_to_session.get(host_key) if session_id and session_id in self.sessions: return self.sessions[session_id] return None def close_session(self, session_id: str) -> bool: """Manually close a session""" with self.lock: if session_id in self.sessions: self._remove_session(session_id) return True return False def close_session_by_virtual(self, virtual_ip: str, virtual_port: int, protocol: int) -> bool: """Close session by virtual endpoint""" virtual_key = (virtual_ip, virtual_port, protocol) with self.lock: session_id = self.virtual_to_session.get(virtual_key) if session_id: self._remove_session(session_id) return True return False def get_sessions(self) -> Dict[str, Dict]: """Get all active sessions""" with self.lock: return { session_id: { 'virtual_ip': session.virtual_ip, 'virtual_port': session.virtual_port, 'real_ip': session.real_ip, 'real_port': session.real_port, 'host_ip': session.host_ip, 'host_port': session.host_port, 'protocol': session.protocol, 'nat_type': session.nat_type.value, 'created_time': session.created_time, 'last_activity': session.last_activity, 'duration': session.duration, 'bytes_in': session.bytes_in, 'bytes_out': session.bytes_out, 'packets_in': session.packets_in, 'packets_out': session.packets_out, 'is_expired': session.is_expired } for session_id, session in self.sessions.items() } def get_stats(self) -> Dict: """Get NAT statistics""" port_stats = self.port_pool.get_stats() with self.lock: current_stats = self.stats.copy() current_stats['active_sessions'] = len(self.sessions) current_stats.update(port_stats) return current_stats def update_packet_stats(self, bytes_count: int): """Update packet statistics""" self.stats['bytes_translated'] += bytes_count self.stats['packets_translated'] += 1 def _cleanup_loop(self): """Background cleanup loop""" while self.running: try: # print("NAT cleanup loop: Cleaning expired sessions...") # Debug print self._cleanup_expired_sessions() time.sleep(0.1) # Shorter sleep for faster testing except Exception as e: print(f"NAT cleanup error: {e}") time.sleep(0.1) def start(self): """Start NAT engine""" self.running = True self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) self.cleanup_thread.start() # print(f"NAT engine started - Host IP: {self.host_ip}, Port range: {self.port_pool.start_port}-{self.port_pool.end_port}") def stop(self): """Stop NAT engine""" # print("Stopping NAT engine...") # Debug print self.running = False if self.cleanup_thread and self.cleanup_thread.is_alive(): self.cleanup_thread.join(timeout=1) # Add timeout to join if self.cleanup_thread.is_alive(): print("NAT cleanup thread did not terminate in time.") # Debug print # Close all sessions with self.lock: session_ids = list(self.sessions.keys()) for session_id in session_ids: self._remove_session(session_id) # print("NAT engine stopped") def _calculate_ip_checksum(self, ip_header_no_checksum: bytes) -> int: """Calculate the IP header checksum.""" # IP header checksum calculation (simplified for demonstration) # This is a basic implementation and might need refinement for production use s = 0 # loop through header words for i in range(0, len(ip_header_no_checksum), 2): w = (ip_header_no_checksum[i] << 8) + (ip_header_no_checksum[i+1]) s = s + w s = (s & 0xffff) + (s >> 16) s = s + (s >> 16) return ~s & 0xffff def process_inbound_packet(self, packet: bytes) -> Optional[bytes]: """Process an inbound packet (from internet to VPN client) for DNAT.""" # Parse IP header # Assuming Ethernet frame, IP header starts at offset 14 # For simplicity, let's assume we are only dealing with IPv4 for now ip_header_offset = 14 ip_header_length = (packet[ip_header_offset] & 0xF) * 4 ip_header = packet[ip_header_offset : ip_header_offset + ip_header_length] # Unpack IP header (version_ihl, tos, total_length, identification, fragment_offset, ttl, protocol, header_checksum, source_address, destination_address) iph = struct.unpack('!BBHHHBBH4s4s', ip_header) protocol = iph[6] source_ip = socket.inet_ntoa(iph[8]) dest_ip = socket.inet_ntoa(iph[9]) # Only process TCP/UDP for now if protocol not in [socket.IPPROTO_TCP, socket.IPPROTO_UDP]: return None # Parse TCP/UDP header transport_header_offset = ip_header_offset + ip_header_length if protocol == socket.IPPROTO_TCP: tcp_header = packet[transport_header_offset : transport_header_offset + 20] tcph = struct.unpack('!HHLLBBHHH', tcp_header) source_port = tcph[0] dest_port = tcph[1] elif protocol == socket.IPPROTO_UDP: udp_header = packet[transport_header_offset : transport_header_offset + 8] udph = struct.unpack('!HHHH', udp_header) source_port = udph[0] dest_port = udph[1] else: return None # Check for DNAT rule match (simplified for now, actual DNAT rules would be in DNATEngine) # For now, assume we are looking for a session based on host_ip (d_addr) and host_port (dest_port) translated_endpoint = self.translate_inbound(dest_ip, dest_port, protocol) if translated_endpoint: virtual_ip, virtual_port = translated_endpoint # Reconstruct packet with translated destination IP and port # Recalculate IP header checksum new_dest_ip_bytes = socket.inet_aton(virtual_ip) # Rebuild IP header with new destination IP # Need to recalculate checksum for IP header # For simplicity, we'll set checksum to 0 and assume it's recalculated later or by OS new_ip_header_raw = struct.pack('!BBHHHBBH4s4s', iph[0], iph[1], iph[2], iph[3], iph[4], iph[5], iph[6], 0, iph[8], new_dest_ip_bytes) new_ip_header_checksum = self._calculate_ip_checksum(new_ip_header_raw) new_ip_header = struct.pack('!BBHHHBBH4s4s', iph[0], iph[1], iph[2], iph[3], iph[4], iph[5], iph[6], new_ip_header_checksum, iph[8], new_dest_ip_bytes) # Rebuild TCP/UDP header with new destination port if protocol == socket.IPPROTO_TCP: # Recalculate TCP checksum (requires pseudo-header, IP header, and TCP data) new_tcp_header_raw = struct.pack('!HHLLBBHHH', source_port, virtual_port, tcph[2], tcph[3], tcph[4], tcph[5], tcph[6], 0, tcph[8]) # For now, setting checksum to 0. Proper recalculation is complex. new_tcp_header = struct.pack('!HHLLBBHHH', source_port, virtual_port, tcph[2], tcph[3], tcph[4], tcph[5], tcph[6], 0, tcph[8]) return packet[:ip_header_offset] + new_ip_header + new_tcp_header + packet[transport_header_offset + 20:] elif protocol == socket.IPPROTO_UDP: # Recalculate UDP checksum (requires pseudo-header, IP header, and UDP data) new_udp_header_raw = struct.pack('!HHHH', source_port, virtual_port, udph[2], 0) # For now, setting checksum to 0. Proper recalculation is complex. new_udp_header = struct.pack('!HHHH', source_port, virtual_port, udph[2], 0) return packet[:ip_header_offset] + new_ip_header + new_udp_header + packet[transport_header_offset + 8:] return None def process_outbound_packet(self, packet: bytes) -> Optional[bytes]: """Process an outbound packet (from VPN client to internet) for SNAT.""" # Parse IP header ip_header_offset = 14 ip_header_length = (packet[ip_header_offset] & 0xF) * 4 ip_header = packet[ip_header_offset : ip_header_offset + ip_header_length] # Unpack IP header iph = struct.unpack('!BBHHHBBH4s4s', ip_header) protocol = iph[6] source_ip = socket.inet_ntoa(iph[8]) dest_ip = socket.inet_ntoa(iph[9]) # Only process TCP/UDP for now if protocol not in [socket.IPPROTO_TCP, socket.IPPROTO_UDP]: return None # Parse TCP/UDP header transport_header_offset = ip_header_offset + ip_header_length if protocol == socket.IPPROTO_TCP: tcp_header = packet[transport_header_offset : transport_header_offset + 20] tcph = struct.unpack('!HHLLBBHHH', tcp_header) source_port = tcph[0] dest_port = tcph[1] elif protocol == socket.IPPROTO_UDP: udp_header = packet[transport_header_offset : transport_header_offset + 8] udph = struct.unpack('!HHHH', udp_header) source_port = udph[0] dest_port = udph[1] else: return None # Perform SNAT translated_endpoint = self.translate_outbound(source_ip, source_port, dest_ip, dest_port, protocol) if translated_endpoint: host_ip, host_port = translated_endpoint # Reconstruct packet with translated source IP and port # Recalculate IP header checksum new_source_ip_bytes = socket.inet_aton(host_ip) # Rebuild IP header with new source IP new_ip_header_raw = struct.pack('!BBHHHBBH4s4s', iph[0], iph[1], iph[2], iph[3], iph[4], iph[5], iph[6], 0, new_source_ip_bytes, iph[9]) new_ip_header_checksum = self._calculate_ip_checksum(new_ip_header_raw) new_ip_header = struct.pack('!BBHHHBBH4s4s', iph[0], iph[1], iph[2], iph[3], iph[4], iph[5], iph[6], new_ip_header_checksum, new_source_ip_bytes, iph[9]) # Rebuild TCP/UDP header with new source port if protocol == socket.IPPROTO_TCP: # Recalculate TCP checksum new_tcp_header_raw = struct.pack('!HHLLBBHHH', host_port, dest_port, tcph[2], tcph[3], tcph[4], tcph[5], tcph[6], 0, tcph[8]) # For now, setting checksum to 0. Proper recalculation is complex. new_tcp_header = struct.pack('!HHLLBBHHH', host_port, dest_port, tcph[2], tcph[3], tcph[4], tcph[5], tcph[6], 0, tcph[8]) return packet[:ip_header_offset] + new_ip_header + new_tcp_header + packet[transport_header_offset + 20:] elif protocol == socket.IPPROTO_UDP: # Recalculate UDP checksum new_udp_header_raw = struct.pack('!HHHH', host_port, dest_port, udph[2], 0) # For now, setting checksum to 0. Proper recalculation is complex. new_udp_header = struct.pack('!HHHH', host_port, dest_port, udph[2], 0) return packet[:ip_header_offset] + new_ip_header + new_udp_header + packet[transport_header_offset + 8:] return None class NATRule: """Represents a NAT rule for DNAT (port forwarding)""" def __init__(self, external_port: int, internal_ip: str, internal_port: int, protocol: int, enabled: bool = True): self.external_port = external_port self.internal_ip = internal_ip self.internal_port = internal_port self.protocol = protocol self.enabled = enabled self.created_time = time.time() self.hit_count = 0 self.last_hit = None def matches(self, port: int, protocol: int) -> bool: """Check if rule matches the given port and protocol""" return (self.enabled and self.external_port == port and self.protocol == protocol) def record_hit(self): """Record a rule hit""" self.hit_count += 1 self.last_hit = time.time() def to_dict(self) -> Dict: """Convert rule to dictionary""" return { 'external_port': self.external_port, 'internal_ip': self.internal_ip, 'internal_port': self.internal_port, 'protocol': self.protocol, 'enabled': self.enabled, 'created_time': self.created_time, 'hit_count': self.hit_count, 'last_hit': self.last_hit } class DNATEngine: """Destination NAT engine for port forwarding""" def __init__(self): self.rules: Dict[str, NATRule] = {} # rule_id -> rule self.lock = threading.Lock() def add_rule(self, rule_id: str, external_port: int, internal_ip: str, internal_port: int, protocol: int) -> bool: """Add DNAT rule""" with self.lock: if rule_id in self.rules: return False rule = NATRule(external_port, internal_ip, internal_port, protocol) self.rules[rule_id] = rule return True def remove_rule(self, rule_id: str) -> bool: """Remove DNAT rule""" with self.lock: if rule_id in self.rules: del self.rules[rule_id] return True return False def get_rule(self, rule_id: str) -> Optional[NATRule]: """Get DNAT rule by ID""" with self.lock: return self.rules.get(rule_id) def get_matching_rule(self, port: int, protocol: int) -> Optional[NATRule]: """Get matching DNAT rule for given port and protocol""" with self.lock: for rule in self.rules.values(): if rule.matches(port, protocol): rule.record_hit() return rule return None def get_all_rules(self) -> Dict[str, Dict]: """Get all DNAT rules""" with self.lock: return {rule_id: rule.to_dict() for rule_id, rule in self.rules.items()}