TNT / core /virtual_router.py
Fred808's picture
Upload 48 files
50d86e3 verified
"""
Virtual Router Module
Implements packet routing between virtual clients and external internet:
- Maintain routing table for virtual network
- Forward packets based on destination IP
- Handle internal vs external routing decisions
- Support static route configuration
"""
import ipaddress
import time
import threading
from typing import Dict, List, Optional, Tuple, Set
from dataclasses import dataclass
from enum import Enum
from .ip_parser import ParsedPacket, IPv4Header
class RouteType(Enum):
DIRECT = "DIRECT" # Directly connected network
STATIC = "STATIC" # Static route
DEFAULT = "DEFAULT" # Default route
@dataclass
class RouteEntry:
"""Represents a routing table entry"""
destination: str # Network in CIDR notation (e.g., "10.0.0.0/24")
gateway: Optional[str] # Next hop IP (None for direct routes)
interface: str # Interface name or identifier
metric: int # Route metric (lower is preferred)
route_type: RouteType
created_time: float
last_used: Optional[float] = None
use_count: int = 0
def __post_init__(self):
if self.created_time == 0:
self.created_time = time.time()
def record_use(self):
"""Record route usage"""
self.use_count += 1
self.last_used = time.time()
def matches_destination(self, ip: str) -> bool:
"""Check if this route matches the destination IP"""
try:
network = ipaddress.ip_network(self.destination, strict=False)
return ipaddress.ip_address(ip) in network
except (ipaddress.AddressValueError, ValueError):
return False
def to_dict(self) -> Dict:
"""Convert route to dictionary"""
return {
'destination': self.destination,
'gateway': self.gateway,
'interface': self.interface,
'metric': self.metric,
'route_type': self.route_type.value,
'created_time': self.created_time,
'last_used': self.last_used,
'use_count': self.use_count
}
@dataclass
class Interface:
"""Represents a network interface"""
name: str
ip_address: str
netmask: str
network: str # Network in CIDR notation
enabled: bool = True
mtu: int = 1500
created_time: float = 0
def __post_init__(self):
if self.created_time == 0:
self.created_time = time.time()
# Calculate network if not provided
if not self.network:
try:
interface_network = ipaddress.ip_interface(f"{self.ip_address}/{self.netmask}")
self.network = str(interface_network.network)
except (ipaddress.AddressValueError, ValueError):
self.network = "0.0.0.0/0"
def is_local_address(self, ip: str) -> bool:
"""Check if IP address belongs to this interface's network"""
try:
network = ipaddress.ip_network(self.network, strict=False)
return ipaddress.ip_address(ip) in network
except (ipaddress.AddressValueError, ValueError):
return False
def to_dict(self) -> Dict:
"""Convert interface to dictionary"""
return {
'name': self.name,
'ip_address': self.ip_address,
'netmask': self.netmask,
'network': self.network,
'enabled': self.enabled,
'mtu': self.mtu,
'created_time': self.created_time
}
class VirtualRouter:
"""Virtual router implementation"""
def __init__(self, config: Dict):
self.config = config
self.routing_table: List[RouteEntry] = []
self.interfaces: Dict[str, Interface] = {}
self.arp_table: Dict[str, str] = {} # IP -> MAC mapping
self.lock = threading.Lock()
# Router configuration
self.router_id = config.get('router_id', 'virtual-router-1')
self.default_gateway = config.get('default_gateway')
# Statistics
self.stats = {
'packets_routed': 0,
'packets_dropped': 0,
'route_lookups': 0,
'arp_requests': 0,
'arp_replies': 0,
'routing_errors': 0
}
# Initialize interfaces and routes
self._initialize_interfaces()
self._initialize_routes()
def _initialize_interfaces(self):
"""Initialize network interfaces from configuration"""
interfaces_config = self.config.get('interfaces', [])
for iface_config in interfaces_config:
interface = Interface(
name=iface_config['name'],
ip_address=iface_config['ip_address'],
netmask=iface_config.get('netmask', '255.255.255.0'),
network=iface_config.get('network'),
enabled=iface_config.get('enabled', True),
mtu=iface_config.get('mtu', 1500)
)
with self.lock:
self.interfaces[interface.name] = interface
# Add direct route for interface network
self.add_route(
destination=interface.network,
gateway=None,
interface=interface.name,
metric=0,
route_type=RouteType.DIRECT
)
def _initialize_routes(self):
"""Initialize static routes from configuration"""
routes_config = self.config.get('static_routes', [])
for route_config in routes_config:
self.add_route(
destination=route_config['destination'],
gateway=route_config.get('gateway'),
interface=route_config['interface'],
metric=route_config.get('metric', 10),
route_type=RouteType.STATIC
)
# Add default route if configured
if self.default_gateway:
# Find interface for default gateway
default_interface = None
for interface in self.interfaces.values():
if interface.is_local_address(self.default_gateway):
default_interface = interface.name
break
if default_interface:
self.add_route(
destination="0.0.0.0/0",
gateway=self.default_gateway,
interface=default_interface,
metric=100,
route_type=RouteType.DEFAULT
)
def add_interface(self, name: str, ip_address: str, netmask: str = "255.255.255.0",
network: Optional[str] = None, mtu: int = 1500) -> bool:
"""Add network interface"""
with self.lock:
if name in self.interfaces:
return False
interface = Interface(
name=name,
ip_address=ip_address,
netmask=netmask,
network=network,
mtu=mtu
)
self.interfaces[name] = interface
# Add direct route for interface network
self.add_route(
destination=interface.network,
gateway=None,
interface=name,
metric=0,
route_type=RouteType.DIRECT
)
return True
def remove_interface(self, name: str) -> bool:
"""Remove network interface"""
with self.lock:
if name not in self.interfaces:
return False
# Remove interface
del self.interfaces[name]
# Remove routes associated with this interface
self.routing_table = [
route for route in self.routing_table
if route.interface != name
]
return True
def enable_interface(self, name: str) -> bool:
"""Enable network interface"""
with self.lock:
if name in self.interfaces:
self.interfaces[name].enabled = True
return True
return False
def disable_interface(self, name: str) -> bool:
"""Disable network interface"""
with self.lock:
if name in self.interfaces:
self.interfaces[name].enabled = False
return True
return False
def add_route(self, destination: str, gateway: Optional[str], interface: str,
metric: int = 10, route_type: RouteType = RouteType.STATIC) -> bool:
"""Add route to routing table"""
try:
# Validate destination network
ipaddress.ip_network(destination, strict=False)
# Validate gateway if provided
if gateway:
ipaddress.ip_address(gateway)
route = RouteEntry(
destination=destination,
gateway=gateway,
interface=interface,
metric=metric,
route_type=route_type,
created_time=time.time()
)
with self.lock:
# Check if interface exists
if interface not in self.interfaces:
return False
# Remove existing route with same destination and interface
self.routing_table = [
r for r in self.routing_table
if not (r.destination == destination and r.interface == interface)
]
# Add new route
self.routing_table.append(route)
# Sort by metric (lower metric = higher priority)
self.routing_table.sort(key=lambda r: (r.metric, r.created_time))
return True
except (ipaddress.AddressValueError, ValueError):
return False
def remove_route(self, destination: str, interface: str) -> bool:
"""Remove route from routing table"""
with self.lock:
original_count = len(self.routing_table)
self.routing_table = [
route for route in self.routing_table
if not (route.destination == destination and route.interface == interface)
]
return len(self.routing_table) < original_count
def lookup_route(self, destination_ip: str) -> Optional[RouteEntry]:
"""Look up route for destination IP"""
self.stats['route_lookups'] += 1
with self.lock:
# Find all matching routes
matching_routes = []
for route in self.routing_table:
# Skip disabled interfaces
interface = self.interfaces.get(route.interface)
if not interface or not interface.enabled:
continue
if route.matches_destination(destination_ip):
matching_routes.append(route)
if not matching_routes:
self.stats['routing_errors'] += 1
return None
# Sort by specificity (longest prefix match) and then by metric
def route_priority(route):
try:
network = ipaddress.ip_network(route.destination, strict=False)
return (-network.prefixlen, route.metric, route.created_time)
except:
return (0, route.metric, route.created_time)
matching_routes.sort(key=route_priority)
best_route = matching_routes[0]
best_route.record_use()
return best_route
def route_packet(self, packet: ParsedPacket) -> Optional[Tuple[str, str]]:
"""Route packet and return (next_hop_ip, interface)"""
self.stats['packets_routed'] += 1
destination_ip = packet.ip_header.dest_ip
# Look up route
route = self.lookup_route(destination_ip)
if not route:
self.stats['packets_dropped'] += 1
return None
# Determine next hop
if route.gateway:
next_hop = route.gateway
else:
# Direct route - destination is next hop
next_hop = destination_ip
return (next_hop, route.interface)
def is_local_destination(self, ip: str) -> bool:
"""Check if IP is a local destination (belongs to router interfaces)"""
with self.lock:
for interface in self.interfaces.values():
if interface.ip_address == ip:
return True
return False
def is_local_network(self, ip: str) -> bool:
"""Check if IP belongs to any local network"""
with self.lock:
for interface in self.interfaces.values():
if interface.is_local_address(ip):
return True
return False
def get_interface_for_ip(self, ip: str) -> Optional[Interface]:
"""Get interface that can reach the given IP"""
with self.lock:
for interface in self.interfaces.values():
if interface.enabled and interface.is_local_address(ip):
return interface
return None
def add_arp_entry(self, ip: str, mac: str):
"""Add ARP table entry"""
with self.lock:
self.arp_table[ip] = mac
def get_arp_entry(self, ip: str) -> Optional[str]:
"""Get MAC address from ARP table"""
with self.lock:
return self.arp_table.get(ip)
def remove_arp_entry(self, ip: str) -> bool:
"""Remove ARP table entry"""
with self.lock:
if ip in self.arp_table:
del self.arp_table[ip]
return True
return False
def clear_arp_table(self):
"""Clear ARP table"""
with self.lock:
self.arp_table.clear()
def get_routing_table(self) -> List[Dict]:
"""Get routing table"""
with self.lock:
return [route.to_dict() for route in self.routing_table]
def get_interfaces(self) -> Dict[str, Dict]:
"""Get network interfaces"""
with self.lock:
return {
name: interface.to_dict()
for name, interface in self.interfaces.items()
}
def get_arp_table(self) -> Dict[str, str]:
"""Get ARP table"""
with self.lock:
return self.arp_table.copy()
def get_stats(self) -> Dict:
"""Get router statistics"""
with self.lock:
stats = self.stats.copy()
stats['total_routes'] = len(self.routing_table)
stats['total_interfaces'] = len(self.interfaces)
stats['enabled_interfaces'] = sum(1 for iface in self.interfaces.values() if iface.enabled)
stats['arp_entries'] = len(self.arp_table)
return stats
def reset_stats(self):
"""Reset router statistics"""
self.stats = {
'packets_routed': 0,
'packets_dropped': 0,
'route_lookups': 0,
'arp_requests': 0,
'arp_replies': 0,
'routing_errors': 0
}
# Reset route usage statistics
with self.lock:
for route in self.routing_table:
route.use_count = 0
route.last_used = None
def flush_routes(self, route_type: Optional[RouteType] = None):
"""Flush routes of specified type (or all if None)"""
with self.lock:
if route_type:
self.routing_table = [
route for route in self.routing_table
if route.route_type != route_type
]
else:
self.routing_table.clear()
def export_config(self) -> Dict:
"""Export router configuration"""
return {
'router_id': self.router_id,
'default_gateway': self.default_gateway,
'interfaces': [
{
'name': iface.name,
'ip_address': iface.ip_address,
'netmask': iface.netmask,
'network': iface.network,
'enabled': iface.enabled,
'mtu': iface.mtu
}
for iface in self.interfaces.values()
],
'static_routes': [
{
'destination': route.destination,
'gateway': route.gateway,
'interface': route.interface,
'metric': route.metric
}
for route in self.routing_table
if route.route_type == RouteType.STATIC
]
}
def import_config(self, config: Dict):
"""Import router configuration"""
# Clear existing configuration
with self.lock:
self.interfaces.clear()
self.routing_table.clear()
self.arp_table.clear()
# Update router settings
self.router_id = config.get('router_id', self.router_id)
self.default_gateway = config.get('default_gateway', self.default_gateway)
# Reinitialize from new config
self.config.update(config)
self._initialize_interfaces()
self._initialize_routes()
class RouterUtils:
"""Utility functions for router operations"""
@staticmethod
def ip_to_int(ip: str) -> int:
"""Convert IP address to integer"""
return int(ipaddress.ip_address(ip))
@staticmethod
def int_to_ip(ip_int: int) -> str:
"""Convert integer to IP address"""
return str(ipaddress.ip_address(ip_int))
@staticmethod
def calculate_network(ip: str, netmask: str) -> str:
"""Calculate network address from IP and netmask"""
try:
interface = ipaddress.ip_interface(f"{ip}/{netmask}")
return str(interface.network)
except (ipaddress.AddressValueError, ValueError):
return "0.0.0.0/0"
@staticmethod
def is_private_ip(ip: str) -> bool:
"""Check if IP address is private"""
try:
ip_obj = ipaddress.ip_address(ip)
return ip_obj.is_private
except (ipaddress.AddressValueError, ValueError):
return False
@staticmethod
def is_multicast_ip(ip: str) -> bool:
"""Check if IP address is multicast"""
try:
ip_obj = ipaddress.ip_address(ip)
return ip_obj.is_multicast
except (ipaddress.AddressValueError, ValueError):
return False
@staticmethod
def validate_cidr(cidr: str) -> bool:
"""Validate CIDR notation"""
try:
ipaddress.ip_network(cidr, strict=False)
return True
except (ipaddress.AddressValueError, ValueError):
return False