Spaces:
Paused
Paused
| """ | |
| Packet Bridge Module | |
| Handles communication with virtual clients: | |
| - Accept packet streams over WebSocket/TCP | |
| - Deliver response packets back to clients | |
| - Frame processing (Ethernet → IPv4) | |
| - Connection management | |
| """ | |
| import asyncio | |
| import websockets | |
| import socket | |
| import threading | |
| import time | |
| import struct | |
| from typing import Dict, List, Optional, Callable, Set, Any, Tuple | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| import json | |
| import logging | |
| from .ip_parser import IPParser, ParsedPacket | |
| class BridgeType(Enum): | |
| WEBSOCKET = "WEBSOCKET" | |
| TCP_SOCKET = "TCP_SOCKET" | |
| UDP_SOCKET = "UDP_SOCKET" | |
| class ClientConnection: | |
| """Represents a client connection to the bridge""" | |
| client_id: str | |
| bridge_type: BridgeType | |
| remote_address: str | |
| remote_port: int | |
| websocket: Optional[Any] = None # WebSocket connection | |
| socket: Optional['socket.socket'] = None # TCP/UDP socket | |
| connected_time: float = 0 | |
| last_activity: float = 0 | |
| packets_received: int = 0 | |
| packets_sent: int = 0 | |
| bytes_received: int = 0 | |
| bytes_sent: int = 0 | |
| is_active: bool = True | |
| def __post_init__(self): | |
| if self.connected_time == 0: | |
| self.connected_time = time.time() | |
| if self.last_activity == 0: | |
| self.last_activity = time.time() | |
| def update_activity(self, packet_count: int = 1, byte_count: int = 0, direction: str = 'received'): | |
| """Update connection activity""" | |
| self.last_activity = time.time() | |
| if direction == 'received': | |
| self.packets_received += packet_count | |
| self.bytes_received += byte_count | |
| else: | |
| self.packets_sent += packet_count | |
| self.bytes_sent += byte_count | |
| def to_dict(self) -> Dict: | |
| """Convert to dictionary""" | |
| return { | |
| 'client_id': self.client_id, | |
| 'bridge_type': self.bridge_type.value, | |
| 'remote_address': self.remote_address, | |
| 'remote_port': self.remote_port, | |
| 'connected_time': self.connected_time, | |
| 'last_activity': self.last_activity, | |
| 'packets_received': self.packets_received, | |
| 'packets_sent': self.packets_sent, | |
| 'bytes_received': self.bytes_received, | |
| 'bytes_sent': self.bytes_sent, | |
| 'is_active': self.is_active, | |
| 'duration': time.time() - self.connected_time | |
| } | |
| class EthernetFrame: | |
| """Ethernet frame parser""" | |
| def __init__(self): | |
| self.dest_mac = b'\x00' * 6 | |
| self.src_mac = b'\x00' * 6 | |
| self.ethertype = 0x0800 # IPv4 | |
| self.payload = b'' | |
| def parse(cls, data: bytes) -> Optional['EthernetFrame']: | |
| """Parse Ethernet frame from raw bytes""" | |
| if len(data) < 14: # Minimum Ethernet header size | |
| return None | |
| frame = cls() | |
| frame.dest_mac = data[0:6] | |
| frame.src_mac = data[6:12] | |
| frame.ethertype = struct.unpack('!H', data[12:14])[0] | |
| frame.payload = data[14:] | |
| return frame | |
| def build(self) -> bytes: | |
| """Build Ethernet frame as bytes""" | |
| header = self.dest_mac + self.src_mac + struct.pack('!H', self.ethertype) | |
| return header + self.payload | |
| def is_ipv4(self) -> bool: | |
| """Check if frame contains IPv4 packet""" | |
| return self.ethertype == 0x0800 | |
| def is_arp(self) -> bool: | |
| """Check if frame contains ARP packet""" | |
| return self.ethertype == 0x0806 | |
| class PacketBridge: | |
| """Packet bridge implementation""" | |
| def __init__(self, config: Dict): | |
| self.config = config | |
| self.clients: Dict[str, ClientConnection] = {} | |
| self.packet_handlers: List[Callable[[ParsedPacket, str], Optional[bytes]]] = [] | |
| self.lock = threading.Lock() | |
| # Configuration | |
| self.websocket_host = config.get('websocket_host', '0.0.0.0') | |
| self.websocket_port = config.get('websocket_port', 8765) | |
| self.tcp_host = config.get('tcp_host', '0.0.0.0') | |
| self.tcp_port = config.get('tcp_port', 8766) | |
| self.max_clients = config.get('max_clients', 100) | |
| self.client_timeout = config.get('client_timeout', 300) | |
| # WebSocket server | |
| self.websocket_server = None | |
| self.tcp_server_socket = None | |
| # Background tasks | |
| self.running = False | |
| self.websocket_task = None | |
| self.tcp_task = None | |
| self.cleanup_task = None | |
| # Statistics | |
| self.stats = { | |
| 'total_clients': 0, | |
| 'active_clients': 0, | |
| 'packets_processed': 0, | |
| 'packets_forwarded': 0, | |
| 'packets_dropped': 0, | |
| 'bytes_processed': 0, | |
| 'websocket_connections': 0, | |
| 'tcp_connections': 0, | |
| 'connection_errors': 0 | |
| } | |
| # Event loop | |
| self.loop = None | |
| def add_packet_handler(self, handler: Callable[[ParsedPacket, str], Optional[bytes]]): | |
| """Add packet handler function""" | |
| self.packet_handlers.append(handler) | |
| def remove_packet_handler(self, handler: Callable[[ParsedPacket, str], Optional[bytes]]): | |
| """Remove packet handler function""" | |
| if handler in self.packet_handlers: | |
| self.packet_handlers.remove(handler) | |
| def _generate_client_id(self, remote_address: str, remote_port: int) -> str: | |
| """Generate unique client ID""" | |
| timestamp = int(time.time() * 1000) | |
| return f"client_{remote_address}_{remote_port}_{timestamp}" | |
| def _process_ethernet_frame(self, frame_data: bytes, client_id: str) -> Optional[bytes]: | |
| """Process Ethernet frame and extract IP packet""" | |
| try: | |
| # Parse Ethernet frame | |
| frame = EthernetFrame.parse(frame_data) | |
| if not frame or not frame.is_ipv4(): | |
| return None | |
| # Parse IP packet | |
| packet = IPParser.parse_packet(frame.payload) | |
| self.stats['packets_processed'] += 1 | |
| self.stats['bytes_processed'] += len(frame_data) | |
| # Process through packet handlers | |
| response_packet = None | |
| for handler in self.packet_handlers: | |
| try: | |
| response = handler(packet, client_id) | |
| if response: | |
| response_packet = response | |
| break | |
| except Exception as e: | |
| logging.error(f"Packet handler error: {e}") | |
| if response_packet: | |
| # Wrap response in Ethernet frame | |
| response_frame = EthernetFrame() | |
| response_frame.dest_mac = frame.src_mac | |
| response_frame.src_mac = frame.dest_mac | |
| response_frame.ethertype = 0x0800 | |
| response_frame.payload = response_packet | |
| self.stats['packets_forwarded'] += 1 | |
| return response_frame.build() | |
| else: | |
| self.stats['packets_dropped'] += 1 | |
| return None | |
| except Exception as e: | |
| logging.error(f"Error processing Ethernet frame: {e}") | |
| self.stats['packets_dropped'] += 1 | |
| return None | |
| async def _handle_websocket_client(self, websocket, path): | |
| """Handle WebSocket client connection""" | |
| client_address = websocket.remote_address | |
| client_id = self._generate_client_id(client_address[0], client_address[1]) | |
| # Create client connection | |
| client = ClientConnection( | |
| client_id=client_id, | |
| bridge_type=BridgeType.WEBSOCKET, | |
| remote_address=client_address[0], | |
| remote_port=client_address[1], | |
| websocket=websocket | |
| ) | |
| with self.lock: | |
| if len(self.clients) >= self.max_clients: | |
| await websocket.close(code=1013, reason="Too many clients") | |
| return | |
| self.clients[client_id] = client | |
| self.stats['total_clients'] += 1 | |
| self.stats['active_clients'] = len(self.clients) | |
| self.stats['websocket_connections'] += 1 | |
| logging.info(f"WebSocket client connected: {client_id} from {client_address}") | |
| try: | |
| async for message in websocket: | |
| if isinstance(message, bytes): | |
| # Binary message - treat as Ethernet frame | |
| client.update_activity(1, len(message), 'received') | |
| response = self._process_ethernet_frame(message, client_id) | |
| if response: | |
| await websocket.send(response) | |
| client.update_activity(1, len(response), 'sent') | |
| elif isinstance(message, str): | |
| # Text message - treat as control message | |
| try: | |
| control_msg = json.loads(message) | |
| await self._handle_control_message(client, control_msg) | |
| except json.JSONDecodeError: | |
| logging.warning(f"Invalid control message from {client_id}: {message}") | |
| except websockets.exceptions.ConnectionClosed: | |
| logging.info(f"WebSocket client disconnected: {client_id}") | |
| except Exception as e: | |
| logging.error(f"WebSocket client error: {e}") | |
| self.stats['connection_errors'] += 1 | |
| finally: | |
| # Clean up client | |
| with self.lock: | |
| if client_id in self.clients: | |
| self.clients[client_id].is_active = False | |
| del self.clients[client_id] | |
| self.stats['active_clients'] = len(self.clients) | |
| async def _handle_control_message(self, client: ClientConnection, message: Dict): | |
| """Handle control message from client""" | |
| msg_type = message.get('type') | |
| if msg_type == 'ping': | |
| # Respond to ping | |
| response = {'type': 'pong', 'timestamp': time.time()} | |
| await client.websocket.send(json.dumps(response)) | |
| elif msg_type == 'stats': | |
| # Send client statistics | |
| response = { | |
| 'type': 'stats', | |
| 'client_stats': client.to_dict(), | |
| 'bridge_stats': self.get_stats() | |
| } | |
| await client.websocket.send(json.dumps(response)) | |
| elif msg_type == 'config': | |
| # Handle configuration updates | |
| config_data = message.get('data', {}) | |
| # Process configuration updates here | |
| response = {'type': 'config_ack', 'status': 'ok'} | |
| await client.websocket.send(json.dumps(response)) | |
| def _handle_tcp_client(self, client_socket: socket.socket, client_address: Tuple[str, int]): | |
| """Handle TCP client connection""" | |
| client_id = self._generate_client_id(client_address[0], client_address[1]) | |
| # Create client connection | |
| client = ClientConnection( | |
| client_id=client_id, | |
| bridge_type=BridgeType.TCP_SOCKET, | |
| remote_address=client_address[0], | |
| remote_port=client_address[1], | |
| socket=client_socket | |
| ) | |
| with self.lock: | |
| if len(self.clients) >= self.max_clients: | |
| client_socket.close() | |
| return | |
| self.clients[client_id] = client | |
| self.stats['total_clients'] += 1 | |
| self.stats['active_clients'] = len(self.clients) | |
| self.stats['tcp_connections'] += 1 | |
| logging.info(f"TCP client connected: {client_id} from {client_address}") | |
| try: | |
| client_socket.settimeout(self.client_timeout) | |
| while client.is_active: | |
| try: | |
| # Read frame length (4 bytes) | |
| length_data = client_socket.recv(4) | |
| if not length_data: | |
| break | |
| frame_length = struct.unpack('!I', length_data)[0] | |
| if frame_length > 65536: # Sanity check | |
| break | |
| # Read frame data | |
| frame_data = b'' | |
| while len(frame_data) < frame_length: | |
| chunk = client_socket.recv(frame_length - len(frame_data)) | |
| if not chunk: | |
| break | |
| frame_data += chunk | |
| if len(frame_data) != frame_length: | |
| break | |
| client.update_activity(1, len(frame_data), 'received') | |
| # Process frame | |
| response = self._process_ethernet_frame(frame_data, client_id) | |
| if response: | |
| # Send response with length prefix | |
| response_length = struct.pack('!I', len(response)) | |
| client_socket.send(response_length + response) | |
| client.update_activity(1, len(response), 'sent') | |
| except socket.timeout: | |
| continue | |
| except Exception as e: | |
| logging.error(f"TCP client error: {e}") | |
| break | |
| except Exception as e: | |
| logging.error(f"TCP client handler error: {e}") | |
| self.stats['connection_errors'] += 1 | |
| finally: | |
| # Clean up client | |
| try: | |
| client_socket.close() | |
| except: | |
| pass | |
| with self.lock: | |
| if client_id in self.clients: | |
| self.clients[client_id].is_active = False | |
| del self.clients[client_id] | |
| self.stats['active_clients'] = len(self.clients) | |
| logging.info(f"TCP client disconnected: {client_id}") | |
| def _tcp_server_loop(self): | |
| """TCP server loop""" | |
| try: | |
| self.tcp_server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) | |
| self.tcp_server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) | |
| self.tcp_server_socket.bind((self.tcp_host, self.tcp_port)) | |
| self.tcp_server_socket.listen(10) | |
| logging.info(f"TCP bridge server listening on {self.tcp_host}:{self.tcp_port}") | |
| while self.running: | |
| try: | |
| client_socket, client_address = self.tcp_server_socket.accept() | |
| # Handle client in separate thread | |
| client_thread = threading.Thread( | |
| target=self._handle_tcp_client, | |
| args=(client_socket, client_address), | |
| daemon=True | |
| ) | |
| client_thread.start() | |
| except socket.error as e: | |
| if self.running: | |
| logging.error(f"TCP server error: {e}") | |
| time.sleep(1) | |
| except Exception as e: | |
| logging.error(f"TCP server loop error: {e}") | |
| finally: | |
| if self.tcp_server_socket: | |
| self.tcp_server_socket.close() | |
| def _cleanup_loop(self): | |
| """Background cleanup loop""" | |
| while self.running: | |
| try: | |
| current_time = time.time() | |
| expired_clients = [] | |
| with self.lock: | |
| for client_id, client in self.clients.items(): | |
| # Mark inactive clients for removal | |
| if current_time - client.last_activity > self.client_timeout: | |
| expired_clients.append(client_id) | |
| # Clean up expired clients | |
| for client_id in expired_clients: | |
| with self.lock: | |
| if client_id in self.clients: | |
| client = self.clients[client_id] | |
| client.is_active = False | |
| # Close connections | |
| if client.websocket: | |
| try: | |
| asyncio.run_coroutine_threadsafe( | |
| client.websocket.close(), | |
| self.loop | |
| ) | |
| except: | |
| pass | |
| if client.socket: | |
| try: | |
| client.socket.close() | |
| except: | |
| pass | |
| del self.clients[client_id] | |
| logging.info(f"Cleaned up expired client: {client_id}") | |
| self.stats['active_clients'] = len(self.clients) | |
| time.sleep(30) # Cleanup every 30 seconds | |
| except Exception as e: | |
| logging.error(f"Cleanup loop error: {e}") | |
| time.sleep(5) | |
| def send_packet_to_client(self, client_id: str, packet_data: bytes) -> bool: | |
| """Send packet to specific client""" | |
| with self.lock: | |
| client = self.clients.get(client_id) | |
| if not client or not client.is_active: | |
| return False | |
| try: | |
| if client.bridge_type == BridgeType.WEBSOCKET: | |
| # Send via WebSocket | |
| if client.websocket: | |
| asyncio.run_coroutine_threadsafe( | |
| client.websocket.send(packet_data), | |
| self.loop | |
| ) | |
| client.update_activity(1, len(packet_data), 'sent') | |
| return True | |
| elif client.bridge_type == BridgeType.TCP_SOCKET: | |
| # Send via TCP socket with length prefix | |
| if client.socket: | |
| length_prefix = struct.pack('!I', len(packet_data)) | |
| client.socket.send(length_prefix + packet_data) | |
| client.update_activity(1, len(packet_data), 'sent') | |
| return True | |
| except Exception as e: | |
| logging.error(f"Failed to send packet to client {client_id}: {e}") | |
| # Mark client as inactive | |
| client.is_active = False | |
| return False | |
| def broadcast_packet(self, packet_data: bytes, exclude_client: Optional[str] = None) -> int: | |
| """Broadcast packet to all clients""" | |
| sent_count = 0 | |
| with self.lock: | |
| client_ids = list(self.clients.keys()) | |
| for client_id in client_ids: | |
| if client_id != exclude_client: | |
| if self.send_packet_to_client(client_id, packet_data): | |
| sent_count += 1 | |
| return sent_count | |
| def get_clients(self) -> Dict[str, Dict]: | |
| """Get all connected clients""" | |
| with self.lock: | |
| return { | |
| client_id: client.to_dict() | |
| for client_id, client in self.clients.items() | |
| } | |
| def get_client(self, client_id: str) -> Optional[Dict]: | |
| """Get specific client""" | |
| with self.lock: | |
| client = self.clients.get(client_id) | |
| return client.to_dict() if client else None | |
| def disconnect_client(self, client_id: str) -> bool: | |
| """Disconnect specific client""" | |
| with self.lock: | |
| client = self.clients.get(client_id) | |
| if not client: | |
| return False | |
| client.is_active = False | |
| # Close connection | |
| if client.websocket: | |
| try: | |
| asyncio.run_coroutine_threadsafe( | |
| client.websocket.close(), | |
| self.loop | |
| ) | |
| except: | |
| pass | |
| if client.socket: | |
| try: | |
| client.socket.close() | |
| except: | |
| pass | |
| del self.clients[client_id] | |
| self.stats['active_clients'] = len(self.clients) | |
| return True | |
| def get_stats(self) -> Dict: | |
| """Get bridge statistics""" | |
| with self.lock: | |
| stats = self.stats.copy() | |
| stats['active_clients'] = len(self.clients) | |
| return stats | |
| def reset_stats(self): | |
| """Reset bridge statistics""" | |
| self.stats = { | |
| 'total_clients': 0, | |
| 'active_clients': len(self.clients), | |
| 'packets_processed': 0, | |
| 'packets_forwarded': 0, | |
| 'packets_dropped': 0, | |
| 'bytes_processed': 0, | |
| 'websocket_connections': 0, | |
| 'tcp_connections': 0, | |
| 'connection_errors': 0 | |
| } | |
| async def start_websocket_server(self): | |
| """Start WebSocket server""" | |
| try: | |
| self.websocket_server = await websockets.serve( | |
| self._handle_websocket_client, | |
| self.websocket_host, | |
| self.websocket_port, | |
| max_size=1024*1024, # 1MB max message size | |
| ping_interval=30, | |
| ping_timeout=10 | |
| ) | |
| logging.info(f"WebSocket bridge server started on {self.websocket_host}:{self.websocket_port}") | |
| # Keep server running | |
| await self.websocket_server.wait_closed() | |
| except Exception as e: | |
| logging.error(f"WebSocket server error: {e}") | |
| def start(self): | |
| """Start packet bridge""" | |
| self.running = True | |
| # Start event loop | |
| self.loop = asyncio.new_event_loop() | |
| asyncio.set_event_loop(self.loop) | |
| # Start WebSocket server in a separate thread | |
| websocket_thread = threading.Thread(target=self._run_websocket_server_in_thread, daemon=True) | |
| websocket_thread.start() | |
| # Start TCP server in separate thread | |
| tcp_thread = threading.Thread(target=self._tcp_server_loop, daemon=True) | |
| tcp_thread.start() | |
| # Start cleanup thread | |
| cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) | |
| cleanup_thread.start() | |
| logging.info("Packet bridge started") | |
| def stop(self): | |
| """Stop packet bridge""" | |
| self.running = False | |
| # Close WebSocket server | |
| if self.websocket_server: | |
| self.websocket_server.close() | |
| # Close TCP server | |
| if self.tcp_server_socket: | |
| self.tcp_server_socket.close() | |
| # Disconnect all clients | |
| with self.lock: | |
| client_ids = list(self.clients.keys()) | |
| for client_id in client_ids: | |
| self.disconnect_client(client_id) | |
| # Stop event loop | |
| if self.loop and not self.loop.is_closed(): | |
| self.loop.call_soon_threadsafe(self.loop.stop) | |
| logging.info("Packet bridge stopped") | |
| def _run_websocket_server_in_thread(self): | |
| """Run the WebSocket server in a separate thread with its own event loop.""" | |
| asyncio.set_event_loop(self.loop) | |
| self.loop.run_until_complete(self.start_websocket_server()) | |