Spaces:
Configuration error
Configuration error
| import argparse | |
| import asyncio | |
| import struct | |
| import websockets | |
| import socket | |
| import uuid | |
| # Simple length-prefixed framing | |
| # frame: 4-byte big-endian length + payload | |
| async def websocket_send_frame(ws, data: bytes, frame_type: bytes = b'D'): # frame_type: b'D' for data, b'O' for open, etc. | |
| # frame format: 1-byte type + 4-byte length + payload | |
| header = frame_type + struct.pack('>I', len(data)) | |
| await ws.send(header + data) | |
| async def websocket_recv_frame(ws, recv_lock: asyncio.Lock): | |
| # Serialize calls to ws.recv using recv_lock to avoid concurrent recv errors | |
| async with recv_lock: | |
| data = await ws.recv() | |
| if isinstance(data, str): | |
| data = data.encode() | |
| if len(data) < 5: | |
| raise ValueError('frame too short') | |
| frame_type = data[0:1] | |
| length = struct.unpack('>I', data[1:5])[0] | |
| payload = data[5:5+length] | |
| return frame_type, payload | |
| async def handle_local_connection(local_reader, local_writer, ws, recv_lock: asyncio.Lock): | |
| peername = None | |
| try: | |
| sock = local_writer.get_extra_info('socket') | |
| if sock: | |
| peername = sock.getpeername() | |
| except Exception: | |
| peername = None | |
| print(f'New local connection from {peername}') | |
| async def read_local_then_send(): | |
| try: | |
| while True: | |
| chunk = await local_reader.read(4096) | |
| if not chunk: | |
| # EOF | |
| print('local EOF, sending close frame') | |
| await websocket_send_frame(ws, b'__CLOSE__') | |
| break | |
| await websocket_send_frame(ws, chunk) | |
| except asyncio.CancelledError: | |
| pass | |
| except Exception as e: | |
| print('local read/send error', e) | |
| async def recv_then_write_local(): | |
| try: | |
| while True: | |
| try: | |
| frame_type, payload = await websocket_recv_frame(ws, recv_lock) | |
| except Exception as e: | |
| print('ws recv/local write error', e) | |
| break | |
| if payload == b'__CLOSE__': | |
| print('received close frame from server') | |
| break | |
| try: | |
| local_writer.write(payload) | |
| await local_writer.drain() | |
| except Exception as e: | |
| print('local writer error while writing payload', e) | |
| break | |
| finally: | |
| try: | |
| local_writer.close() | |
| await local_writer.wait_closed() | |
| except Exception: | |
| pass | |
| # run both tasks and ensure exceptions are handled | |
| task_send = asyncio.create_task(read_local_then_send()) | |
| task_recv = asyncio.create_task(recv_then_write_local()) | |
| done, pending = await asyncio.wait({task_send, task_recv}, return_when=asyncio.FIRST_EXCEPTION) | |
| for t in pending: | |
| t.cancel() | |
| for t in done: | |
| if t.exception(): | |
| print('task exception in handle_local_connection:', t.exception()) | |
| print(f'Connection from {peername} closed') | |
| async def handle_local_connection_multiplex(local_reader, local_writer, ws, recv_lock: asyncio.Lock, conn_id: bytes, conn_queue: asyncio.Queue, conn_queues: dict): | |
| # Similar to handle_local_connection but prefixes payload with conn_id | |
| peername = None | |
| try: | |
| sock = local_writer.get_extra_info('socket') | |
| if sock: | |
| peername = sock.getpeername() | |
| except Exception: | |
| peername = None | |
| print(f'New multiplexed connection {conn_id.decode()} from {peername}') | |
| async def read_local_then_send(): | |
| try: | |
| while True: | |
| chunk = await local_reader.read(4096) | |
| if not chunk: | |
| # EOF | |
| print(f'local EOF for {conn_id.decode()}, sending close frame') | |
| await websocket_send_frame(ws, conn_id + b'__CLOSE__', frame_type=b'D') | |
| break | |
| await websocket_send_frame(ws, conn_id + chunk, frame_type=b'D') | |
| except asyncio.CancelledError: | |
| pass | |
| except Exception as e: | |
| print('local read/send error', e) | |
| async def recv_then_write_local(): | |
| try: | |
| while True: | |
| try: | |
| payload = await conn_queue.get() | |
| except Exception as e: | |
| print('conn_queue get error', e) | |
| break | |
| if payload == b'__CLOSE__': | |
| print(f'received close for {conn_id.decode()}') | |
| break | |
| try: | |
| local_writer.write(payload) | |
| await local_writer.drain() | |
| except Exception as e: | |
| print('local writer error while writing payload', e) | |
| break | |
| finally: | |
| try: | |
| local_writer.close() | |
| await local_writer.wait_closed() | |
| except Exception: | |
| pass | |
| task_send = asyncio.create_task(read_local_then_send()) | |
| task_recv = asyncio.create_task(recv_then_write_local()) | |
| done, pending = await asyncio.wait({task_send, task_recv}, return_when=asyncio.FIRST_EXCEPTION) | |
| for t in pending: | |
| t.cancel() | |
| for t in done: | |
| if t.exception(): | |
| print('task exception in handle_local_connection_multiplex:', t.exception()) | |
| # cleanup queue registration | |
| try: | |
| conn_queues.pop(conn_id, None) | |
| except Exception: | |
| pass | |
| print(f'Multiplexed connection {conn_id.decode()} from {peername} closed') | |
| async def run_client(server_url: str, local_port: int): | |
| async with websockets.connect(server_url) as ws: | |
| print('Connected to server') | |
| # create a lock to serialize ws.recv calls | |
| recv_lock = asyncio.Lock() | |
| # handshake using the same recv_lock | |
| await ws.send('ping') | |
| async with recv_lock: | |
| resp = await ws.recv() | |
| print('handshake response:', resp) | |
| # multiplexing: assign a unique client session id | |
| session_id = uuid.uuid4().hex[:8].encode() | |
| print('session id:', session_id.decode()) | |
| # mapping of conn_id bytes -> asyncio.Queue for demuxing incoming frames | |
| conn_queues: dict[bytes, asyncio.Queue] = {} | |
| async def demuxer_loop(): | |
| try: | |
| while True: | |
| frame_type, payload = await websocket_recv_frame(ws, recv_lock) | |
| # only handle data frames (type 'D') for demuxing | |
| if frame_type != b'D': | |
| # ignore other frame types for now | |
| continue | |
| if len(payload) < 8: | |
| print('demux: payload too short') | |
| continue | |
| cid = payload[:8] | |
| body = payload[8:] | |
| q = conn_queues.get(cid) | |
| if q: | |
| await q.put(body) | |
| else: | |
| print('demux: no queue for', cid) | |
| except Exception as e: | |
| print('demuxer exiting:', e) | |
| demux_task = asyncio.create_task(demuxer_loop()) | |
| # start local server to accept connections and forward (each connection gets an ephemeral id) | |
| async def accept_callback(r, w): | |
| try: | |
| data = await asyncio.wait_for(r.read(1024), timeout=1.5) | |
| except Exception: | |
| data = b'' | |
| # Browser quick response | |
| if data.startswith((b'GET ', b'POST ', b'HEAD ')): | |
| body = b"<html><body><h1>cpolar demo</h1><p>Connected</p></body></html>" | |
| html = b"HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: " + str(len(body)).encode() + b"\r\n\r\n" + body | |
| try: | |
| w.write(html) | |
| await w.drain() | |
| except Exception: | |
| pass | |
| try: | |
| w.close() | |
| await w.wait_closed() | |
| except Exception: | |
| pass | |
| return | |
| # when a new local connection arrives, send an OPEN frame with connection id | |
| conn_id = uuid.uuid4().hex[:8].encode() | |
| q: asyncio.Queue = asyncio.Queue() | |
| conn_queues[conn_id] = q | |
| await websocket_send_frame(ws, conn_id, frame_type=b'O') | |
| # First chunk (if any) is data, send it immediately as conn_id+data (if non-empty) | |
| if data: | |
| await websocket_send_frame(ws, conn_id + data, frame_type=b'D') | |
| # start handler with conn-specific id by passing queue and full conn_queues | |
| asyncio.create_task(handle_local_connection_multiplex(r, w, ws, recv_lock, conn_id, q, conn_queues)) | |
| server = await asyncio.start_server(accept_callback, '127.0.0.1', local_port) | |
| addrs = ', '.join(str(sock.getsockname()) for sock in server.sockets) | |
| print(f'Listening on {addrs}, forward to server (multiplexed)') | |
| async with server: | |
| await server.serve_forever() | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument('--server', required=True, help='WebSocket server URL, e.g. ws://localhost:8000/ws') | |
| parser.add_argument('--local-port', type=int, default=9000, help='Local port to listen on') | |
| args = parser.parse_args() | |
| try: | |
| asyncio.run(run_client(args.server, args.local_port)) | |
| except KeyboardInterrupt: | |
| print('client exiting') | |