cpolarClient / client.py
BOSS
客户端
2f8c384
import argparse
import asyncio
import struct
import websockets
import socket
import uuid
# Simple length-prefixed framing
# frame: 4-byte big-endian length + payload
async def websocket_send_frame(ws, data: bytes, frame_type: bytes = b'D'): # frame_type: b'D' for data, b'O' for open, etc.
# frame format: 1-byte type + 4-byte length + payload
header = frame_type + struct.pack('>I', len(data))
await ws.send(header + data)
async def websocket_recv_frame(ws, recv_lock: asyncio.Lock):
# Serialize calls to ws.recv using recv_lock to avoid concurrent recv errors
async with recv_lock:
data = await ws.recv()
if isinstance(data, str):
data = data.encode()
if len(data) < 5:
raise ValueError('frame too short')
frame_type = data[0:1]
length = struct.unpack('>I', data[1:5])[0]
payload = data[5:5+length]
return frame_type, payload
async def handle_local_connection(local_reader, local_writer, ws, recv_lock: asyncio.Lock):
peername = None
try:
sock = local_writer.get_extra_info('socket')
if sock:
peername = sock.getpeername()
except Exception:
peername = None
print(f'New local connection from {peername}')
async def read_local_then_send():
try:
while True:
chunk = await local_reader.read(4096)
if not chunk:
# EOF
print('local EOF, sending close frame')
await websocket_send_frame(ws, b'__CLOSE__')
break
await websocket_send_frame(ws, chunk)
except asyncio.CancelledError:
pass
except Exception as e:
print('local read/send error', e)
async def recv_then_write_local():
try:
while True:
try:
frame_type, payload = await websocket_recv_frame(ws, recv_lock)
except Exception as e:
print('ws recv/local write error', e)
break
if payload == b'__CLOSE__':
print('received close frame from server')
break
try:
local_writer.write(payload)
await local_writer.drain()
except Exception as e:
print('local writer error while writing payload', e)
break
finally:
try:
local_writer.close()
await local_writer.wait_closed()
except Exception:
pass
# run both tasks and ensure exceptions are handled
task_send = asyncio.create_task(read_local_then_send())
task_recv = asyncio.create_task(recv_then_write_local())
done, pending = await asyncio.wait({task_send, task_recv}, return_when=asyncio.FIRST_EXCEPTION)
for t in pending:
t.cancel()
for t in done:
if t.exception():
print('task exception in handle_local_connection:', t.exception())
print(f'Connection from {peername} closed')
async def handle_local_connection_multiplex(local_reader, local_writer, ws, recv_lock: asyncio.Lock, conn_id: bytes, conn_queue: asyncio.Queue, conn_queues: dict):
# Similar to handle_local_connection but prefixes payload with conn_id
peername = None
try:
sock = local_writer.get_extra_info('socket')
if sock:
peername = sock.getpeername()
except Exception:
peername = None
print(f'New multiplexed connection {conn_id.decode()} from {peername}')
async def read_local_then_send():
try:
while True:
chunk = await local_reader.read(4096)
if not chunk:
# EOF
print(f'local EOF for {conn_id.decode()}, sending close frame')
await websocket_send_frame(ws, conn_id + b'__CLOSE__', frame_type=b'D')
break
await websocket_send_frame(ws, conn_id + chunk, frame_type=b'D')
except asyncio.CancelledError:
pass
except Exception as e:
print('local read/send error', e)
async def recv_then_write_local():
try:
while True:
try:
payload = await conn_queue.get()
except Exception as e:
print('conn_queue get error', e)
break
if payload == b'__CLOSE__':
print(f'received close for {conn_id.decode()}')
break
try:
local_writer.write(payload)
await local_writer.drain()
except Exception as e:
print('local writer error while writing payload', e)
break
finally:
try:
local_writer.close()
await local_writer.wait_closed()
except Exception:
pass
task_send = asyncio.create_task(read_local_then_send())
task_recv = asyncio.create_task(recv_then_write_local())
done, pending = await asyncio.wait({task_send, task_recv}, return_when=asyncio.FIRST_EXCEPTION)
for t in pending:
t.cancel()
for t in done:
if t.exception():
print('task exception in handle_local_connection_multiplex:', t.exception())
# cleanup queue registration
try:
conn_queues.pop(conn_id, None)
except Exception:
pass
print(f'Multiplexed connection {conn_id.decode()} from {peername} closed')
async def run_client(server_url: str, local_port: int):
async with websockets.connect(server_url) as ws:
print('Connected to server')
# create a lock to serialize ws.recv calls
recv_lock = asyncio.Lock()
# handshake using the same recv_lock
await ws.send('ping')
async with recv_lock:
resp = await ws.recv()
print('handshake response:', resp)
# multiplexing: assign a unique client session id
session_id = uuid.uuid4().hex[:8].encode()
print('session id:', session_id.decode())
# mapping of conn_id bytes -> asyncio.Queue for demuxing incoming frames
conn_queues: dict[bytes, asyncio.Queue] = {}
async def demuxer_loop():
try:
while True:
frame_type, payload = await websocket_recv_frame(ws, recv_lock)
# only handle data frames (type 'D') for demuxing
if frame_type != b'D':
# ignore other frame types for now
continue
if len(payload) < 8:
print('demux: payload too short')
continue
cid = payload[:8]
body = payload[8:]
q = conn_queues.get(cid)
if q:
await q.put(body)
else:
print('demux: no queue for', cid)
except Exception as e:
print('demuxer exiting:', e)
demux_task = asyncio.create_task(demuxer_loop())
# start local server to accept connections and forward (each connection gets an ephemeral id)
async def accept_callback(r, w):
try:
data = await asyncio.wait_for(r.read(1024), timeout=1.5)
except Exception:
data = b''
# Browser quick response
if data.startswith((b'GET ', b'POST ', b'HEAD ')):
body = b"<html><body><h1>cpolar demo</h1><p>Connected</p></body></html>"
html = b"HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: " + str(len(body)).encode() + b"\r\n\r\n" + body
try:
w.write(html)
await w.drain()
except Exception:
pass
try:
w.close()
await w.wait_closed()
except Exception:
pass
return
# when a new local connection arrives, send an OPEN frame with connection id
conn_id = uuid.uuid4().hex[:8].encode()
q: asyncio.Queue = asyncio.Queue()
conn_queues[conn_id] = q
await websocket_send_frame(ws, conn_id, frame_type=b'O')
# First chunk (if any) is data, send it immediately as conn_id+data (if non-empty)
if data:
await websocket_send_frame(ws, conn_id + data, frame_type=b'D')
# start handler with conn-specific id by passing queue and full conn_queues
asyncio.create_task(handle_local_connection_multiplex(r, w, ws, recv_lock, conn_id, q, conn_queues))
server = await asyncio.start_server(accept_callback, '127.0.0.1', local_port)
addrs = ', '.join(str(sock.getsockname()) for sock in server.sockets)
print(f'Listening on {addrs}, forward to server (multiplexed)')
async with server:
await server.serve_forever()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--server', required=True, help='WebSocket server URL, e.g. ws://localhost:8000/ws')
parser.add_argument('--local-port', type=int, default=9000, help='Local port to listen on')
args = parser.parse_args()
try:
asyncio.run(run_client(args.server, args.local_port))
except KeyboardInterrupt:
print('client exiting')