|
|
""" |
|
|
Virtual Router Module |
|
|
|
|
|
Implements packet routing between virtual clients and external internet: |
|
|
- Maintain routing table for virtual network |
|
|
- Forward packets based on destination IP |
|
|
- Handle internal vs external routing decisions |
|
|
- Support static route configuration |
|
|
""" |
|
|
|
|
|
import ipaddress |
|
|
import time |
|
|
import threading |
|
|
from typing import Dict, List, Optional, Tuple, Set |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
|
|
|
from .ip_parser import ParsedPacket, IPv4Header |
|
|
|
|
|
|
|
|
class RouteType(Enum): |
|
|
DIRECT = "DIRECT" |
|
|
STATIC = "STATIC" |
|
|
DEFAULT = "DEFAULT" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class RouteEntry: |
|
|
"""Represents a routing table entry""" |
|
|
destination: str |
|
|
gateway: Optional[str] |
|
|
interface: str |
|
|
metric: int |
|
|
route_type: RouteType |
|
|
created_time: float |
|
|
last_used: Optional[float] = None |
|
|
use_count: int = 0 |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.created_time == 0: |
|
|
self.created_time = time.time() |
|
|
|
|
|
def record_use(self): |
|
|
"""Record route usage""" |
|
|
self.use_count += 1 |
|
|
self.last_used = time.time() |
|
|
|
|
|
def matches_destination(self, ip: str) -> bool: |
|
|
"""Check if this route matches the destination IP""" |
|
|
try: |
|
|
network = ipaddress.ip_network(self.destination, strict=False) |
|
|
return ipaddress.ip_address(ip) in network |
|
|
except (ipaddress.AddressValueError, ValueError): |
|
|
return False |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
"""Convert route to dictionary""" |
|
|
return { |
|
|
'destination': self.destination, |
|
|
'gateway': self.gateway, |
|
|
'interface': self.interface, |
|
|
'metric': self.metric, |
|
|
'route_type': self.route_type.value, |
|
|
'created_time': self.created_time, |
|
|
'last_used': self.last_used, |
|
|
'use_count': self.use_count |
|
|
} |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Interface: |
|
|
"""Represents a network interface""" |
|
|
name: str |
|
|
ip_address: str |
|
|
netmask: str |
|
|
network: str |
|
|
enabled: bool = True |
|
|
mtu: int = 1500 |
|
|
created_time: float = 0 |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.created_time == 0: |
|
|
self.created_time = time.time() |
|
|
|
|
|
|
|
|
if not self.network: |
|
|
try: |
|
|
interface_network = ipaddress.ip_interface(f"{self.ip_address}/{self.netmask}") |
|
|
self.network = str(interface_network.network) |
|
|
except (ipaddress.AddressValueError, ValueError): |
|
|
self.network = "0.0.0.0/0" |
|
|
|
|
|
def is_local_address(self, ip: str) -> bool: |
|
|
"""Check if IP address belongs to this interface's network""" |
|
|
try: |
|
|
network = ipaddress.ip_network(self.network, strict=False) |
|
|
return ipaddress.ip_address(ip) in network |
|
|
except (ipaddress.AddressValueError, ValueError): |
|
|
return False |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
"""Convert interface to dictionary""" |
|
|
return { |
|
|
'name': self.name, |
|
|
'ip_address': self.ip_address, |
|
|
'netmask': self.netmask, |
|
|
'network': self.network, |
|
|
'enabled': self.enabled, |
|
|
'mtu': self.mtu, |
|
|
'created_time': self.created_time |
|
|
} |
|
|
|
|
|
|
|
|
class VirtualRouter: |
|
|
"""Virtual router implementation""" |
|
|
|
|
|
def __init__(self, config: Dict): |
|
|
self.config = config |
|
|
self.routing_table: List[RouteEntry] = [] |
|
|
self.interfaces: Dict[str, Interface] = {} |
|
|
self.arp_table: Dict[str, str] = {} |
|
|
self.lock = threading.Lock() |
|
|
|
|
|
|
|
|
self.router_id = config.get('router_id', 'virtual-router-1') |
|
|
self.default_gateway = config.get('default_gateway') |
|
|
|
|
|
|
|
|
self.stats = { |
|
|
'packets_routed': 0, |
|
|
'packets_dropped': 0, |
|
|
'route_lookups': 0, |
|
|
'arp_requests': 0, |
|
|
'arp_replies': 0, |
|
|
'routing_errors': 0 |
|
|
} |
|
|
|
|
|
|
|
|
self._initialize_interfaces() |
|
|
self._initialize_routes() |
|
|
|
|
|
def _initialize_interfaces(self): |
|
|
"""Initialize network interfaces from configuration""" |
|
|
interfaces_config = self.config.get('interfaces', []) |
|
|
|
|
|
for iface_config in interfaces_config: |
|
|
interface = Interface( |
|
|
name=iface_config['name'], |
|
|
ip_address=iface_config['ip_address'], |
|
|
netmask=iface_config.get('netmask', '255.255.255.0'), |
|
|
network=iface_config.get('network'), |
|
|
enabled=iface_config.get('enabled', True), |
|
|
mtu=iface_config.get('mtu', 1500) |
|
|
) |
|
|
|
|
|
with self.lock: |
|
|
self.interfaces[interface.name] = interface |
|
|
|
|
|
|
|
|
self.add_route( |
|
|
destination=interface.network, |
|
|
gateway=None, |
|
|
interface=interface.name, |
|
|
metric=0, |
|
|
route_type=RouteType.DIRECT |
|
|
) |
|
|
|
|
|
def _initialize_routes(self): |
|
|
"""Initialize static routes from configuration""" |
|
|
routes_config = self.config.get('static_routes', []) |
|
|
|
|
|
for route_config in routes_config: |
|
|
self.add_route( |
|
|
destination=route_config['destination'], |
|
|
gateway=route_config.get('gateway'), |
|
|
interface=route_config['interface'], |
|
|
metric=route_config.get('metric', 10), |
|
|
route_type=RouteType.STATIC |
|
|
) |
|
|
|
|
|
|
|
|
if self.default_gateway: |
|
|
|
|
|
default_interface = None |
|
|
for interface in self.interfaces.values(): |
|
|
if interface.is_local_address(self.default_gateway): |
|
|
default_interface = interface.name |
|
|
break |
|
|
|
|
|
if default_interface: |
|
|
self.add_route( |
|
|
destination="0.0.0.0/0", |
|
|
gateway=self.default_gateway, |
|
|
interface=default_interface, |
|
|
metric=100, |
|
|
route_type=RouteType.DEFAULT |
|
|
) |
|
|
|
|
|
def add_interface(self, name: str, ip_address: str, netmask: str = "255.255.255.0", |
|
|
network: Optional[str] = None, mtu: int = 1500) -> bool: |
|
|
"""Add network interface""" |
|
|
with self.lock: |
|
|
if name in self.interfaces: |
|
|
return False |
|
|
|
|
|
interface = Interface( |
|
|
name=name, |
|
|
ip_address=ip_address, |
|
|
netmask=netmask, |
|
|
network=network, |
|
|
mtu=mtu |
|
|
) |
|
|
|
|
|
self.interfaces[name] = interface |
|
|
|
|
|
|
|
|
self.add_route( |
|
|
destination=interface.network, |
|
|
gateway=None, |
|
|
interface=name, |
|
|
metric=0, |
|
|
route_type=RouteType.DIRECT |
|
|
) |
|
|
|
|
|
return True |
|
|
|
|
|
def remove_interface(self, name: str) -> bool: |
|
|
"""Remove network interface""" |
|
|
with self.lock: |
|
|
if name not in self.interfaces: |
|
|
return False |
|
|
|
|
|
|
|
|
del self.interfaces[name] |
|
|
|
|
|
|
|
|
self.routing_table = [ |
|
|
route for route in self.routing_table |
|
|
if route.interface != name |
|
|
] |
|
|
|
|
|
return True |
|
|
|
|
|
def enable_interface(self, name: str) -> bool: |
|
|
"""Enable network interface""" |
|
|
with self.lock: |
|
|
if name in self.interfaces: |
|
|
self.interfaces[name].enabled = True |
|
|
return True |
|
|
return False |
|
|
|
|
|
def disable_interface(self, name: str) -> bool: |
|
|
"""Disable network interface""" |
|
|
with self.lock: |
|
|
if name in self.interfaces: |
|
|
self.interfaces[name].enabled = False |
|
|
return True |
|
|
return False |
|
|
|
|
|
def add_route(self, destination: str, gateway: Optional[str], interface: str, |
|
|
metric: int = 10, route_type: RouteType = RouteType.STATIC) -> bool: |
|
|
"""Add route to routing table""" |
|
|
try: |
|
|
|
|
|
ipaddress.ip_network(destination, strict=False) |
|
|
|
|
|
|
|
|
if gateway: |
|
|
ipaddress.ip_address(gateway) |
|
|
|
|
|
route = RouteEntry( |
|
|
destination=destination, |
|
|
gateway=gateway, |
|
|
interface=interface, |
|
|
metric=metric, |
|
|
route_type=route_type, |
|
|
created_time=time.time() |
|
|
) |
|
|
|
|
|
with self.lock: |
|
|
|
|
|
if interface not in self.interfaces: |
|
|
return False |
|
|
|
|
|
|
|
|
self.routing_table = [ |
|
|
r for r in self.routing_table |
|
|
if not (r.destination == destination and r.interface == interface) |
|
|
] |
|
|
|
|
|
|
|
|
self.routing_table.append(route) |
|
|
|
|
|
|
|
|
self.routing_table.sort(key=lambda r: (r.metric, r.created_time)) |
|
|
|
|
|
return True |
|
|
|
|
|
except (ipaddress.AddressValueError, ValueError): |
|
|
return False |
|
|
|
|
|
def remove_route(self, destination: str, interface: str) -> bool: |
|
|
"""Remove route from routing table""" |
|
|
with self.lock: |
|
|
original_count = len(self.routing_table) |
|
|
self.routing_table = [ |
|
|
route for route in self.routing_table |
|
|
if not (route.destination == destination and route.interface == interface) |
|
|
] |
|
|
return len(self.routing_table) < original_count |
|
|
|
|
|
def lookup_route(self, destination_ip: str) -> Optional[RouteEntry]: |
|
|
"""Look up route for destination IP""" |
|
|
self.stats['route_lookups'] += 1 |
|
|
|
|
|
with self.lock: |
|
|
|
|
|
matching_routes = [] |
|
|
for route in self.routing_table: |
|
|
|
|
|
interface = self.interfaces.get(route.interface) |
|
|
if not interface or not interface.enabled: |
|
|
continue |
|
|
|
|
|
if route.matches_destination(destination_ip): |
|
|
matching_routes.append(route) |
|
|
|
|
|
if not matching_routes: |
|
|
self.stats['routing_errors'] += 1 |
|
|
return None |
|
|
|
|
|
|
|
|
def route_priority(route): |
|
|
try: |
|
|
network = ipaddress.ip_network(route.destination, strict=False) |
|
|
return (-network.prefixlen, route.metric, route.created_time) |
|
|
except: |
|
|
return (0, route.metric, route.created_time) |
|
|
|
|
|
matching_routes.sort(key=route_priority) |
|
|
best_route = matching_routes[0] |
|
|
best_route.record_use() |
|
|
|
|
|
return best_route |
|
|
|
|
|
def route_packet(self, packet: ParsedPacket) -> Optional[Tuple[str, str]]: |
|
|
"""Route packet and return (next_hop_ip, interface)""" |
|
|
self.stats['packets_routed'] += 1 |
|
|
|
|
|
destination_ip = packet.ip_header.dest_ip |
|
|
|
|
|
|
|
|
route = self.lookup_route(destination_ip) |
|
|
if not route: |
|
|
self.stats['packets_dropped'] += 1 |
|
|
return None |
|
|
|
|
|
|
|
|
if route.gateway: |
|
|
next_hop = route.gateway |
|
|
else: |
|
|
|
|
|
next_hop = destination_ip |
|
|
|
|
|
return (next_hop, route.interface) |
|
|
|
|
|
def is_local_destination(self, ip: str) -> bool: |
|
|
"""Check if IP is a local destination (belongs to router interfaces)""" |
|
|
with self.lock: |
|
|
for interface in self.interfaces.values(): |
|
|
if interface.ip_address == ip: |
|
|
return True |
|
|
return False |
|
|
|
|
|
def is_local_network(self, ip: str) -> bool: |
|
|
"""Check if IP belongs to any local network""" |
|
|
with self.lock: |
|
|
for interface in self.interfaces.values(): |
|
|
if interface.is_local_address(ip): |
|
|
return True |
|
|
return False |
|
|
|
|
|
def get_interface_for_ip(self, ip: str) -> Optional[Interface]: |
|
|
"""Get interface that can reach the given IP""" |
|
|
with self.lock: |
|
|
for interface in self.interfaces.values(): |
|
|
if interface.enabled and interface.is_local_address(ip): |
|
|
return interface |
|
|
return None |
|
|
|
|
|
def add_arp_entry(self, ip: str, mac: str): |
|
|
"""Add ARP table entry""" |
|
|
with self.lock: |
|
|
self.arp_table[ip] = mac |
|
|
|
|
|
def get_arp_entry(self, ip: str) -> Optional[str]: |
|
|
"""Get MAC address from ARP table""" |
|
|
with self.lock: |
|
|
return self.arp_table.get(ip) |
|
|
|
|
|
def remove_arp_entry(self, ip: str) -> bool: |
|
|
"""Remove ARP table entry""" |
|
|
with self.lock: |
|
|
if ip in self.arp_table: |
|
|
del self.arp_table[ip] |
|
|
return True |
|
|
return False |
|
|
|
|
|
def clear_arp_table(self): |
|
|
"""Clear ARP table""" |
|
|
with self.lock: |
|
|
self.arp_table.clear() |
|
|
|
|
|
def get_routing_table(self) -> List[Dict]: |
|
|
"""Get routing table""" |
|
|
with self.lock: |
|
|
return [route.to_dict() for route in self.routing_table] |
|
|
|
|
|
def get_interfaces(self) -> Dict[str, Dict]: |
|
|
"""Get network interfaces""" |
|
|
with self.lock: |
|
|
return { |
|
|
name: interface.to_dict() |
|
|
for name, interface in self.interfaces.items() |
|
|
} |
|
|
|
|
|
def get_arp_table(self) -> Dict[str, str]: |
|
|
"""Get ARP table""" |
|
|
with self.lock: |
|
|
return self.arp_table.copy() |
|
|
|
|
|
def get_stats(self) -> Dict: |
|
|
"""Get router statistics""" |
|
|
with self.lock: |
|
|
stats = self.stats.copy() |
|
|
stats['total_routes'] = len(self.routing_table) |
|
|
stats['total_interfaces'] = len(self.interfaces) |
|
|
stats['enabled_interfaces'] = sum(1 for iface in self.interfaces.values() if iface.enabled) |
|
|
stats['arp_entries'] = len(self.arp_table) |
|
|
|
|
|
return stats |
|
|
|
|
|
def reset_stats(self): |
|
|
"""Reset router statistics""" |
|
|
self.stats = { |
|
|
'packets_routed': 0, |
|
|
'packets_dropped': 0, |
|
|
'route_lookups': 0, |
|
|
'arp_requests': 0, |
|
|
'arp_replies': 0, |
|
|
'routing_errors': 0 |
|
|
} |
|
|
|
|
|
|
|
|
with self.lock: |
|
|
for route in self.routing_table: |
|
|
route.use_count = 0 |
|
|
route.last_used = None |
|
|
|
|
|
def flush_routes(self, route_type: Optional[RouteType] = None): |
|
|
"""Flush routes of specified type (or all if None)""" |
|
|
with self.lock: |
|
|
if route_type: |
|
|
self.routing_table = [ |
|
|
route for route in self.routing_table |
|
|
if route.route_type != route_type |
|
|
] |
|
|
else: |
|
|
self.routing_table.clear() |
|
|
|
|
|
def export_config(self) -> Dict: |
|
|
"""Export router configuration""" |
|
|
return { |
|
|
'router_id': self.router_id, |
|
|
'default_gateway': self.default_gateway, |
|
|
'interfaces': [ |
|
|
{ |
|
|
'name': iface.name, |
|
|
'ip_address': iface.ip_address, |
|
|
'netmask': iface.netmask, |
|
|
'network': iface.network, |
|
|
'enabled': iface.enabled, |
|
|
'mtu': iface.mtu |
|
|
} |
|
|
for iface in self.interfaces.values() |
|
|
], |
|
|
'static_routes': [ |
|
|
{ |
|
|
'destination': route.destination, |
|
|
'gateway': route.gateway, |
|
|
'interface': route.interface, |
|
|
'metric': route.metric |
|
|
} |
|
|
for route in self.routing_table |
|
|
if route.route_type == RouteType.STATIC |
|
|
] |
|
|
} |
|
|
|
|
|
def import_config(self, config: Dict): |
|
|
"""Import router configuration""" |
|
|
|
|
|
with self.lock: |
|
|
self.interfaces.clear() |
|
|
self.routing_table.clear() |
|
|
self.arp_table.clear() |
|
|
|
|
|
|
|
|
self.router_id = config.get('router_id', self.router_id) |
|
|
self.default_gateway = config.get('default_gateway', self.default_gateway) |
|
|
|
|
|
|
|
|
self.config.update(config) |
|
|
self._initialize_interfaces() |
|
|
self._initialize_routes() |
|
|
|
|
|
|
|
|
class RouterUtils: |
|
|
"""Utility functions for router operations""" |
|
|
|
|
|
@staticmethod |
|
|
def ip_to_int(ip: str) -> int: |
|
|
"""Convert IP address to integer""" |
|
|
return int(ipaddress.ip_address(ip)) |
|
|
|
|
|
@staticmethod |
|
|
def int_to_ip(ip_int: int) -> str: |
|
|
"""Convert integer to IP address""" |
|
|
return str(ipaddress.ip_address(ip_int)) |
|
|
|
|
|
@staticmethod |
|
|
def calculate_network(ip: str, netmask: str) -> str: |
|
|
"""Calculate network address from IP and netmask""" |
|
|
try: |
|
|
interface = ipaddress.ip_interface(f"{ip}/{netmask}") |
|
|
return str(interface.network) |
|
|
except (ipaddress.AddressValueError, ValueError): |
|
|
return "0.0.0.0/0" |
|
|
|
|
|
@staticmethod |
|
|
def is_private_ip(ip: str) -> bool: |
|
|
"""Check if IP address is private""" |
|
|
try: |
|
|
ip_obj = ipaddress.ip_address(ip) |
|
|
return ip_obj.is_private |
|
|
except (ipaddress.AddressValueError, ValueError): |
|
|
return False |
|
|
|
|
|
@staticmethod |
|
|
def is_multicast_ip(ip: str) -> bool: |
|
|
"""Check if IP address is multicast""" |
|
|
try: |
|
|
ip_obj = ipaddress.ip_address(ip) |
|
|
return ip_obj.is_multicast |
|
|
except (ipaddress.AddressValueError, ValueError): |
|
|
return False |
|
|
|
|
|
@staticmethod |
|
|
def validate_cidr(cidr: str) -> bool: |
|
|
"""Validate CIDR notation""" |
|
|
try: |
|
|
ipaddress.ip_network(cidr, strict=False) |
|
|
return True |
|
|
except (ipaddress.AddressValueError, ValueError): |
|
|
return False |
|
|
|
|
|
|