NNT2 / core /nat_engine.py
Fred808's picture
Upload 54 files
47a9eda verified
"""
NAT Engine Module
Implements Network Address Translation:
- Map (virtualIP, virtualPort) to (hostIP, hostPort)
- Maintain connection tracking table
- Handle port allocation and deallocation
- Support connection state tracking
"""
import time
import threading
import socket
import random
from typing import Dict, Optional, Tuple, Set
from dataclasses import dataclass
from enum import Enum
from .ip_parser import IPProtocol
class NATType(Enum):
SNAT = "SNAT" # Source NAT
DNAT = "DNAT" # Destination NAT
@dataclass
class NATSession:
"""Represents a NAT session"""
# Virtual (internal) endpoint
virtual_ip: str
virtual_port: int
# Real (external) endpoint
real_ip: str
real_port: int
# Host (translated) endpoint
host_ip: str
host_port: int
# Session metadata
protocol: str # TCP or UDP
nat_type: NATType
created_time: float
last_activity: float
bytes_in: int = 0
bytes_out: int = 0
packets_in: int = 0
packets_out: int = 0
@property
def session_id(self) -> str:
"""Get unique session identifier"""
return f"{self.virtual_ip}:{self.virtual_port}-{self.real_ip}:{self.real_port}-{self.protocol}"
@property
def is_expired(self) -> bool:
"""Check if session has expired"""
timeout = 300 if self.protocol == 'TCP' else 60 # 5 min for TCP, 1 min for UDP
return time.time() - self.last_activity > timeout
@property
def duration(self) -> float:
"""Get session duration in seconds"""
return time.time() - self.created_time
def update_activity(self, bytes_transferred: int = 0, direction: str = 'out'):
"""Update session activity"""
self.last_activity = time.time()
if direction == 'out':
self.bytes_out += bytes_transferred
self.packets_out += 1
else:
self.bytes_in += bytes_transferred
self.packets_in += 1
class PortPool:
"""Manages available ports for NAT"""
def __init__(self, start_port: int = 10000, end_port: int = 65535):
self.start_port = start_port
self.end_port = end_port
self.available_ports: Set[int] = set(range(start_port, end_port + 1))
self.allocated_ports: Dict[int, str] = {} # port -> session_id
self.lock = threading.Lock()
def allocate_port(self, session_id: str) -> Optional[int]:
"""Allocate a port for a session"""
with self.lock:
if not self.available_ports:
return None
# Try to get a random port to distribute load
port = random.choice(list(self.available_ports))
self.available_ports.remove(port)
self.allocated_ports[port] = session_id
return port
def release_port(self, port: int) -> bool:
"""Release a port back to the pool"""
with self.lock:
if port in self.allocated_ports:
del self.allocated_ports[port]
if self.start_port <= port <= self.end_port:
self.available_ports.add(port)
return True
return False
def get_session_for_port(self, port: int) -> Optional[str]:
"""Get session ID for a port"""
with self.lock:
return self.allocated_ports.get(port)
def get_stats(self) -> Dict:
"""Get port pool statistics"""
with self.lock:
return {
'total_ports': self.end_port - self.start_port + 1,
'available_ports': len(self.available_ports),
'allocated_ports': len(self.allocated_ports),
'utilization': len(self.allocated_ports) / (self.end_port - self.start_port + 1)
}
class NATEngine:
"""Network Address Translation engine"""
def __init__(self, config: Dict):
self.config = config
self.sessions: Dict[str, NATSession] = {} # session_id -> session
self.virtual_to_session: Dict[Tuple[str, int, str], str] = {} # (vip, vport, proto) -> session_id
self.host_to_session: Dict[Tuple[str, int, str], str] = {} # (hip, hport, proto) -> session_id
self.lock = threading.Lock()
# Port pool for outbound connections
self.port_pool = PortPool(
config.get('port_range_start', 10000),
config.get('port_range_end', 65535)
)
# Host IP for outbound connections
self.host_ip = config.get('host_ip', self._get_default_host_ip())
# Session timeout
self.session_timeout = config.get('session_timeout', 300)
# Statistics
self.stats = {
'total_sessions': 0,
'active_sessions': 0,
'expired_sessions': 0,
'port_exhaustion_events': 0,
'bytes_translated': 0,
'packets_translated': 0
}
# Cleanup thread
self.running = False
self.cleanup_thread = None
def _get_default_host_ip(self) -> str:
"""Get default host IP address"""
try:
# Connect to a remote address to determine local IP
with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s:
s.connect(('8.8.8.8', 80))
return s.getsockname()[0]
except Exception:
return '127.0.0.1'
def _cleanup_expired_sessions(self):
"""Clean up expired sessions"""
current_time = time.time()
expired_sessions = []
with self.lock:
for session_id, session in self.sessions.items():
if session.is_expired:
expired_sessions.append(session_id)
for session_id in expired_sessions:
self._remove_session(session_id)
self.stats['expired_sessions'] += 1
def _remove_session(self, session_id: str):
"""Remove a session and clean up resources"""
with self.lock:
if session_id not in self.sessions:
return
session = self.sessions[session_id]
# Remove from lookup tables
virtual_key = (session.virtual_ip, session.virtual_port, session.protocol)
if virtual_key in self.virtual_to_session:
del self.virtual_to_session[virtual_key]
host_key = (session.host_ip, session.host_port, session.protocol)
if host_key in self.host_to_session:
del self.host_to_session[host_key]
# Release port
self.port_pool.release_port(session.host_port)
# Remove session
del self.sessions[session_id]
self.stats['active_sessions'] = len(self.sessions)
def create_outbound_session(self, virtual_ip: str, virtual_port: int,
real_ip: str, real_port: int, protocol: str) -> Optional[NATSession]:
"""Create NAT session for outbound connection"""
# Allocate host port
session_id = f"{virtual_ip}:{virtual_port}-{real_ip}:{real_port}-{protocol}"
host_port = self.port_pool.allocate_port(session_id)
if host_port is None:
self.stats['port_exhaustion_events'] += 1
return None
# Create session
session = NATSession(
virtual_ip=virtual_ip,
virtual_port=virtual_port,
real_ip=real_ip,
real_port=real_port,
host_ip=self.host_ip,
host_port=host_port,
protocol=protocol,
nat_type=NATType.SNAT,
created_time=time.time(),
last_activity=time.time()
)
with self.lock:
self.sessions[session_id] = session
# Add to lookup tables
virtual_key = (virtual_ip, virtual_port, protocol)
self.virtual_to_session[virtual_key] = session_id
host_key = (self.host_ip, host_port, protocol)
self.host_to_session[host_key] = session_id
self.stats['total_sessions'] += 1
self.stats['active_sessions'] = len(self.sessions)
return session
def translate_outbound(self, virtual_ip: str, virtual_port: int,
real_ip: str, real_port: int, protocol: str) -> Optional[Tuple[str, int]]:
"""Translate outbound packet (virtual -> host)"""
virtual_key = (virtual_ip, virtual_port, protocol)
with self.lock:
session_id = self.virtual_to_session.get(virtual_key)
if session_id:
session = self.sessions[session_id]
session.update_activity(direction='out')
return (session.host_ip, session.host_port)
else:
# Create new session
session = self.create_outbound_session(virtual_ip, virtual_port, real_ip, real_port, protocol)
if session:
return (session.host_ip, session.host_port)
return None
def translate_inbound(self, host_ip: str, host_port: int, protocol: str) -> Optional[Tuple[str, int]]:
"""Translate inbound packet (host -> virtual)"""
host_key = (host_ip, host_port, protocol)
with self.lock:
session_id = self.host_to_session.get(host_key)
if session_id and session_id in self.sessions:
session = self.sessions[session_id]
session.update_activity(direction='in')
return (session.virtual_ip, session.virtual_port)
return None
def get_session_by_virtual(self, virtual_ip: str, virtual_port: int, protocol: str) -> Optional[NATSession]:
"""Get session by virtual endpoint"""
virtual_key = (virtual_ip, virtual_port, protocol)
with self.lock:
session_id = self.virtual_to_session.get(virtual_key)
if session_id and session_id in self.sessions:
return self.sessions[session_id]
return None
def get_session_by_host(self, host_ip: str, host_port: int, protocol: str) -> Optional[NATSession]:
"""Get session by host endpoint"""
host_key = (host_ip, host_port, protocol)
with self.lock:
session_id = self.host_to_session.get(host_key)
if session_id and session_id in self.sessions:
return self.sessions[session_id]
return None
def close_session(self, session_id: str) -> bool:
"""Manually close a session"""
with self.lock:
if session_id in self.sessions:
self._remove_session(session_id)
return True
return False
def close_session_by_virtual(self, virtual_ip: str, virtual_port: int, protocol: str) -> bool:
"""Close session by virtual endpoint"""
virtual_key = (virtual_ip, virtual_port, protocol)
with self.lock:
session_id = self.virtual_to_session.get(virtual_key)
if session_id:
self._remove_session(session_id)
return True
return False
def get_sessions(self) -> Dict[str, Dict]:
"""Get all active sessions"""
with self.lock:
return {
session_id: {
'virtual_ip': session.virtual_ip,
'virtual_port': session.virtual_port,
'real_ip': session.real_ip,
'real_port': session.real_port,
'host_ip': session.host_ip,
'host_port': session.host_port,
'protocol': session.protocol,
'nat_type': session.nat_type.value,
'created_time': session.created_time,
'last_activity': session.last_activity,
'duration': session.duration,
'bytes_in': session.bytes_in,
'bytes_out': session.bytes_out,
'packets_in': session.packets_in,
'packets_out': session.packets_out,
'is_expired': session.is_expired
}
for session_id, session in self.sessions.items()
}
def get_stats(self) -> Dict:
"""Get NAT statistics"""
port_stats = self.port_pool.get_stats()
with self.lock:
current_stats = self.stats.copy()
current_stats['active_sessions'] = len(self.sessions)
current_stats.update(port_stats)
return current_stats
def update_packet_stats(self, bytes_count: int):
"""Update packet statistics"""
self.stats['bytes_translated'] += bytes_count
self.stats['packets_translated'] += 1
def _cleanup_loop(self):
"""Background cleanup loop"""
while self.running:
try:
self._cleanup_expired_sessions()
time.sleep(30) # Cleanup every 30 seconds
except Exception as e:
print(f"NAT cleanup error: {e}")
time.sleep(5)
def start(self):
"""Start NAT engine"""
self.running = True
self.cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True)
self.cleanup_thread.start()
print(f"NAT engine started - Host IP: {self.host_ip}, Port range: {self.port_pool.start_port}-{self.port_pool.end_port}")
def stop(self):
"""Stop NAT engine"""
self.running = False
if self.cleanup_thread:
self.cleanup_thread.join()
# Close all sessions
with self.lock:
session_ids = list(self.sessions.keys())
for session_id in session_ids:
self._remove_session(session_id)
print("NAT engine stopped")
def reset_stats(self):
"""Reset statistics"""
self.stats = {
'total_sessions': 0,
'active_sessions': len(self.sessions),
'expired_sessions': 0,
'port_exhaustion_events': 0,
'bytes_translated': 0,
'packets_translated': 0
}
class NATRule:
"""Represents a NAT rule for DNAT (port forwarding)"""
def __init__(self, external_port: int, internal_ip: str, internal_port: int,
protocol: str = 'TCP', enabled: bool = True):
self.external_port = external_port
self.internal_ip = internal_ip
self.internal_port = internal_port
self.protocol = protocol.upper()
self.enabled = enabled
self.created_time = time.time()
self.hit_count = 0
self.last_hit = None
def matches(self, port: int, protocol: str) -> bool:
"""Check if rule matches the given port and protocol"""
return (self.enabled and
self.external_port == port and
self.protocol == protocol.upper())
def record_hit(self):
"""Record a rule hit"""
self.hit_count += 1
self.last_hit = time.time()
def to_dict(self) -> Dict:
"""Convert rule to dictionary"""
return {
'external_port': self.external_port,
'internal_ip': self.internal_ip,
'internal_port': self.internal_port,
'protocol': self.protocol,
'enabled': self.enabled,
'created_time': self.created_time,
'hit_count': self.hit_count,
'last_hit': self.last_hit
}
class DNATEngine:
"""Destination NAT engine for port forwarding"""
def __init__(self):
self.rules: Dict[str, NATRule] = {} # rule_id -> rule
self.lock = threading.Lock()
def add_rule(self, rule_id: str, external_port: int, internal_ip: str,
internal_port: int, protocol: str = 'TCP') -> bool:
"""Add DNAT rule"""
with self.lock:
if rule_id in self.rules:
return False
rule = NATRule(external_port, internal_ip, internal_port, protocol)
self.rules[rule_id] = rule
return True
def remove_rule(self, rule_id: str) -> bool:
"""Remove DNAT rule"""
with self.lock:
if rule_id in self.rules:
del self.rules[rule_id]
return True
return False
def enable_rule(self, rule_id: str) -> bool:
"""Enable DNAT rule"""
with self.lock:
if rule_id in self.rules:
self.rules[rule_id].enabled = True
return True
return False
def disable_rule(self, rule_id: str) -> bool:
"""Disable DNAT rule"""
with self.lock:
if rule_id in self.rules:
self.rules[rule_id].enabled = False
return True
return False
def translate_inbound_dnat(self, external_port: int, protocol: str) -> Optional[Tuple[str, int]]:
"""Translate inbound packet using DNAT rules"""
with self.lock:
for rule in self.rules.values():
if rule.matches(external_port, protocol):
rule.record_hit()
return (rule.internal_ip, rule.internal_port)
return None
def get_rules(self) -> Dict[str, Dict]:
"""Get all DNAT rules"""
with self.lock:
return {
rule_id: rule.to_dict()
for rule_id, rule in self.rules.items()
}
def clear_rules(self):
"""Clear all DNAT rules"""
with self.lock:
self.rules.clear()