File size: 2,727 Bytes
2a081e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c608e0
 
 
 
 
2a081e2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52013b5
2a081e2
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
#!/usr/bin/env python3
"""WebSocket-to-SSH bridge: accepts WebSocket connections on port 7862,
bridges each to sshd on 127.0.0.1:2222. Handles concurrent connections.

Used by nginx to provide SSH-over-WebSocket through the single port 7860.
Client usage:
  ssh -o 'ProxyCommand=websocat -b wss://tao-shen-huggingrun.hf.space/ssh' user@huggingrun
"""
import asyncio
import os
import signal
import sys

try:
    import websockets
except ImportError:
    print("[ws-ssh-bridge] websockets not installed, pip installing...", file=sys.stderr)
    import subprocess
    subprocess.check_call([sys.executable, "-m", "pip", "install", "websockets", "-q"])
    import websockets

# Support both old and new websockets API
try:
    from websockets.asyncio.server import serve
except ImportError:
    from websockets.server import serve

SSH_HOST = "127.0.0.1"
SSH_PORT = int(os.environ.get("SSH_PORT", "2222"))
WS_PORT = 7862


async def bridge(websocket):
    """Bridge a single WebSocket connection to sshd via TCP."""
    try:
        reader, writer = await asyncio.open_connection(SSH_HOST, SSH_PORT)
    except Exception as e:
        print(f"[ws-ssh-bridge] Cannot connect to sshd: {e}", file=sys.stderr)
        await websocket.close(1011, f"sshd unreachable: {e}")
        return

    async def ws_to_tcp():
        try:
            async for msg in websocket:
                if isinstance(msg, bytes):
                    writer.write(msg)
                elif isinstance(msg, str):
                    writer.write(msg.encode())
                await writer.drain()
        except websockets.ConnectionClosed:
            pass
        finally:
            if not writer.is_closing():
                writer.close()

    async def tcp_to_ws():
        try:
            while True:
                data = await reader.read(65536)
                if not data:
                    break
                await websocket.send(data)
        except (websockets.ConnectionClosed, ConnectionResetError):
            pass

    try:
        await asyncio.gather(ws_to_tcp(), tcp_to_ws())
    except Exception:
        pass
    finally:
        if not writer.is_closing():
            writer.close()


async def main():
    print(f"[ws-ssh-bridge] Listening on 127.0.0.1:{WS_PORT} -> sshd {SSH_HOST}:{SSH_PORT}",
          file=sys.stderr)
    async with serve(bridge, "127.0.0.1", WS_PORT,
                     ping_interval=30, ping_timeout=120,
                     max_size=None):
        stop = asyncio.Event()
        loop = asyncio.get_event_loop()
        for sig in (signal.SIGINT, signal.SIGTERM):
            loop.add_signal_handler(sig, stop.set)
        await stop.wait()


if __name__ == "__main__":
    asyncio.run(main())