File size: 5,855 Bytes
daaa6ed
 
 
 
 
 
 
 
 
 
 
36ce73b
daaa6ed
 
36ce73b
daaa6ed
36ce73b
daaa6ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36ce73b
daaa6ed
 
36ce73b
daaa6ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36ce73b
 
 
 
 
 
daaa6ed
 
36ce73b
daaa6ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
"""

Reverse proxy for virtual ports.



Single Responsibility: only handles HTTP/WebSocket proxying.

Port CRUD is in ports.py — separate concern.

"""

import asyncio
import json

import httpx
from fastapi import APIRouter, Depends, Request, WebSocket, WebSocketDisconnect
from fastapi.responses import Response

from auth import AuthUser, get_current_user, get_ws_user
from config import MIN_PORT, MAX_PORT
from storage import load_meta, check_zone_owner

router = APIRouter(tags=["proxy"])

# ── Shared HTTP client ────────────────────────

_HOP_HEADERS = frozenset({
    "connection", "keep-alive", "proxy-authenticate", "proxy-authorization",
    "te", "trailers", "transfer-encoding", "upgrade",
})

_client: httpx.AsyncClient | None = None


def _get_client() -> httpx.AsyncClient:
    global _client
    if _client is None:
        _client = httpx.AsyncClient(
            timeout=httpx.Timeout(30.0, connect=5.0),
            follow_redirects=False,
            limits=httpx.Limits(max_connections=50),
        )
    return _client


def _validate_proxy_access(zone_name: str, port: int):
    """Validate port range and check it's registered for the zone."""
    if not (MIN_PORT <= port <= MAX_PORT):
        raise ValueError(f"Port must be between {MIN_PORT} and {MAX_PORT}")
    meta = load_meta()
    if zone_name not in meta:
        raise ValueError(f"Zone '{zone_name}' does not exist")
    ports = meta[zone_name].get("ports", [])
    if not any(p["port"] == port for p in ports):
        raise ValueError("Port not mapped")


# ── HTTP Reverse Proxy ────────────────────────

@router.api_route(

    "/port/{zone_name}/{port}/{subpath:path}",

    methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"],

)
async def proxy_http(request: Request, zone_name: str, port: int, subpath: str = "", user: AuthUser = Depends(get_current_user)):
    try:
        _validate_proxy_access(zone_name, port)
        check_zone_owner(zone_name, user.sub, user.role)
    except ValueError:
        return Response(content="Port not mapped", status_code=404)

    target_url = f"http://127.0.0.1:{port}/{subpath}"
    if request.url.query:
        target_url += f"?{request.url.query}"

    headers = {}
    for key, value in request.headers.items():
        if key.lower() not in _HOP_HEADERS and key.lower() != "host":
            headers[key] = value
    headers["host"] = f"127.0.0.1:{port}"
    headers["x-forwarded-for"] = request.client.host if request.client else "127.0.0.1"
    headers["x-forwarded-proto"] = request.url.scheme
    headers["x-forwarded-prefix"] = f"/port/{zone_name}/{port}"

    body = await request.body()
    client = _get_client()

    try:
        resp = await client.request(method=request.method, url=target_url, headers=headers, content=body)
    except httpx.ConnectError:
        return Response(
            content=f"Cannot connect to port {port}. Make sure your server is running.",
            status_code=502,
            media_type="text/plain",
        )
    except httpx.TimeoutException:
        return Response(content=f"Timeout connecting to port {port}", status_code=504, media_type="text/plain")

    resp_headers = {}
    for key, value in resp.headers.items():
        if key.lower() not in _HOP_HEADERS and key.lower() != "content-encoding":
            resp_headers[key] = value

    return Response(content=resp.content, status_code=resp.status_code, headers=resp_headers)


# ── WebSocket Reverse Proxy ──────────────────

@router.websocket("/port/{zone_name}/{port}/ws/{subpath:path}")
async def proxy_ws(websocket: WebSocket, zone_name: str, port: int, subpath: str = ""):
    # Authenticate via query parameter
    user = get_ws_user(websocket)
    if not user:
        await websocket.close(code=4001, reason="Chưa đăng nhập")
        return

    try:
        _validate_proxy_access(zone_name, port)
        check_zone_owner(zone_name, user.sub, user.role)
    except ValueError:
        await websocket.close(code=4004, reason="Port not mapped")
        return

    await websocket.accept()
    target_url = f"ws://127.0.0.1:{port}/ws/{subpath}"

    import websockets as ws_lib

    try:
        async with ws_lib.connect(target_url) as backend_ws:
            async def client_to_backend():
                try:
                    while True:
                        msg = await websocket.receive()
                        if msg.get("type") == "websocket.disconnect":
                            break
                        if "text" in msg:
                            await backend_ws.send(msg["text"])
                        elif "bytes" in msg:
                            await backend_ws.send(msg["bytes"])
                except (WebSocketDisconnect, Exception):
                    pass

            async def backend_to_client():
                try:
                    async for message in backend_ws:
                        if isinstance(message, str):
                            await websocket.send_text(message)
                        else:
                            await websocket.send_bytes(message)
                except (WebSocketDisconnect, Exception):
                    pass

            await asyncio.gather(client_to_backend(), backend_to_client())
    except Exception:
        try:
            await websocket.send_text(json.dumps({"error": f"Cannot connect WebSocket to port {port}"}))
        except Exception:
            pass
    finally:
        try:
            await websocket.close()
        except Exception:
            pass