|
|
""" |
|
|
Packet Bridge Module |
|
|
|
|
|
Handles communication with virtual clients: |
|
|
- Accept packet streams over WebSocket/TCP |
|
|
- Deliver response packets back to clients |
|
|
- Frame processing (Ethernet → IPv4) |
|
|
- Connection management |
|
|
""" |
|
|
|
|
|
import asyncio |
|
|
import websockets |
|
|
import socket |
|
|
import threading |
|
|
import time |
|
|
import struct |
|
|
from typing import Dict, List, Optional, Callable, Set, Any, Tuple |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
import json |
|
|
import logging |
|
|
|
|
|
from .ip_parser import IPParser, ParsedPacket |
|
|
|
|
|
|
|
|
class BridgeType(Enum): |
|
|
WEBSOCKET = "WEBSOCKET" |
|
|
TCP_SOCKET = "TCP_SOCKET" |
|
|
UDP_SOCKET = "UDP_SOCKET" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ClientConnection: |
|
|
"""Represents a client connection to the bridge""" |
|
|
client_id: str |
|
|
bridge_type: BridgeType |
|
|
remote_address: str |
|
|
remote_port: int |
|
|
websocket: Optional[Any] = None |
|
|
socket: Optional['socket.socket'] = None |
|
|
connected_time: float = 0 |
|
|
last_activity: float = 0 |
|
|
packets_received: int = 0 |
|
|
packets_sent: int = 0 |
|
|
bytes_received: int = 0 |
|
|
bytes_sent: int = 0 |
|
|
is_active: bool = True |
|
|
|
|
|
def __post_init__(self): |
|
|
if self.connected_time == 0: |
|
|
self.connected_time = time.time() |
|
|
if self.last_activity == 0: |
|
|
self.last_activity = time.time() |
|
|
|
|
|
def update_activity(self, packet_count: int = 1, byte_count: int = 0, direction: str = 'received'): |
|
|
"""Update connection activity""" |
|
|
self.last_activity = time.time() |
|
|
|
|
|
if direction == 'received': |
|
|
self.packets_received += packet_count |
|
|
self.bytes_received += byte_count |
|
|
else: |
|
|
self.packets_sent += packet_count |
|
|
self.bytes_sent += byte_count |
|
|
|
|
|
def to_dict(self) -> Dict: |
|
|
"""Convert to dictionary""" |
|
|
return { |
|
|
'client_id': self.client_id, |
|
|
'bridge_type': self.bridge_type.value, |
|
|
'remote_address': self.remote_address, |
|
|
'remote_port': self.remote_port, |
|
|
'connected_time': self.connected_time, |
|
|
'last_activity': self.last_activity, |
|
|
'packets_received': self.packets_received, |
|
|
'packets_sent': self.packets_sent, |
|
|
'bytes_received': self.bytes_received, |
|
|
'bytes_sent': self.bytes_sent, |
|
|
'is_active': self.is_active, |
|
|
'duration': time.time() - self.connected_time |
|
|
} |
|
|
|
|
|
|
|
|
class EthernetFrame: |
|
|
"""Ethernet frame parser""" |
|
|
|
|
|
def __init__(self): |
|
|
self.dest_mac = b'\x00' * 6 |
|
|
self.src_mac = b'\x00' * 6 |
|
|
self.ethertype = 0x0800 |
|
|
self.payload = b'' |
|
|
|
|
|
@classmethod |
|
|
def parse(cls, data: bytes) -> Optional['EthernetFrame']: |
|
|
"""Parse Ethernet frame from raw bytes""" |
|
|
if len(data) < 14: |
|
|
return None |
|
|
|
|
|
frame = cls() |
|
|
frame.dest_mac = data[0:6] |
|
|
frame.src_mac = data[6:12] |
|
|
frame.ethertype = struct.unpack('!H', data[12:14])[0] |
|
|
frame.payload = data[14:] |
|
|
|
|
|
return frame |
|
|
|
|
|
def build(self) -> bytes: |
|
|
"""Build Ethernet frame as bytes""" |
|
|
header = self.dest_mac + self.src_mac + struct.pack('!H', self.ethertype) |
|
|
return header + self.payload |
|
|
|
|
|
def is_ipv4(self) -> bool: |
|
|
"""Check if frame contains IPv4 packet""" |
|
|
return self.ethertype == 0x0800 |
|
|
|
|
|
def is_arp(self) -> bool: |
|
|
"""Check if frame contains ARP packet""" |
|
|
return self.ethertype == 0x0806 |
|
|
|
|
|
|
|
|
class PacketBridge: |
|
|
"""Packet bridge implementation""" |
|
|
|
|
|
def __init__(self, config: Dict): |
|
|
self.config = config |
|
|
self.clients: Dict[str, ClientConnection] = {} |
|
|
self.packet_handlers: List[Callable[[ParsedPacket, str], Optional[bytes]]] = [] |
|
|
self.lock = threading.Lock() |
|
|
|
|
|
|
|
|
self.websocket_host = config.get('websocket_host', '0.0.0.0') |
|
|
self.websocket_port = config.get('websocket_port', 8765) |
|
|
self.tcp_host = config.get('tcp_host', '0.0.0.0') |
|
|
self.tcp_port = config.get('tcp_port', 8766) |
|
|
self.max_clients = config.get('max_clients', 100) |
|
|
self.client_timeout = config.get('client_timeout', 300) |
|
|
|
|
|
|
|
|
self.websocket_server = None |
|
|
self.tcp_server_socket = None |
|
|
|
|
|
|
|
|
self.running = False |
|
|
self.websocket_task = None |
|
|
self.tcp_task = None |
|
|
self.cleanup_task = None |
|
|
|
|
|
|
|
|
self.stats = { |
|
|
'total_clients': 0, |
|
|
'active_clients': 0, |
|
|
'packets_processed': 0, |
|
|
'packets_forwarded': 0, |
|
|
'packets_dropped': 0, |
|
|
'bytes_processed': 0, |
|
|
'websocket_connections': 0, |
|
|
'tcp_connections': 0, |
|
|
'connection_errors': 0 |
|
|
} |
|
|
|
|
|
|
|
|
self.loop = None |
|
|
|
|
|
def add_packet_handler(self, handler: Callable[[ParsedPacket, str], Optional[bytes]]): |
|
|
"""Add packet handler function""" |
|
|
self.packet_handlers.append(handler) |
|
|
|
|
|
def remove_packet_handler(self, handler: Callable[[ParsedPacket, str], Optional[bytes]]): |
|
|
"""Remove packet handler function""" |
|
|
if handler in self.packet_handlers: |
|
|
self.packet_handlers.remove(handler) |
|
|
|
|
|
def _generate_client_id(self, remote_address: str, remote_port: int) -> str: |
|
|
"""Generate unique client ID""" |
|
|
timestamp = int(time.time() * 1000) |
|
|
return f"client_{remote_address}_{remote_port}_{timestamp}" |
|
|
|
|
|
def _process_ethernet_frame(self, frame_data: bytes, client_id: str) -> Optional[bytes]: |
|
|
"""Process Ethernet frame and extract IP packet""" |
|
|
try: |
|
|
|
|
|
frame = EthernetFrame.parse(frame_data) |
|
|
if not frame or not frame.is_ipv4(): |
|
|
return None |
|
|
|
|
|
|
|
|
packet = IPParser.parse_packet(frame.payload) |
|
|
self.stats['packets_processed'] += 1 |
|
|
self.stats['bytes_processed'] += len(frame_data) |
|
|
|
|
|
|
|
|
response_packet = None |
|
|
for handler in self.packet_handlers: |
|
|
try: |
|
|
response = handler(packet, client_id) |
|
|
if response: |
|
|
response_packet = response |
|
|
break |
|
|
except Exception as e: |
|
|
logging.error(f"Packet handler error: {e}") |
|
|
|
|
|
if response_packet: |
|
|
|
|
|
response_frame = EthernetFrame() |
|
|
response_frame.dest_mac = frame.src_mac |
|
|
response_frame.src_mac = frame.dest_mac |
|
|
response_frame.ethertype = 0x0800 |
|
|
response_frame.payload = response_packet |
|
|
|
|
|
self.stats['packets_forwarded'] += 1 |
|
|
return response_frame.build() |
|
|
else: |
|
|
self.stats['packets_dropped'] += 1 |
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Error processing Ethernet frame: {e}") |
|
|
self.stats['packets_dropped'] += 1 |
|
|
return None |
|
|
|
|
|
async def _handle_websocket_client(self, websocket, path): |
|
|
"""Handle WebSocket client connection""" |
|
|
client_address = websocket.remote_address |
|
|
client_id = self._generate_client_id(client_address[0], client_address[1]) |
|
|
|
|
|
|
|
|
client = ClientConnection( |
|
|
client_id=client_id, |
|
|
bridge_type=BridgeType.WEBSOCKET, |
|
|
remote_address=client_address[0], |
|
|
remote_port=client_address[1], |
|
|
websocket=websocket |
|
|
) |
|
|
|
|
|
with self.lock: |
|
|
if len(self.clients) >= self.max_clients: |
|
|
await websocket.close(code=1013, reason="Too many clients") |
|
|
return |
|
|
|
|
|
self.clients[client_id] = client |
|
|
|
|
|
self.stats['total_clients'] += 1 |
|
|
self.stats['active_clients'] = len(self.clients) |
|
|
self.stats['websocket_connections'] += 1 |
|
|
|
|
|
logging.info(f"WebSocket client connected: {client_id} from {client_address}") |
|
|
|
|
|
try: |
|
|
async for message in websocket: |
|
|
if isinstance(message, bytes): |
|
|
|
|
|
client.update_activity(1, len(message), 'received') |
|
|
|
|
|
response = self._process_ethernet_frame(message, client_id) |
|
|
if response: |
|
|
await websocket.send(response) |
|
|
client.update_activity(1, len(response), 'sent') |
|
|
|
|
|
elif isinstance(message, str): |
|
|
|
|
|
try: |
|
|
control_msg = json.loads(message) |
|
|
await self._handle_control_message(client, control_msg) |
|
|
except json.JSONDecodeError: |
|
|
logging.warning(f"Invalid control message from {client_id}: {message}") |
|
|
|
|
|
except websockets.exceptions.ConnectionClosed: |
|
|
logging.info(f"WebSocket client disconnected: {client_id}") |
|
|
except Exception as e: |
|
|
logging.error(f"WebSocket client error: {e}") |
|
|
self.stats['connection_errors'] += 1 |
|
|
|
|
|
finally: |
|
|
|
|
|
with self.lock: |
|
|
if client_id in self.clients: |
|
|
self.clients[client_id].is_active = False |
|
|
del self.clients[client_id] |
|
|
|
|
|
self.stats['active_clients'] = len(self.clients) |
|
|
|
|
|
async def _handle_control_message(self, client: ClientConnection, message: Dict): |
|
|
"""Handle control message from client""" |
|
|
msg_type = message.get('type') |
|
|
|
|
|
if msg_type == 'ping': |
|
|
|
|
|
response = {'type': 'pong', 'timestamp': time.time()} |
|
|
await client.websocket.send(json.dumps(response)) |
|
|
|
|
|
elif msg_type == 'stats': |
|
|
|
|
|
response = { |
|
|
'type': 'stats', |
|
|
'client_stats': client.to_dict(), |
|
|
'bridge_stats': self.get_stats() |
|
|
} |
|
|
await client.websocket.send(json.dumps(response)) |
|
|
|
|
|
elif msg_type == 'config': |
|
|
|
|
|
config_data = message.get('data', {}) |
|
|
|
|
|
response = {'type': 'config_ack', 'status': 'ok'} |
|
|
await client.websocket.send(json.dumps(response)) |
|
|
|
|
|
def _handle_tcp_client(self, client_socket: socket.socket, client_address: Tuple[str, int]): |
|
|
"""Handle TCP client connection""" |
|
|
client_id = self._generate_client_id(client_address[0], client_address[1]) |
|
|
|
|
|
|
|
|
client = ClientConnection( |
|
|
client_id=client_id, |
|
|
bridge_type=BridgeType.TCP_SOCKET, |
|
|
remote_address=client_address[0], |
|
|
remote_port=client_address[1], |
|
|
socket=client_socket |
|
|
) |
|
|
|
|
|
with self.lock: |
|
|
if len(self.clients) >= self.max_clients: |
|
|
client_socket.close() |
|
|
return |
|
|
|
|
|
self.clients[client_id] = client |
|
|
|
|
|
self.stats['total_clients'] += 1 |
|
|
self.stats['active_clients'] = len(self.clients) |
|
|
self.stats['tcp_connections'] += 1 |
|
|
|
|
|
logging.info(f"TCP client connected: {client_id} from {client_address}") |
|
|
|
|
|
try: |
|
|
client_socket.settimeout(self.client_timeout) |
|
|
|
|
|
while client.is_active: |
|
|
try: |
|
|
|
|
|
length_data = client_socket.recv(4) |
|
|
if not length_data: |
|
|
break |
|
|
|
|
|
frame_length = struct.unpack('!I', length_data)[0] |
|
|
if frame_length > 65536: |
|
|
break |
|
|
|
|
|
|
|
|
frame_data = b'' |
|
|
while len(frame_data) < frame_length: |
|
|
chunk = client_socket.recv(frame_length - len(frame_data)) |
|
|
if not chunk: |
|
|
break |
|
|
frame_data += chunk |
|
|
|
|
|
if len(frame_data) != frame_length: |
|
|
break |
|
|
|
|
|
client.update_activity(1, len(frame_data), 'received') |
|
|
|
|
|
|
|
|
response = self._process_ethernet_frame(frame_data, client_id) |
|
|
if response: |
|
|
|
|
|
response_length = struct.pack('!I', len(response)) |
|
|
client_socket.send(response_length + response) |
|
|
client.update_activity(1, len(response), 'sent') |
|
|
|
|
|
except socket.timeout: |
|
|
continue |
|
|
except Exception as e: |
|
|
logging.error(f"TCP client error: {e}") |
|
|
break |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"TCP client handler error: {e}") |
|
|
self.stats['connection_errors'] += 1 |
|
|
|
|
|
finally: |
|
|
|
|
|
try: |
|
|
client_socket.close() |
|
|
except: |
|
|
pass |
|
|
|
|
|
with self.lock: |
|
|
if client_id in self.clients: |
|
|
self.clients[client_id].is_active = False |
|
|
del self.clients[client_id] |
|
|
|
|
|
self.stats['active_clients'] = len(self.clients) |
|
|
logging.info(f"TCP client disconnected: {client_id}") |
|
|
|
|
|
def _tcp_server_loop(self): |
|
|
"""TCP server loop""" |
|
|
try: |
|
|
self.tcp_server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
|
|
self.tcp_server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
|
|
self.tcp_server_socket.bind((self.tcp_host, self.tcp_port)) |
|
|
self.tcp_server_socket.listen(10) |
|
|
|
|
|
logging.info(f"TCP bridge server listening on {self.tcp_host}:{self.tcp_port}") |
|
|
|
|
|
while self.running: |
|
|
try: |
|
|
client_socket, client_address = self.tcp_server_socket.accept() |
|
|
|
|
|
|
|
|
client_thread = threading.Thread( |
|
|
target=self._handle_tcp_client, |
|
|
args=(client_socket, client_address), |
|
|
daemon=True |
|
|
) |
|
|
client_thread.start() |
|
|
|
|
|
except socket.error as e: |
|
|
if self.running: |
|
|
logging.error(f"TCP server error: {e}") |
|
|
time.sleep(1) |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"TCP server loop error: {e}") |
|
|
|
|
|
finally: |
|
|
if self.tcp_server_socket: |
|
|
self.tcp_server_socket.close() |
|
|
|
|
|
def _cleanup_loop(self): |
|
|
"""Background cleanup loop""" |
|
|
while self.running: |
|
|
try: |
|
|
current_time = time.time() |
|
|
expired_clients = [] |
|
|
|
|
|
with self.lock: |
|
|
for client_id, client in self.clients.items(): |
|
|
|
|
|
if current_time - client.last_activity > self.client_timeout: |
|
|
expired_clients.append(client_id) |
|
|
|
|
|
|
|
|
for client_id in expired_clients: |
|
|
with self.lock: |
|
|
if client_id in self.clients: |
|
|
client = self.clients[client_id] |
|
|
client.is_active = False |
|
|
|
|
|
|
|
|
if client.websocket: |
|
|
try: |
|
|
asyncio.run_coroutine_threadsafe( |
|
|
client.websocket.close(), |
|
|
self.loop |
|
|
) |
|
|
except: |
|
|
pass |
|
|
|
|
|
if client.socket: |
|
|
try: |
|
|
client.socket.close() |
|
|
except: |
|
|
pass |
|
|
|
|
|
del self.clients[client_id] |
|
|
logging.info(f"Cleaned up expired client: {client_id}") |
|
|
|
|
|
self.stats['active_clients'] = len(self.clients) |
|
|
|
|
|
time.sleep(30) |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Cleanup loop error: {e}") |
|
|
time.sleep(5) |
|
|
|
|
|
def send_packet_to_client(self, client_id: str, packet_data: bytes) -> bool: |
|
|
"""Send packet to specific client""" |
|
|
with self.lock: |
|
|
client = self.clients.get(client_id) |
|
|
|
|
|
if not client or not client.is_active: |
|
|
return False |
|
|
|
|
|
try: |
|
|
if client.bridge_type == BridgeType.WEBSOCKET: |
|
|
|
|
|
if client.websocket: |
|
|
asyncio.run_coroutine_threadsafe( |
|
|
client.websocket.send(packet_data), |
|
|
self.loop |
|
|
) |
|
|
client.update_activity(1, len(packet_data), 'sent') |
|
|
return True |
|
|
|
|
|
elif client.bridge_type == BridgeType.TCP_SOCKET: |
|
|
|
|
|
if client.socket: |
|
|
length_prefix = struct.pack('!I', len(packet_data)) |
|
|
client.socket.send(length_prefix + packet_data) |
|
|
client.update_activity(1, len(packet_data), 'sent') |
|
|
return True |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"Failed to send packet to client {client_id}: {e}") |
|
|
|
|
|
client.is_active = False |
|
|
|
|
|
return False |
|
|
|
|
|
def broadcast_packet(self, packet_data: bytes, exclude_client: Optional[str] = None) -> int: |
|
|
"""Broadcast packet to all clients""" |
|
|
sent_count = 0 |
|
|
|
|
|
with self.lock: |
|
|
client_ids = list(self.clients.keys()) |
|
|
|
|
|
for client_id in client_ids: |
|
|
if client_id != exclude_client: |
|
|
if self.send_packet_to_client(client_id, packet_data): |
|
|
sent_count += 1 |
|
|
|
|
|
return sent_count |
|
|
|
|
|
def get_clients(self) -> Dict[str, Dict]: |
|
|
"""Get all connected clients""" |
|
|
with self.lock: |
|
|
return { |
|
|
client_id: client.to_dict() |
|
|
for client_id, client in self.clients.items() |
|
|
} |
|
|
|
|
|
def get_client(self, client_id: str) -> Optional[Dict]: |
|
|
"""Get specific client""" |
|
|
with self.lock: |
|
|
client = self.clients.get(client_id) |
|
|
return client.to_dict() if client else None |
|
|
|
|
|
def disconnect_client(self, client_id: str) -> bool: |
|
|
"""Disconnect specific client""" |
|
|
with self.lock: |
|
|
client = self.clients.get(client_id) |
|
|
if not client: |
|
|
return False |
|
|
|
|
|
client.is_active = False |
|
|
|
|
|
|
|
|
if client.websocket: |
|
|
try: |
|
|
asyncio.run_coroutine_threadsafe( |
|
|
client.websocket.close(), |
|
|
self.loop |
|
|
) |
|
|
except: |
|
|
pass |
|
|
|
|
|
if client.socket: |
|
|
try: |
|
|
client.socket.close() |
|
|
except: |
|
|
pass |
|
|
|
|
|
del self.clients[client_id] |
|
|
self.stats['active_clients'] = len(self.clients) |
|
|
|
|
|
return True |
|
|
|
|
|
def get_stats(self) -> Dict: |
|
|
"""Get bridge statistics""" |
|
|
with self.lock: |
|
|
stats = self.stats.copy() |
|
|
stats['active_clients'] = len(self.clients) |
|
|
|
|
|
return stats |
|
|
|
|
|
def reset_stats(self): |
|
|
"""Reset bridge statistics""" |
|
|
self.stats = { |
|
|
'total_clients': 0, |
|
|
'active_clients': len(self.clients), |
|
|
'packets_processed': 0, |
|
|
'packets_forwarded': 0, |
|
|
'packets_dropped': 0, |
|
|
'bytes_processed': 0, |
|
|
'websocket_connections': 0, |
|
|
'tcp_connections': 0, |
|
|
'connection_errors': 0 |
|
|
} |
|
|
|
|
|
async def start_websocket_server(self): |
|
|
"""Start WebSocket server""" |
|
|
try: |
|
|
self.websocket_server = await websockets.serve( |
|
|
self._handle_websocket_client, |
|
|
self.websocket_host, |
|
|
self.websocket_port, |
|
|
max_size=1024*1024, |
|
|
ping_interval=30, |
|
|
ping_timeout=10 |
|
|
) |
|
|
|
|
|
logging.info(f"WebSocket bridge server started on {self.websocket_host}:{self.websocket_port}") |
|
|
|
|
|
|
|
|
await self.websocket_server.wait_closed() |
|
|
|
|
|
except Exception as e: |
|
|
logging.error(f"WebSocket server error: {e}") |
|
|
|
|
|
def start(self): |
|
|
"""Start packet bridge""" |
|
|
self.running = True |
|
|
|
|
|
|
|
|
self.loop = asyncio.new_event_loop() |
|
|
asyncio.set_event_loop(self.loop) |
|
|
|
|
|
|
|
|
websocket_thread = threading.Thread(target=self._run_websocket_server_in_thread, daemon=True) |
|
|
websocket_thread.start() |
|
|
|
|
|
|
|
|
tcp_thread = threading.Thread(target=self._tcp_server_loop, daemon=True) |
|
|
tcp_thread.start() |
|
|
|
|
|
|
|
|
cleanup_thread = threading.Thread(target=self._cleanup_loop, daemon=True) |
|
|
cleanup_thread.start() |
|
|
|
|
|
logging.info("Packet bridge started") |
|
|
|
|
|
|
|
|
|
|
|
def stop(self): |
|
|
"""Stop packet bridge""" |
|
|
self.running = False |
|
|
|
|
|
|
|
|
if self.websocket_server: |
|
|
self.websocket_server.close() |
|
|
|
|
|
|
|
|
if self.tcp_server_socket: |
|
|
self.tcp_server_socket.close() |
|
|
|
|
|
|
|
|
with self.lock: |
|
|
client_ids = list(self.clients.keys()) |
|
|
|
|
|
for client_id in client_ids: |
|
|
self.disconnect_client(client_id) |
|
|
|
|
|
|
|
|
if self.loop and not self.loop.is_closed(): |
|
|
self.loop.call_soon_threadsafe(self.loop.stop) |
|
|
|
|
|
logging.info("Packet bridge stopped") |
|
|
|
|
|
|
|
|
|
|
|
def _run_websocket_server_in_thread(self): |
|
|
"""Run the WebSocket server in a separate thread with its own event loop.""" |
|
|
asyncio.set_event_loop(self.loop) |
|
|
self.loop.run_until_complete(self.start_websocket_server()) |
|
|
|
|
|
|
|
|
|