""" NAT Engine Module Implements Network Address Translation: - Map (virtualIP, virtualPort) to (hostIP, hostPort) - Maintain connection tracking table - Handle port allocation and deallocation - Support connection state tracking """ import time import threading import socket import random from typing import Dict, Optional, Tuple, Set from dataclasses import dataclass from enum import Enum from .ip_parser import IPProtocol class NATType(Enum): SNAT = "SNAT" # Source NAT DNAT = "DNAT" # Destination NAT @dataclass class NATSession: """Represents a NAT session""" # Virtual (internal) endpoint virtual_ip: str virtual_port: int # Real (external) endpoint real_ip: str real_port: int # Host (translated) endpoint host_ip: str host_port: int # Session metadata protocol: str # TCP or UDP nat_type: NATType created_time: float last_activity: float bytes_in: int = 0 bytes_out: int = 0 packets_in: int = 0 packets_out: int = 0 @property def session_id(self) -> str: """Get unique session identifier""" return f"{self.virtual_ip}:{self.virtual_port}-{self.real_ip}:{self.real_port}-{self.protocol}" @property def is_expired(self) -> bool: """Check if session has expired""" timeout = 300 if self.protocol == 'TCP' else 60 # 5 min for TCP, 1 min for UDP return time.time() - self.last_activity > timeout @property def duration(self) -> float: """Get session duration in seconds""" return time.time() - self.created_time def update_activity(self, bytes_transferred: int = 0, direction: str = 'out'): """Update session activity""" self.last_activity = time.time() if direction == 'out': self.bytes_out += bytes_transferred self.packets_out += 1 else: self.bytes_in += bytes_transferred self.packets_in += 1 class PortPool: """Manages available ports for NAT""" def __init__(self, start_port: int = 10000, end_port: int = 65535): self.start_port = start_port self.end_port = end_port self.available_ports: Set[int] = set(range(start_port, end_port + 1)) self.allocated_ports: Dict[int, str] = {} # port -> session_id self.lock = threading.Lock() def allocate_port(self, session_id: str) -> Optional[int]: """Allocate a port for a session""" with self.lock: if not self.available_ports: return None # Try to get a random port to distribute load port = random.choice(list(self.available_ports)) self.available_ports.remove(port) self.allocated_ports[port] = session_id return port def release_port(self, port: int) -> bool: """Release a port back to the pool""" with self.lock: if port in self.allocated_ports: del self.allocated_ports[port] if self.start_port <= port <= self.end_port: self.available_ports.add(port) return True return False def get_session_for_port(self, port: int) -> Optional[str]: """Get session ID for a port""" with self.lock: return self.allocated_ports.get(port) def get_stats(self) -> Dict: """Get port pool statistics""" with self.lock: return { 'total_ports': self.end_port - self.start_port + 1, 'available_ports': len(self.available_ports), 'allocated_ports': len(self.allocated_ports), 'utilization': len(self.allocated_ports) / (self.end_port - self.start_port + 1) } class NATEngine: """Network Address Translation engine""" def __init__(self, config: Dict): self.config = config self.sessions: Dict[str, NATSession] = {} # session_id -> session self.virtual_to_session: Dict[Tuple[str, int, str], str] = {} # (vip, vport, proto) -> session_id self.host_to_session: Dict[Tuple[str, int, str], str] = {} # (hip, hport, proto) -> session_id self.lock = threading.Lock() # Port pool for outbound connections self.port_pool = PortPool( config.get('port_range_start', 10000), config.get('port_range_end', 65535) ) # Host IP for outbound connections self.host_ip = config.get('host_ip', self._get_default_host_ip()) # Session timeout self.session_timeout = config.get('session_timeout', 300) # Statistics self.stats = { 'total_sessions': 0, 'active_sessions': 0, 'expired_sessions': 0, 'port_exhaustion_events': 0, 'bytes_translated': 0, 'packets_translated': 0 } # Cleanup thread self.running = False self.cleanup_thread = None def _get_default_host_ip(self) -> str: """Get default host IP address""" try: # Connect to a remote address to determine local IP with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: s.connect(('8.8.8.8', 80)) return s.getsockname()[0] except Exception: return '127.0.0.1' def _cleanup_expired_sessions(self): """Clean up expired sessions""" current_time = time.time() expired_sessions = [] with self.lock: for session_id, session in self.sessions.items(): if session.is_expired: expired_sessions.append(session_id) for session_id in expired_sessions: self._remove_session(session_id) self.stats['expired_sessions'] += 1 def _remove_session(self, session_id: str): """Remove a session and clean up resources""" with self.lock: if session_id not in self.sessions: return session = self.sessions[session_id] # Remove from lookup tables virtual_key = (session.virtual_ip, session.virtual_port, session.protocol) if virtual_key in self.virtual_to_session: del self.virtual_to_session[virtual_key] host_key = (session.host_ip, session.host_port, session.protocol) if host_key in self.host_to_session: del self.host_to_session[host_key] # Release port self.port_pool.release_port(session.host_port) # Remove session del self.sessions[session_id] self.stats['active_sessions'] = len(self.sessions) def create_outbound_session(self, virtual_ip: str, virtual_port: int, real_ip: str, real_port: int, protocol: str) -> Optional[NATSession]: """Create NAT session for outbound connection""" # Allocate host port session_id = f"{virtual_ip}:{virtual_port}-{real_ip}:{real_port}-{protocol}" host_port = self.port_pool.allocate_port(session_id) if host_port is None: self.stats['port_exhaustion_events'] += 1 return None # Create session session = NATSession( virtual_ip=virtual_ip, virtual_port=virtual_port, real_ip=real_ip, real_port=real_port, host_ip=self.host_ip, host_port=host_port, protocol=protocol, nat_type=NATType.SNAT, created_time=time.time(), last_activity=time.time() ) with self.lock: self.sessions[session_id] = session # Add to lookup tables virtual_key = (virtual_ip, virtual_port, protocol) self.virtual_to_session[virtual_key] = session_id host_key = (self.host_ip, host_port, protocol) self.host_to_session[host_key] = session_id self.stats['total_sessions'] += 1 self.stats['active_sessions'] = len(self.sessions) return session def translate_outbound(self, virtual_ip: str, virtual_port: int, real_ip: str, real_port: int, protocol: str) -> Optional[Tuple[str, int]]: """Translate outbound packet (virtual -> host)""" virtual_key = (virtual_ip, virtual_port, protocol) with self.lock: session_id = self.virtual_to_session.get(virtual_key) if session_id: session = self.sessions[session_id] session.update_activity(direction='out') return (session.host_ip, session.host_port) else: # Create new session session = self.create_outbound_session(virtual_ip, virtual_port, real_ip, real_port, protocol) if session: return (session.host_ip, session.host_port) return None def translate_inbound(self, host_ip: str, host_port: int, protocol: str) -> Optional[Tuple[str, int]]: """Translate inbound packet (host -> virtual)""" host_key = (host_ip, host_port, protocol) with self.lock: session_id = self.host_to_session.get(host_key) if session_id and session_id in self.sessions: session = self.sessions[session_id] session.update_activity(direction='in') return (session.virtual_ip, session.virtual_port) return None def get_session_by_virtual(self, virtual_ip: str, virtual_port: int, protocol: str) -> Optional[NATSession]: """Get session by virtual endpoint""" virtual_key = (virtual_ip, virtual_port, protocol) with self.lock: session_id = self.virtual_to_session.get(virtual_key) if session_id and session_id in self.sessions: return self.sessions[session_id] return None def get_session_by_host(self, host_ip: str, host_port: int, protocol: str) -> Optional[NATSession]: """Get session by host endpoint""" host_key = (host_ip, host_port, protocol) with self.lock: session_id = self.host_to_session.get(host_key) if session_id and session_id in self.sessions: return self.sessions[session_id] return None def close_session(self, session_id: str) -> bool: """Manually close a session""" with self.lock: if session_id in self.sessions: self._remove_session(session_id) return True return False def close_session_by_virtual(self, virtual_ip: str, virtual_port: int, protocol: str) -> bool: """Close session by virtual endpoint""" virtual_key = (virtual_ip, virtual_port, protocol) with self.lock: session_id = self.virtual_to_session.get(virtual_key) if session_id: self._remove_session(session_id) return True return False def get_sessions(self) -> Dict[str, Dict]: """Get all active sessions""" with self.lock: return { session_id: { 'virtual_ip': session.virtual_ip, 'virtual_port': session.virtual_port, 'real_ip': session.real_ip, 'real_port': session.real_port, 'host_ip': session.host_ip, 'host_port': session.host_port, 'protocol': session.protocol, 'nat_type': session.nat_type.value, 'created_time': session.created_time, 'last_activity': session.last_activity, 'duration': session.duration, 'bytes_in': session.bytes_in, 'bytes_out': session.bytes_out, 'packets_in': session.packets_in, 'packets_out': session.packets_out, 'is_expired': session.is_expired } for session_id, session in self.sessions.items() } def get_stats(self) -> Dict: """Get NAT statistics""" port_stats = self.port_pool.get_stats() with self.lock: current_stats = self.stats.copy() current_stats['active_sessions'] = len(self.sessions) current_stats.update(port_stats) return current_stats def update_packet_stats(self, bytes_count: int): """Update packet statistics""" self.stats['bytes_translated'] += bytes_count self.stats['packets_translated'] += 1 def _cleanup_loop(self): """Background cleanup loop""" while self.running: try: self._cleanup_expired_sessions() time.sleep(30) # Cleanup every 30 seconds except Exception as e: print(f"NAT cleanup error: {e}") time.sleep(5) def start(self): """Start NAT engine""" self.running = True self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) self.cleanup_thread.start() print(f"NAT engine started - Host IP: {self.host_ip}, Port range: {self.port_pool.start_port}-{self.port_pool.end_port}") def stop(self): """Stop NAT engine""" self.running = False if self.cleanup_thread: self.cleanup_thread.join() # Close all sessions with self.lock: session_ids = list(self.sessions.keys()) for session_id in session_ids: self._remove_session(session_id) print("NAT engine stopped") def reset_stats(self): """Reset statistics""" self.stats = { 'total_sessions': 0, 'active_sessions': len(self.sessions), 'expired_sessions': 0, 'port_exhaustion_events': 0, 'bytes_translated': 0, 'packets_translated': 0 } class NATRule: """Represents a NAT rule for DNAT (port forwarding)""" def __init__(self, external_port: int, internal_ip: str, internal_port: int, protocol: str = 'TCP', enabled: bool = True): self.external_port = external_port self.internal_ip = internal_ip self.internal_port = internal_port self.protocol = protocol.upper() self.enabled = enabled self.created_time = time.time() self.hit_count = 0 self.last_hit = None def matches(self, port: int, protocol: str) -> bool: """Check if rule matches the given port and protocol""" return (self.enabled and self.external_port == port and self.protocol == protocol.upper()) def record_hit(self): """Record a rule hit""" self.hit_count += 1 self.last_hit = time.time() def to_dict(self) -> Dict: """Convert rule to dictionary""" return { 'external_port': self.external_port, 'internal_ip': self.internal_ip, 'internal_port': self.internal_port, 'protocol': self.protocol, 'enabled': self.enabled, 'created_time': self.created_time, 'hit_count': self.hit_count, 'last_hit': self.last_hit } class DNATEngine: """Destination NAT engine for port forwarding""" def __init__(self): self.rules: Dict[str, NATRule] = {} # rule_id -> rule self.lock = threading.Lock() def add_rule(self, rule_id: str, external_port: int, internal_ip: str, internal_port: int, protocol: str = 'TCP') -> bool: """Add DNAT rule""" with self.lock: if rule_id in self.rules: return False rule = NATRule(external_port, internal_ip, internal_port, protocol) self.rules[rule_id] = rule return True def remove_rule(self, rule_id: str) -> bool: """Remove DNAT rule""" with self.lock: if rule_id in self.rules: del self.rules[rule_id] return True return False def enable_rule(self, rule_id: str) -> bool: """Enable DNAT rule""" with self.lock: if rule_id in self.rules: self.rules[rule_id].enabled = True return True return False def disable_rule(self, rule_id: str) -> bool: """Disable DNAT rule""" with self.lock: if rule_id in self.rules: self.rules[rule_id].enabled = False return True return False def translate_inbound_dnat(self, external_port: int, protocol: str) -> Optional[Tuple[str, int]]: """Translate inbound packet using DNAT rules""" with self.lock: for rule in self.rules.values(): if rule.matches(external_port, protocol): rule.record_hit() return (rule.internal_ip, rule.internal_port) return None def get_rules(self) -> Dict[str, Dict]: """Get all DNAT rules""" with self.lock: return { rule_id: rule.to_dict() for rule_id, rule in self.rules.items() } def clear_rules(self): """Clear all DNAT rules""" with self.lock: self.rules.clear()