File size: 6,380 Bytes
daaa6ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36ce73b
 
daaa6ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36ce73b
 
 
 
 
 
 
 
 
 
 
 
 
daaa6ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
"""

Terminal WebSocket with persistent PTY sessions.



Single Responsibility: only handles PTY lifecycle and WebSocket communication.

Depends on storage.get_zone_path for path resolution (Dependency Inversion).

"""

import asyncio
import collections
import fcntl
import json
import os
import pty
import select
import struct
import termios

from fastapi import APIRouter, WebSocket, WebSocketDisconnect

from config import SCROLLBACK_SIZE
from storage import get_zone_path, check_zone_owner
from auth import get_ws_user

router = APIRouter(tags=["terminal"])

# Active terminals: {zone_name: {fd, pid, buffer, buffer_size, bg_task, ws}}
active_terminals: dict[str, dict] = {}


# ── PTY Management ────────────────────────────

def _spawn_shell(zone_name: str) -> dict:
    """Spawn a new PTY shell for a zone."""
    zone_path = get_zone_path(zone_name)
    master_fd, slave_fd = pty.openpty()

    child_pid = os.fork()
    if child_pid == 0:
        os.setsid()
        os.dup2(slave_fd, 0)
        os.dup2(slave_fd, 1)
        os.dup2(slave_fd, 2)
        os.close(master_fd)
        os.close(slave_fd)
        os.chdir(str(zone_path))
        env = os.environ.copy()
        env["TERM"] = "xterm-256color"
        env["HOME"] = str(zone_path)
        env["PS1"] = f"[{zone_name}] \\w $ "
        os.execvpe("/bin/bash", ["/bin/bash", "--norc"], env)
    else:
        os.close(slave_fd)
        flag = fcntl.fcntl(master_fd, fcntl.F_GETFL)
        fcntl.fcntl(master_fd, fcntl.F_SETFL, flag | os.O_NONBLOCK)
        return {"fd": master_fd, "pid": child_pid, "buffer": collections.deque(), "buffer_size": 0}


def _resize_terminal(zone_name: str, rows: int, cols: int):
    if zone_name in active_terminals:
        fd = active_terminals[zone_name]["fd"]
        winsize = struct.pack("HHHH", rows, cols, 0, 0)
        fcntl.ioctl(fd, termios.TIOCSWINSZ, winsize)


def _append_buffer(info: dict, data: bytes):
    info["buffer"].append(data)
    info["buffer_size"] += len(data)
    while info["buffer_size"] > SCROLLBACK_SIZE:
        old = info["buffer"].popleft()
        info["buffer_size"] -= len(old)


def _get_buffer(info: dict) -> bytes:
    return b"".join(info["buffer"])


def _is_alive(zone_name: str) -> bool:
    if zone_name not in active_terminals:
        return False
    try:
        pid = active_terminals[zone_name]["pid"]
        return os.waitpid(pid, os.WNOHANG) == (0, 0)
    except ChildProcessError:
        active_terminals.pop(zone_name, None)
        return False


async def _bg_reader(zone_name: str):
    """Background: continuously read PTY output into the ring buffer."""
    info = active_terminals.get(zone_name)
    if not info:
        return
    fd = info["fd"]
    while _is_alive(zone_name):
        await asyncio.sleep(0.02)
        try:
            r, _, _ = select.select([fd], [], [], 0)
            if r:
                data = os.read(fd, 4096)
                if data:
                    _append_buffer(info, data)
                    ws = info.get("ws")
                    if ws:
                        try:
                            await ws.send_bytes(data)
                        except Exception:
                            info["ws"] = None
        except (OSError, BlockingIOError):
            pass
        except Exception:
            break


def kill_terminal(zone_name: str):
    """Kill terminal process for a zone."""
    if zone_name in active_terminals:
        info = active_terminals.pop(zone_name)
        bg = info.get("bg_task")
        if bg:
            bg.cancel()
        try:
            os.kill(info["pid"], 9)
            os.waitpid(info["pid"], os.WNOHANG)
        except (ProcessLookupError, ChildProcessError):
            pass
        try:
            os.close(info["fd"])
        except OSError:
            pass


# ── WebSocket Handler ─────────────────────────

@router.websocket("/ws/terminal/{zone_name}")
async def terminal_ws(websocket: WebSocket, zone_name: str):
    # Authenticate via query parameter
    user = get_ws_user(websocket)
    if not user:
        await websocket.close(code=4001, reason="Chưa đăng nhập")
        return

    # Check zone ownership
    try:
        check_zone_owner(zone_name, user.sub, user.role)
    except ValueError as e:
        await websocket.close(code=4003, reason=str(e))
        return

    await websocket.accept()

    try:
        get_zone_path(zone_name)
    except ValueError as e:
        await websocket.send_json({"error": str(e)})
        await websocket.close()
        return

    # Spawn or reuse terminal
    if not _is_alive(zone_name):
        kill_terminal(zone_name)
        try:
            info = _spawn_shell(zone_name)
            info["ws"] = None
            active_terminals[zone_name] = info
            info["bg_task"] = asyncio.create_task(_bg_reader(zone_name))
        except Exception as e:
            await websocket.send_json({"error": f"Cannot create terminal: {e}"})
            await websocket.close()
            return

    info = active_terminals[zone_name]
    fd = info["fd"]

    # Replay buffered scrollback
    buf = _get_buffer(info)
    if buf:
        await websocket.send_bytes(buf)

    # Register this WebSocket as the active receiver
    info["ws"] = websocket

    try:
        while True:
            msg = await websocket.receive()
            if msg.get("type") == "websocket.disconnect":
                break
            if "text" in msg:
                data = json.loads(msg["text"])
                if data.get("type") == "resize":
                    _resize_terminal(zone_name, data.get("rows", 24), data.get("cols", 80))
                elif data.get("type") == "input":
                    os.write(fd, data["data"].encode("utf-8"))
            elif "bytes" in msg:
                os.write(fd, msg["bytes"])
    except WebSocketDisconnect:
        pass
    except Exception:
        pass
    finally:
        if zone_name in active_terminals and active_terminals[zone_name].get("ws") is websocket:
            active_terminals[zone_name]["ws"] = None