|
|
""" |
|
|
NAT Engine Module |
|
|
|
|
|
Implements Network Address Translation: |
|
|
- Map (virtualIP, virtualPort) to (hostIP, hostPort) |
|
|
- Maintain connection tracking table |
|
|
- Handle port allocation and deallocation |
|
|
- Support connection state tracking |
|
|
""" |
|
|
|
|
|
import time |
|
|
import threading |
|
|
import socket |
|
|
import random |
|
|
from typing import Dict, Optional, Tuple, Set |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
|
|
|
from .ip_parser import IPProtocol |
|
|
|
|
|
|
|
|
class NATType(Enum): |
|
|
SNAT = "SNAT" |
|
|
DNAT = "DNAT" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class NATSession: |
|
|
"""Represents a NAT session""" |
|
|
|
|
|
virtual_ip: str |
|
|
virtual_port: int |
|
|
|
|
|
|
|
|
real_ip: str |
|
|
real_port: int |
|
|
|
|
|
|
|
|
host_ip: str |
|
|
host_port: int |
|
|
|
|
|
|
|
|
protocol: str |
|
|
nat_type: NATType |
|
|
created_time: float |
|
|
last_activity: float |
|
|
bytes_in: int = 0 |
|
|
bytes_out: int = 0 |
|
|
packets_in: int = 0 |
|
|
packets_out: int = 0 |
|
|
|
|
|
@property |
|
|
def session_id(self) -> str: |
|
|
"""Get unique session identifier""" |
|
|
return f"{self.virtual_ip}:{self.virtual_port}-{self.real_ip}:{self.real_port}-{self.protocol}" |
|
|
|
|
|
@property |
|
|
def is_expired(self) -> bool: |
|
|
"""Check if session has expired""" |
|
|
timeout = 300 if self.protocol == 'TCP' else 60 |
|
|
return time.time() - self.last_activity > timeout |
|
|
|
|
|
@property |
|
|
def duration(self) -> float: |
|
|
"""Get session duration in seconds""" |
|
|
return time.time() - self.created_time |
|
|
|
|
|
def update_activity(self, bytes_transferred: int = 0, direction: str = 'out'): |
|
|
"""Update session activity""" |
|
|
self.last_activity = time.time() |
|
|
|
|
|
if direction == 'out': |
|
|
self.bytes_out += bytes_transferred |
|
|
self.packets_out += 1 |
|
|
else: |
|
|
self.bytes_in += bytes_transferred |
|
|
self.packets_in += 1 |
|
|
|
|
|
|
|
|
class PortPool: |
|
|
"""Manages available ports for NAT""" |
|
|
|
|
|
def __init__(self, start_port: int = 10000, end_port: int = 65535): |
|
|
self.start_port = start_port |
|
|
self.end_port = end_port |
|
|
self.available_ports: Set[int] = set(range(start_port, end_port + 1)) |
|
|
self.allocated_ports: Dict[int, str] = {} |
|
|
self.lock = threading.Lock() |
|
|
|
|
|
def allocate_port(self, session_id: str) -> Optional[int]: |
|
|
"""Allocate a port for a session""" |
|
|
with self.lock: |
|
|
if not self.available_ports: |
|
|
return None |
|
|
|
|
|
|
|
|
port = random.choice(list(self.available_ports)) |
|
|
self.available_ports.remove(port) |
|
|
self.allocated_ports[port] = session_id |
|
|
|
|
|
return port |
|
|
|
|
|
def release_port(self, port: int) -> bool: |
|
|
"""Release a port back to the pool""" |
|
|
with self.lock: |
|
|
if port in self.allocated_ports: |
|
|
del self.allocated_ports[port] |
|
|
if self.start_port <= port <= self.end_port: |
|
|
self.available_ports.add(port) |
|
|
return True |
|
|
return False |
|
|
|
|
|
def get_session_for_port(self, port: int) -> Optional[str]: |
|
|
"""Get session ID for a port""" |
|
|
with self.lock: |
|
|
return self.allocated_ports.get(port) |
|
|
|
|
|
def get_stats(self) -> Dict: |
|
|
"""Get port pool statistics""" |
|
|
with self.lock: |
|
|
return { |
|
|
'total_ports': self.end_port - self.start_port + 1, |
|
|
'available_ports': len(self.available_ports), |
|
|
'allocated_ports': len(self.allocated_ports), |
|
|
'utilization': len(self.allocated_ports) / (self.end_port - self.start_port + 1) |
|
|
} |
|
|
|
|
|
|
|
|
class NATEngine: |
|
|
"""Network Address Translation engine""" |
|
|
|
|
|
def __init__(self, config: Dict): |
|
|
self.config = config |
|
|
self.sessions: Dict[str, NATSession] = {} |
|
|
self.virtual_to_session: Dict[Tuple[str, int, str], str] = {} |
|
|
self.host_to_session: Dict[Tuple[str, int, str], str] = {} |
|
|
self.lock = threading.Lock() |
|
|
|
|
|
|
|
|
self.port_pool = PortPool( |
|
|
config.get('port_range_start', 10000), |
|
|
config.get('port_range_end', 65535) |
|
|
) |
|
|
|
|
|
|
|
|
self.host_ip = config.get('host_ip', self._get_default_host_ip()) |
|
|
|
|
|
|
|
|
self.session_timeout = config.get('session_timeout', 300) |
|
|
|
|
|
|
|
|
self.stats = { |
|
|
'total_sessions': 0, |
|
|
'active_sessions': 0, |
|
|
'expired_sessions': 0, |
|
|
'port_exhaustion_events': 0, |
|
|
'bytes_translated': 0, |
|
|
'packets_translated': 0 |
|
|
} |
|
|
|
|
|
|
|
|
self.running = False |
|
|
self.cleanup_thread = None |
|
|
|
|
|
def _get_default_host_ip(self) -> str: |
|
|
"""Get default host IP address""" |
|
|
try: |
|
|
|
|
|
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: |
|
|
s.connect(('8.8.8.8', 80)) |
|
|
return s.getsockname()[0] |
|
|
except Exception: |
|
|
return '127.0.0.1' |
|
|
|
|
|
def _cleanup_expired_sessions(self): |
|
|
"""Clean up expired sessions""" |
|
|
current_time = time.time() |
|
|
expired_sessions = [] |
|
|
|
|
|
with self.lock: |
|
|
for session_id, session in self.sessions.items(): |
|
|
if session.is_expired: |
|
|
expired_sessions.append(session_id) |
|
|
|
|
|
for session_id in expired_sessions: |
|
|
self._remove_session(session_id) |
|
|
self.stats['expired_sessions'] += 1 |
|
|
|
|
|
def _remove_session(self, session_id: str): |
|
|
"""Remove a session and clean up resources""" |
|
|
with self.lock: |
|
|
if session_id not in self.sessions: |
|
|
return |
|
|
|
|
|
session = self.sessions[session_id] |
|
|
|
|
|
|
|
|
virtual_key = (session.virtual_ip, session.virtual_port, session.protocol) |
|
|
if virtual_key in self.virtual_to_session: |
|
|
del self.virtual_to_session[virtual_key] |
|
|
|
|
|
host_key = (session.host_ip, session.host_port, session.protocol) |
|
|
if host_key in self.host_to_session: |
|
|
del self.host_to_session[host_key] |
|
|
|
|
|
|
|
|
self.port_pool.release_port(session.host_port) |
|
|
|
|
|
|
|
|
del self.sessions[session_id] |
|
|
|
|
|
self.stats['active_sessions'] = len(self.sessions) |
|
|
|
|
|
def create_outbound_session(self, virtual_ip: str, virtual_port: int, |
|
|
real_ip: str, real_port: int, protocol: str) -> Optional[NATSession]: |
|
|
"""Create NAT session for outbound connection""" |
|
|
|
|
|
session_id = f"{virtual_ip}:{virtual_port}-{real_ip}:{real_port}-{protocol}" |
|
|
host_port = self.port_pool.allocate_port(session_id) |
|
|
|
|
|
if host_port is None: |
|
|
self.stats['port_exhaustion_events'] += 1 |
|
|
return None |
|
|
|
|
|
|
|
|
session = NATSession( |
|
|
virtual_ip=virtual_ip, |
|
|
virtual_port=virtual_port, |
|
|
real_ip=real_ip, |
|
|
real_port=real_port, |
|
|
host_ip=self.host_ip, |
|
|
host_port=host_port, |
|
|
protocol=protocol, |
|
|
nat_type=NATType.SNAT, |
|
|
created_time=time.time(), |
|
|
last_activity=time.time() |
|
|
) |
|
|
|
|
|
with self.lock: |
|
|
self.sessions[session_id] = session |
|
|
|
|
|
|
|
|
virtual_key = (virtual_ip, virtual_port, protocol) |
|
|
self.virtual_to_session[virtual_key] = session_id |
|
|
|
|
|
host_key = (self.host_ip, host_port, protocol) |
|
|
self.host_to_session[host_key] = session_id |
|
|
|
|
|
self.stats['total_sessions'] += 1 |
|
|
self.stats['active_sessions'] = len(self.sessions) |
|
|
|
|
|
return session |
|
|
|
|
|
def translate_outbound(self, virtual_ip: str, virtual_port: int, |
|
|
real_ip: str, real_port: int, protocol: str) -> Optional[Tuple[str, int]]: |
|
|
"""Translate outbound packet (virtual -> host)""" |
|
|
virtual_key = (virtual_ip, virtual_port, protocol) |
|
|
|
|
|
with self.lock: |
|
|
session_id = self.virtual_to_session.get(virtual_key) |
|
|
|
|
|
if session_id: |
|
|
session = self.sessions[session_id] |
|
|
session.update_activity(direction='out') |
|
|
return (session.host_ip, session.host_port) |
|
|
else: |
|
|
|
|
|
session = self.create_outbound_session(virtual_ip, virtual_port, real_ip, real_port, protocol) |
|
|
if session: |
|
|
return (session.host_ip, session.host_port) |
|
|
|
|
|
return None |
|
|
|
|
|
def translate_inbound(self, host_ip: str, host_port: int, protocol: str) -> Optional[Tuple[str, int]]: |
|
|
"""Translate inbound packet (host -> virtual)""" |
|
|
host_key = (host_ip, host_port, protocol) |
|
|
|
|
|
with self.lock: |
|
|
session_id = self.host_to_session.get(host_key) |
|
|
|
|
|
if session_id and session_id in self.sessions: |
|
|
session = self.sessions[session_id] |
|
|
session.update_activity(direction='in') |
|
|
return (session.virtual_ip, session.virtual_port) |
|
|
|
|
|
return None |
|
|
|
|
|
def get_session_by_virtual(self, virtual_ip: str, virtual_port: int, protocol: str) -> Optional[NATSession]: |
|
|
"""Get session by virtual endpoint""" |
|
|
virtual_key = (virtual_ip, virtual_port, protocol) |
|
|
|
|
|
with self.lock: |
|
|
session_id = self.virtual_to_session.get(virtual_key) |
|
|
if session_id and session_id in self.sessions: |
|
|
return self.sessions[session_id] |
|
|
|
|
|
return None |
|
|
|
|
|
def get_session_by_host(self, host_ip: str, host_port: int, protocol: str) -> Optional[NATSession]: |
|
|
"""Get session by host endpoint""" |
|
|
host_key = (host_ip, host_port, protocol) |
|
|
|
|
|
with self.lock: |
|
|
session_id = self.host_to_session.get(host_key) |
|
|
if session_id and session_id in self.sessions: |
|
|
return self.sessions[session_id] |
|
|
|
|
|
return None |
|
|
|
|
|
def close_session(self, session_id: str) -> bool: |
|
|
"""Manually close a session""" |
|
|
with self.lock: |
|
|
if session_id in self.sessions: |
|
|
self._remove_session(session_id) |
|
|
return True |
|
|
return False |
|
|
|
|
|
def close_session_by_virtual(self, virtual_ip: str, virtual_port: int, protocol: str) -> bool: |
|
|
"""Close session by virtual endpoint""" |
|
|
virtual_key = (virtual_ip, virtual_port, protocol) |
|
|
|
|
|
with self.lock: |
|
|
session_id = self.virtual_to_session.get(virtual_key) |
|
|
if session_id: |
|
|
self._remove_session(session_id) |
|
|
return True |
|
|
return False |
|
|
|
|
|
def get_sessions(self) -> Dict[str, Dict]: |
|
|
"""Get all active sessions""" |
|
|
with self.lock: |
|
|
return { |
|
|
session_id: { |
|
|
'virtual_ip': session.virtual_ip, |
|
|
'virtual_port': session.virtual_port, |
|
|
'real_ip': session.real_ip, |
|
|
'real_port': session.real_port, |
|
|
'host_ip': session.host_ip, |
|
|
'host_port': session.host_port, |
|
|
'protocol': session.protocol, |
|
|
'nat_type': session.nat_type.value, |
|
|
'created_time': session.created_time, |
|
|
'last_activity': session.last_activity, |
|
|
'duration': session.duration, |
|
|
'bytes_in': session.bytes_in, |
|
|
'bytes_out': session.bytes_out, |
|
|
'packets_in': session.packets_in, |
|
|
'packets_out': session.packets_out, |
|
|
'is_expired': session.is_expired |
|
|
} |
|
|
for session_id, session in self.sessions.items() |
|
|
} |
|
|
|
|
|
def get_stats(self) -> Dict: |
|
|
"""Get NAT statistics""" |
|
|
port_stats = self.port_pool.get_stats() |
|
|
|
|
|
with self.lock: |
|
|
current_stats = self.stats.copy() |
|
|
current_stats['active_sessions'] = len(self.sessions) |
|
|
current_stats.update(port_stats) |
|
|
|
|
|
return current_stats |
|
|
|
|
|
def update_packet_stats(self, bytes_count: int): |
|
|
"""Update packet statistics""" |
|
|
self.stats['bytes_translated'] += bytes_count |
|
|
self.stats['packets_translated'] += 1 |
|
|
|
|
|
def _cleanup_loop(self): |
|
|
"""Background cleanup loop""" |
|
|
while self.running: |
|
|
try: |
|
|
self._cleanup_expired_sessions() |
|
|
time.sleep(30) |
|
|
except Exception as e: |
|
|
print(f"NAT cleanup error: {e}") |
|
|
time.sleep(5) |
|
|
|
|
|
def start(self): |
|
|
"""Start NAT engine""" |
|
|
self.running = True |
|
|
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) |
|
|
self.cleanup_thread.start() |
|
|
print(f"NAT engine started - Host IP: {self.host_ip}, Port range: {self.port_pool.start_port}-{self.port_pool.end_port}") |
|
|
|
|
|
def stop(self): |
|
|
"""Stop NAT engine""" |
|
|
self.running = False |
|
|
if self.cleanup_thread: |
|
|
self.cleanup_thread.join() |
|
|
|
|
|
|
|
|
with self.lock: |
|
|
session_ids = list(self.sessions.keys()) |
|
|
for session_id in session_ids: |
|
|
self._remove_session(session_id) |
|
|
|
|
|
print("NAT engine stopped") |
|
|
|
|
|
def reset_stats(self): |
|
|
"""Reset statistics""" |
|
|
self.stats = { |
|
|
'total_sessions': 0, |
|
|
'active_sessions': len(self.sessions), |
|
|
'expired_sessions': 0, |
|
|
'port_exhaustion_events': 0, |
|
|
'bytes_translated': 0, |
|
|
'packets_translated': 0 |
|
|
} |
|
|
|
|
|
|
|
|
class NATRule: |
|
|
"""Represents a NAT rule for DNAT (port forwarding)""" |
|
|
|
|
|
def __init__(self, external_port: int, internal_ip: str, internal_port: int, |
|
|
protocol: str = 'TCP', enabled: bool = True): |
|
|
self.external_port = external_port |
|
|
self.internal_ip = internal_ip |
|
|
self.internal_port = internal_port |
|
|
self.protocol = protocol.upper() |
|
|
self.enabled = enabled |
|
|
self.created_time = time.time() |
|
|
self.hit_count = 0 |
|
|
self.last_hit = None |
|
|
|
|
|
def matches(self, port: int, protocol: str) -> bool: |
|
|
"""Check if rule matches the given port and protocol""" |
|
|
return (self.enabled and |
|
|
self.external_port == port and |
|
|
self.protocol == protocol.upper()) |
|
|
|
|
|
def record_hit(self): |
|
|
"""Record a rule hit""" |
|
|
self.hit_count += 1 |
|
|
self.last_hit = time.time() |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
"""Convert rule to dictionary""" |
|
|
return { |
|
|
'external_port': self.external_port, |
|
|
'internal_ip': self.internal_ip, |
|
|
'internal_port': self.internal_port, |
|
|
'protocol': self.protocol, |
|
|
'enabled': self.enabled, |
|
|
'created_time': self.created_time, |
|
|
'hit_count': self.hit_count, |
|
|
'last_hit': self.last_hit |
|
|
} |
|
|
|
|
|
|
|
|
class DNATEngine: |
|
|
"""Destination NAT engine for port forwarding""" |
|
|
|
|
|
def __init__(self): |
|
|
self.rules: Dict[str, NATRule] = {} |
|
|
self.lock = threading.Lock() |
|
|
|
|
|
def add_rule(self, rule_id: str, external_port: int, internal_ip: str, |
|
|
internal_port: int, protocol: str = 'TCP') -> bool: |
|
|
"""Add DNAT rule""" |
|
|
with self.lock: |
|
|
if rule_id in self.rules: |
|
|
return False |
|
|
|
|
|
rule = NATRule(external_port, internal_ip, internal_port, protocol) |
|
|
self.rules[rule_id] = rule |
|
|
return True |
|
|
|
|
|
def remove_rule(self, rule_id: str) -> bool: |
|
|
"""Remove DNAT rule""" |
|
|
with self.lock: |
|
|
if rule_id in self.rules: |
|
|
del self.rules[rule_id] |
|
|
return True |
|
|
return False |
|
|
|
|
|
def enable_rule(self, rule_id: str) -> bool: |
|
|
"""Enable DNAT rule""" |
|
|
with self.lock: |
|
|
if rule_id in self.rules: |
|
|
self.rules[rule_id].enabled = True |
|
|
return True |
|
|
return False |
|
|
|
|
|
def disable_rule(self, rule_id: str) -> bool: |
|
|
"""Disable DNAT rule""" |
|
|
with self.lock: |
|
|
if rule_id in self.rules: |
|
|
self.rules[rule_id].enabled = False |
|
|
return True |
|
|
return False |
|
|
|
|
|
def translate_inbound_dnat(self, external_port: int, protocol: str) -> Optional[Tuple[str, int]]: |
|
|
"""Translate inbound packet using DNAT rules""" |
|
|
with self.lock: |
|
|
for rule in self.rules.values(): |
|
|
if rule.matches(external_port, protocol): |
|
|
rule.record_hit() |
|
|
return (rule.internal_ip, rule.internal_port) |
|
|
|
|
|
return None |
|
|
|
|
|
def get_rules(self) -> Dict[str, Dict]: |
|
|
"""Get all DNAT rules""" |
|
|
with self.lock: |
|
|
return { |
|
|
rule_id: rule.to_dict() |
|
|
for rule_id, rule in self.rules.items() |
|
|
} |
|
|
|
|
|
def clear_rules(self): |
|
|
"""Clear all DNAT rules""" |
|
|
with self.lock: |
|
|
self.rules.clear() |
|
|
|
|
|
|