HuggingRun / ws_ssh_bridge.py
tao-shen's picture
v2: complete redesign — single entrypoint, tar.zst persistence
52013b5
#!/usr/bin/env python3
"""WebSocket-to-SSH bridge: accepts WebSocket connections on port 7862,
bridges each to sshd on 127.0.0.1:2222. Handles concurrent connections.
Used by nginx to provide SSH-over-WebSocket through the single port 7860.
Client usage:
ssh -o 'ProxyCommand=websocat -b wss://tao-shen-huggingrun.hf.space/ssh' user@huggingrun
"""
import asyncio
import os
import signal
import sys
try:
import websockets
except ImportError:
print("[ws-ssh-bridge] websockets not installed, pip installing...", file=sys.stderr)
import subprocess
subprocess.check_call([sys.executable, "-m", "pip", "install", "websockets", "-q"])
import websockets
# Support both old and new websockets API
try:
from websockets.asyncio.server import serve
except ImportError:
from websockets.server import serve
SSH_HOST = "127.0.0.1"
SSH_PORT = int(os.environ.get("SSH_PORT", "2222"))
WS_PORT = 7862
async def bridge(websocket):
"""Bridge a single WebSocket connection to sshd via TCP."""
try:
reader, writer = await asyncio.open_connection(SSH_HOST, SSH_PORT)
except Exception as e:
print(f"[ws-ssh-bridge] Cannot connect to sshd: {e}", file=sys.stderr)
await websocket.close(1011, f"sshd unreachable: {e}")
return
async def ws_to_tcp():
try:
async for msg in websocket:
if isinstance(msg, bytes):
writer.write(msg)
elif isinstance(msg, str):
writer.write(msg.encode())
await writer.drain()
except websockets.ConnectionClosed:
pass
finally:
if not writer.is_closing():
writer.close()
async def tcp_to_ws():
try:
while True:
data = await reader.read(65536)
if not data:
break
await websocket.send(data)
except (websockets.ConnectionClosed, ConnectionResetError):
pass
try:
await asyncio.gather(ws_to_tcp(), tcp_to_ws())
except Exception:
pass
finally:
if not writer.is_closing():
writer.close()
async def main():
print(f"[ws-ssh-bridge] Listening on 127.0.0.1:{WS_PORT} -> sshd {SSH_HOST}:{SSH_PORT}",
file=sys.stderr)
async with serve(bridge, "127.0.0.1", WS_PORT,
ping_interval=30, ping_timeout=120,
max_size=None):
stop = asyncio.Event()
loop = asyncio.get_event_loop()
for sig in (signal.SIGINT, signal.SIGTERM):
loop.add_signal_handler(sig, stop.set)
await stop.wait()
if __name__ == "__main__":
asyncio.run(main())