""" Packet Bridge Module Handles communication with virtual clients: - Accept packet streams over WebSocket/TCP - Deliver response packets back to clients - Frame processing (Ethernet → IPv4) - Connection management """ import asyncio import websockets import socket import threading import time import struct from typing import Dict, List, Optional, Callable, Set, Any, Tuple from dataclasses import dataclass from enum import Enum import json import logging from .ip_parser import IPParser, ParsedPacket class BridgeType(Enum): WEBSOCKET = "WEBSOCKET" TCP_SOCKET = "TCP_SOCKET" UDP_SOCKET = "UDP_SOCKET" @dataclass class ClientConnection: """Represents a client connection to the bridge""" client_id: str bridge_type: BridgeType remote_address: str remote_port: int websocket: Optional[Any] = None # WebSocket connection socket: Optional['socket.socket'] = None # TCP/UDP socket connected_time: float = 0 last_activity: float = 0 packets_received: int = 0 packets_sent: int = 0 bytes_received: int = 0 bytes_sent: int = 0 is_active: bool = True def __post_init__(self): if self.connected_time == 0: self.connected_time = time.time() if self.last_activity == 0: self.last_activity = time.time() def update_activity(self, packet_count: int = 1, byte_count: int = 0, direction: str = 'received'): """Update connection activity""" self.last_activity = time.time() if direction == 'received': self.packets_received += packet_count self.bytes_received += byte_count else: self.packets_sent += packet_count self.bytes_sent += byte_count def to_dict(self) -> Dict: """Convert to dictionary""" return { 'client_id': self.client_id, 'bridge_type': self.bridge_type.value, 'remote_address': self.remote_address, 'remote_port': self.remote_port, 'connected_time': self.connected_time, 'last_activity': self.last_activity, 'packets_received': self.packets_received, 'packets_sent': self.packets_sent, 'bytes_received': self.bytes_received, 'bytes_sent': self.bytes_sent, 'is_active': self.is_active, 'duration': time.time() - self.connected_time } class EthernetFrame: """Ethernet frame parser""" def __init__(self): self.dest_mac = b'\x00' * 6 self.src_mac = b'\x00' * 6 self.ethertype = 0x0800 # IPv4 self.payload = b'' @classmethod def parse(cls, data: bytes) -> Optional['EthernetFrame']: """Parse Ethernet frame from raw bytes""" if len(data) < 14: # Minimum Ethernet header size return None frame = cls() frame.dest_mac = data[0:6] frame.src_mac = data[6:12] frame.ethertype = struct.unpack('!H', data[12:14])[0] frame.payload = data[14:] return frame def build(self) -> bytes: """Build Ethernet frame as bytes""" header = self.dest_mac + self.src_mac + struct.pack('!H', self.ethertype) return header + self.payload def is_ipv4(self) -> bool: """Check if frame contains IPv4 packet""" return self.ethertype == 0x0800 def is_arp(self) -> bool: """Check if frame contains ARP packet""" return self.ethertype == 0x0806 class PacketBridge: """Packet bridge implementation""" def __init__(self, config: Dict): self.config = config self.clients: Dict[str, ClientConnection] = {} self.packet_handlers: List[Callable[[ParsedPacket, str], Optional[bytes]]] = [] self.lock = threading.Lock() # Configuration self.websocket_host = config.get('websocket_host', '0.0.0.0') self.websocket_port = config.get('websocket_port', 8765) self.tcp_host = config.get('tcp_host', '0.0.0.0') self.tcp_port = config.get('tcp_port', 8766) self.max_clients = config.get('max_clients', 100) self.client_timeout = config.get('client_timeout', 300) # WebSocket server self.websocket_server = None self.tcp_server_socket = None # Background tasks self.running = False self.websocket_task = None self.tcp_task = None self.cleanup_task = None # Statistics self.stats = { 'total_clients': 0, 'active_clients': 0, 'packets_processed': 0, 'packets_forwarded': 0, 'packets_dropped': 0, 'bytes_processed': 0, 'websocket_connections': 0, 'tcp_connections': 0, 'connection_errors': 0 } # Event loop self.loop = None def add_packet_handler(self, handler: Callable[[ParsedPacket, str], Optional[bytes]]): """Add packet handler function""" self.packet_handlers.append(handler) def remove_packet_handler(self, handler: Callable[[ParsedPacket, str], Optional[bytes]]): """Remove packet handler function""" if handler in self.packet_handlers: self.packet_handlers.remove(handler) def _generate_client_id(self, remote_address: str, remote_port: int) -> str: """Generate unique client ID""" timestamp = int(time.time() * 1000) return f"client_{remote_address}_{remote_port}_{timestamp}" def _process_ethernet_frame(self, frame_data: bytes, client_id: str) -> Optional[bytes]: """Process Ethernet frame and extract IP packet""" try: # Parse Ethernet frame frame = EthernetFrame.parse(frame_data) if not frame or not frame.is_ipv4(): return None # Parse IP packet packet = IPParser.parse_packet(frame.payload) self.stats['packets_processed'] += 1 self.stats['bytes_processed'] += len(frame_data) # Process through packet handlers response_packet = None for handler in self.packet_handlers: try: response = handler(packet, client_id) if response: response_packet = response break except Exception as e: logging.error(f"Packet handler error: {e}") if response_packet: # Wrap response in Ethernet frame response_frame = EthernetFrame() response_frame.dest_mac = frame.src_mac response_frame.src_mac = frame.dest_mac response_frame.ethertype = 0x0800 response_frame.payload = response_packet self.stats['packets_forwarded'] += 1 return response_frame.build() else: self.stats['packets_dropped'] += 1 return None except Exception as e: logging.error(f"Error processing Ethernet frame: {e}") self.stats['packets_dropped'] += 1 return None async def _handle_websocket_client(self, websocket, path): """Handle WebSocket client connection""" client_address = websocket.remote_address client_id = self._generate_client_id(client_address[0], client_address[1]) # Create client connection client = ClientConnection( client_id=client_id, bridge_type=BridgeType.WEBSOCKET, remote_address=client_address[0], remote_port=client_address[1], websocket=websocket ) with self.lock: if len(self.clients) >= self.max_clients: await websocket.close(code=1013, reason="Too many clients") return self.clients[client_id] = client self.stats['total_clients'] += 1 self.stats['active_clients'] = len(self.clients) self.stats['websocket_connections'] += 1 logging.info(f"WebSocket client connected: {client_id} from {client_address}") try: async for message in websocket: if isinstance(message, bytes): # Binary message - treat as Ethernet frame client.update_activity(1, len(message), 'received') response = self._process_ethernet_frame(message, client_id) if response: await websocket.send(response) client.update_activity(1, len(response), 'sent') elif isinstance(message, str): # Text message - treat as control message try: control_msg = json.loads(message) await self._handle_control_message(client, control_msg) except json.JSONDecodeError: logging.warning(f"Invalid control message from {client_id}: {message}") except websockets.exceptions.ConnectionClosed: logging.info(f"WebSocket client disconnected: {client_id}") except Exception as e: logging.error(f"WebSocket client error: {e}") self.stats['connection_errors'] += 1 finally: # Clean up client with self.lock: if client_id in self.clients: self.clients[client_id].is_active = False del self.clients[client_id] self.stats['active_clients'] = len(self.clients) async def _handle_control_message(self, client: ClientConnection, message: Dict): """Handle control message from client""" msg_type = message.get('type') if msg_type == 'ping': # Respond to ping response = {'type': 'pong', 'timestamp': time.time()} await client.websocket.send(json.dumps(response)) elif msg_type == 'stats': # Send client statistics response = { 'type': 'stats', 'client_stats': client.to_dict(), 'bridge_stats': self.get_stats() } await client.websocket.send(json.dumps(response)) elif msg_type == 'config': # Handle configuration updates config_data = message.get('data', {}) # Process configuration updates here response = {'type': 'config_ack', 'status': 'ok'} await client.websocket.send(json.dumps(response)) def _handle_tcp_client(self, client_socket: socket.socket, client_address: Tuple[str, int]): """Handle TCP client connection""" client_id = self._generate_client_id(client_address[0], client_address[1]) # Create client connection client = ClientConnection( client_id=client_id, bridge_type=BridgeType.TCP_SOCKET, remote_address=client_address[0], remote_port=client_address[1], socket=client_socket ) with self.lock: if len(self.clients) >= self.max_clients: client_socket.close() return self.clients[client_id] = client self.stats['total_clients'] += 1 self.stats['active_clients'] = len(self.clients) self.stats['tcp_connections'] += 1 logging.info(f"TCP client connected: {client_id} from {client_address}") try: client_socket.settimeout(self.client_timeout) while client.is_active: try: # Read frame length (4 bytes) length_data = client_socket.recv(4) if not length_data: break frame_length = struct.unpack('!I', length_data)[0] if frame_length > 65536: # Sanity check break # Read frame data frame_data = b'' while len(frame_data) < frame_length: chunk = client_socket.recv(frame_length - len(frame_data)) if not chunk: break frame_data += chunk if len(frame_data) != frame_length: break client.update_activity(1, len(frame_data), 'received') # Process frame response = self._process_ethernet_frame(frame_data, client_id) if response: # Send response with length prefix response_length = struct.pack('!I', len(response)) client_socket.send(response_length + response) client.update_activity(1, len(response), 'sent') except socket.timeout: continue except Exception as e: logging.error(f"TCP client error: {e}") break except Exception as e: logging.error(f"TCP client handler error: {e}") self.stats['connection_errors'] += 1 finally: # Clean up client try: client_socket.close() except: pass with self.lock: if client_id in self.clients: self.clients[client_id].is_active = False del self.clients[client_id] self.stats['active_clients'] = len(self.clients) logging.info(f"TCP client disconnected: {client_id}") def _tcp_server_loop(self): """TCP server loop""" try: self.tcp_server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.tcp_server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) self.tcp_server_socket.bind((self.tcp_host, self.tcp_port)) self.tcp_server_socket.listen(10) logging.info(f"TCP bridge server listening on {self.tcp_host}:{self.tcp_port}") while self.running: try: client_socket, client_address = self.tcp_server_socket.accept() # Handle client in separate thread client_thread = threading.Thread( target=self._handle_tcp_client, args=(client_socket, client_address), daemon=True ) client_thread.start() except socket.error as e: if self.running: logging.error(f"TCP server error: {e}") time.sleep(1) except Exception as e: logging.error(f"TCP server loop error: {e}") finally: if self.tcp_server_socket: self.tcp_server_socket.close() def _cleanup_loop(self): """Background cleanup loop""" while self.running: try: current_time = time.time() expired_clients = [] with self.lock: for client_id, client in self.clients.items(): # Mark inactive clients for removal if current_time - client.last_activity > self.client_timeout: expired_clients.append(client_id) # Clean up expired clients for client_id in expired_clients: with self.lock: if client_id in self.clients: client = self.clients[client_id] client.is_active = False # Close connections if client.websocket: try: asyncio.run_coroutine_threadsafe( client.websocket.close(), self.loop ) except: pass if client.socket: try: client.socket.close() except: pass del self.clients[client_id] logging.info(f"Cleaned up expired client: {client_id}") self.stats['active_clients'] = len(self.clients) time.sleep(30) # Cleanup every 30 seconds except Exception as e: logging.error(f"Cleanup loop error: {e}") time.sleep(5) def send_packet_to_client(self, client_id: str, packet_data: bytes) -> bool: """Send packet to specific client""" with self.lock: client = self.clients.get(client_id) if not client or not client.is_active: return False try: if client.bridge_type == BridgeType.WEBSOCKET: # Send via WebSocket if client.websocket: asyncio.run_coroutine_threadsafe( client.websocket.send(packet_data), self.loop ) client.update_activity(1, len(packet_data), 'sent') return True elif client.bridge_type == BridgeType.TCP_SOCKET: # Send via TCP socket with length prefix if client.socket: length_prefix = struct.pack('!I', len(packet_data)) client.socket.send(length_prefix + packet_data) client.update_activity(1, len(packet_data), 'sent') return True except Exception as e: logging.error(f"Failed to send packet to client {client_id}: {e}") # Mark client as inactive client.is_active = False return False def broadcast_packet(self, packet_data: bytes, exclude_client: Optional[str] = None) -> int: """Broadcast packet to all clients""" sent_count = 0 with self.lock: client_ids = list(self.clients.keys()) for client_id in client_ids: if client_id != exclude_client: if self.send_packet_to_client(client_id, packet_data): sent_count += 1 return sent_count def get_clients(self) -> Dict[str, Dict]: """Get all connected clients""" with self.lock: return { client_id: client.to_dict() for client_id, client in self.clients.items() } def get_client(self, client_id: str) -> Optional[Dict]: """Get specific client""" with self.lock: client = self.clients.get(client_id) return client.to_dict() if client else None def disconnect_client(self, client_id: str) -> bool: """Disconnect specific client""" with self.lock: client = self.clients.get(client_id) if not client: return False client.is_active = False # Close connection if client.websocket: try: asyncio.run_coroutine_threadsafe( client.websocket.close(), self.loop ) except: pass if client.socket: try: client.socket.close() except: pass del self.clients[client_id] self.stats['active_clients'] = len(self.clients) return True def get_stats(self) -> Dict: """Get bridge statistics""" with self.lock: stats = self.stats.copy() stats['active_clients'] = len(self.clients) return stats def reset_stats(self): """Reset bridge statistics""" self.stats = { 'total_clients': 0, 'active_clients': len(self.clients), 'packets_processed': 0, 'packets_forwarded': 0, 'packets_dropped': 0, 'bytes_processed': 0, 'websocket_connections': 0, 'tcp_connections': 0, 'connection_errors': 0 } async def start_websocket_server(self): """Start WebSocket server""" try: self.websocket_server = await websockets.serve( self._handle_websocket_client, self.websocket_host, self.websocket_port, max_size=1024*1024, # 1MB max message size ping_interval=30, ping_timeout=10 ) logging.info(f"WebSocket bridge server started on {self.websocket_host}:{self.websocket_port}") # Keep server running await self.websocket_server.wait_closed() except Exception as e: logging.error(f"WebSocket server error: {e}") def start(self): """Start packet bridge""" self.running = True # Start event loop self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) # Start WebSocket server in a separate thread websocket_thread = threading.Thread(target=self._run_websocket_server_in_thread, daemon=True) websocket_thread.start() # Start TCP server in separate thread tcp_thread = threading.Thread(target=self._tcp_server_loop, daemon=True) tcp_thread.start() # Start cleanup thread cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) cleanup_thread.start() logging.info("Packet bridge started") def stop(self): """Stop packet bridge""" self.running = False # Close WebSocket server if self.websocket_server: self.websocket_server.close() # Close TCP server if self.tcp_server_socket: self.tcp_server_socket.close() # Disconnect all clients with self.lock: client_ids = list(self.clients.keys()) for client_id in client_ids: self.disconnect_client(client_id) # Stop event loop if self.loop and not self.loop.is_closed(): self.loop.call_soon_threadsafe(self.loop.stop) logging.info("Packet bridge stopped") def _run_websocket_server_in_thread(self): """Run the WebSocket server in a separate thread with its own event loop.""" asyncio.set_event_loop(self.loop) self.loop.run_until_complete(self.start_websocket_server())