""" Universal multi-port reverse proxy + password gate for HF Spaces. Routes: / → Dashboard (lists known ports, add new) /__probe__/{port} → JSON probe: detect what app runs on a port /{port}/... → Proxied to localhost:{port} /login, /health → Auth gate + healthcheck All routes protected by the same password gate (HMAC cookie session). Designed to transparently host many app types behind a path-based proxy: JupyterLab, VS Code Server, Laravel, Next.js, React dev server, etc. Environment variables: password – gate password (empty = no auth required) SESSION_SECRET – HMAC signing key (auto-generated if unset) """ import asyncio import hmac import os import re import secrets import time from aiohttp import web, ClientSession, WSMsgType, ClientTimeout, TCPConnector # ── Configuration ──────────────────────────────────────────────────── PASSWORD = os.environ.get("password", "") SESSION_SECRET = os.environ.get("SESSION_SECRET") or secrets.token_hex(32) print("PASSWORD loaded: {} (len={})".format("yes" if PASSWORD else "EMPTY", len(PASSWORD)), flush=True) COOKIE_NAME = "jupiter_session" COOKIE_MAX_AGE = 7 * 24 * 3600 # 7 days CACHE_MAX_ITEMS = 512 CACHE_TTL = 300 # 5 min # NOTE: .json deliberately NOT cached — many APIs serve dynamic JSON. CACHEABLE_EXTENSIONS = { '.js', '.css', '.woff', '.woff2', '.ttf', '.eot', '.svg', '.png', '.jpg', '.jpeg', '.gif', '.ico', '.map', } MAX_CONNECTIONS = 200 MAX_PER_HOST = 100 CONNECTOR_KEEPALIVE = 30 # Regex to detect /{port}/... paths PORT_RE = re.compile(r'^/(\d{2,5})(?P/.*)?$') # ── In-memory static cache ────────────────────────────────────────── class StaticCache: """LRU-ish cache for truly static assets (js/css/fonts/images).""" def __init__(self, max_items=CACHE_MAX_ITEMS, ttl=CACHE_TTL): self.max_items = max_items self.ttl = ttl self._cache = {} def _is_cacheable(self, path, status, headers): """Decide if a response may be cached.""" if status != 200: return False lower = path.lower() if not any(lower.endswith(ext) for ext in CACHEABLE_EXTENSIONS): return False # Respect explicit cache directives. cc = headers.get("Cache-Control", "").lower() if any(d in cc for d in ("no-store", "no-cache", "private")): return False # Never cache responses that set cookies or vary on auth. if headers.get("Set-Cookie"): return False vary = headers.get("Vary", "").lower() if "authorization" in vary or "cookie" in vary: return False return True def get(self, key): entry = self._cache.get(key) if not entry: return None ts, status, headers, body = entry if time.time() - ts < self.ttl: return status, headers, body del self._cache[key] return None def put(self, key, status, headers, body): if not self._is_cacheable(key, status, headers): return if len(self._cache) >= self.max_items: oldest_key = min(self._cache, key=lambda k: self._cache[k][0]) del self._cache[oldest_key] self._cache[key] = (time.time(), status, headers, body) # ── Session helpers ────────────────────────────────────────────────── def _sign(value): return hmac.new(SESSION_SECRET.encode(), value.encode(), "sha256").hexdigest() def _make_cookie_value(): sid = secrets.token_hex(16) return "{}:{}".format(sid, _sign(sid)) def _valid_session(value): if not value or ":" not in value: return False sid, sig = value.split(":", 1) return hmac.compare_digest(_sign(sid), sig) # ── Cookie forwarding helpers ──────────────────────────────────────── def filter_request_cookies(cookie_header): """Strip the proxy's own session cookie before forwarding to upstream. Apps behind the proxy shouldn't see jupiter_session. """ if not cookie_header: return "" kept = [] for part in cookie_header.split(";"): part = part.strip() if not part: continue if part.startswith(COOKIE_NAME + "="): continue kept.append(part) return "; ".join(kept) def rewrite_set_cookie(value, port): """Scope a Set-Cookie header under /{port}/ so cookies don't leak. Rewrites Path=/ → Path=/{port}/ and strips Domain= attribute. """ if not value: return value attrs = value.split(";") new_attrs = [] for attr in attrs: a = attr.strip() low = a.lower() if low.startswith("domain="): continue # drop domain scoping if low.startswith("path="): path_val = a[len("path="):] # Normalise to the port-prefixed path. expected = "/{}/".format(port) if path_val == "/" or path_val == "": new_attrs.append("Path=" + expected) elif path_val.startswith(expected): new_attrs.append(a) # already prefixed else: # Sub-path like /foo → /{port}/foo new_attrs.append("Path=" + expected + path_val.lstrip("/")) else: new_attrs.append(a) return "; ".join(new_attrs) # ── Body URL rewriting (fixes absolute URLs in HTML/CSS) ──────────── # Matches: attr="/foo" attr='/foo' attr=/foo url(/foo) but NOT // http: https: data: _URL_ATTR_RE = re.compile( r'''(?P
(?:href|src|action|poster|formaction|data-src|srcset)\s*=\s*)(?P["']?)(?P/[^\s"'>)]*)(?P)''',
    re.IGNORECASE,
)
_CSS_URL_RE = re.compile(
    r'''(url\(\s*['"]?)(/[^\s'")]+)''',
    re.IGNORECASE,
)


def rewrite_body_urls(body_bytes, port):
    """Prefix absolute root-relative URLs with /{port}/ in HTML/CSS.

    Skips protocol-relative (//), absolute (http://), and data: URIs.
    """
    try:
        text = body_bytes.decode("utf-8", errors="replace")
    except Exception:
        return body_bytes  # binary — don't touch

    prefix = "/{}".format(port)
    # Don't double-prefix paths that already start with /{port}/
    def _pref(m):
        url = m.group("url")
        if url.startswith("//") or url.startswith(prefix + "/") or url == prefix:
            return m.group(0)
        return m.group("pre") + m.group("q") + prefix + url

    def _css(m):
        url = m.group(2)
        if url.startswith("//") or url.startswith(prefix + "/"):
            return m.group(0)
        return m.group(1) + prefix + url

    text = _URL_ATTR_RE.sub(_pref, text)
    text = _CSS_URL_RE.sub(_css, text)
    return text.encode("utf-8")


# ── Middleware ─────────────────────────────────────────────────────────
@web.middleware
async def auth_middleware(request, handler):
    path = request.path

    # Public routes
    if path in ("/login", "/health", "/favicon.ico") or path == "/":
        return await handler(request)

    # Probe endpoint requires auth too (don't leak what's running)
    if path.startswith("/__probe__/"):
        if PASSWORD and not _valid_session(request.cookies.get(COOKIE_NAME)):
            raise web.HTTPFound("/login?next={}".format(path))
        return await handler(request)

    # Everything else requires session cookie
    if PASSWORD and not _valid_session(request.cookies.get(COOKIE_NAME)):
        raise web.HTTPFound("/login?next={}".format(path))

    return await handler(request)


# ── Route handlers: auth + dashboard ────────────────────────────────
async def health(request):
    return web.Response(text="ok")


async def dashboard(request):
    return web.Response(text=DASHBOARD_HTML, content_type="text/html")


async def login_page(request):
    filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), "login.html")
    with open(filepath, "r", encoding="utf-8") as f:
        return web.Response(text=f.read(), content_type="text/html")


async def login_submit(request):
    data = await request.post()
    submitted = data.get("password", "")
    if not PASSWORD or hmac.compare_digest(submitted, PASSWORD):
        next_url = data.get("next") or request.query.get("next") or "/"
        resp = web.Response(status=302, headers={"Location": next_url})
        resp.set_cookie(
            COOKIE_NAME,
            _make_cookie_value(),
            max_age=COOKIE_MAX_AGE,
            httponly=True,
            samesite="Lax",
            secure=False,
            path="/",
        )
        return resp
    return web.Response(status=302, headers={"Location": "/login?error=1"})


# ── Probe endpoint: detect what app runs on a port ──────────────────
def detect_app_type(status, headers):
    """Guess the app framework from response headers."""
    powered = headers.get("X-Powered-By", "").lower()
    server = headers.get("Server", "").lower()
    set_cookie = headers.get("Set-Cookie", "").lower()
    location = headers.get("Location", "").lower()

    if "next" in powered:
        return "nextjs", "Next.js", "\u2705"
    if "laravel" in set_cookie or "php" in powered:
        return "laravel", "Laravel", "\U0001F418"
    if "express" in powered or "webpack" in powered:
        return "react", "React/Node", "\u269B\uFE0F"
    if "code-server" in server or "vscode" in server:
        return "vscode", "VS Code Server", "\U0001F4BB"
    if "tornado" in server or "jupyter" in set_cookie:
        return "jupyter", "JupyterLab", "\U0001F4D3"
    if "vite" in powered or "vite" in server:
        return "vite", "Vite", "\u26A1"
    return "unknown", "Unknown", "\u2753"


async def probe(request):
    """GET /__probe__/{port} → JSON describing what's on that port."""
    port_str = request.match_info["port"]
    try:
        port = int(port_str)
    except ValueError:
        return web.json_response({"error": "invalid port"}, status=400)

    upstream = "http://127.0.0.1:{}/".format(port)
    try:
        session = request.app["http_session"]
        # Try HEAD first (cheap), fall back to GET.
        async with session.get(upstream, allow_redirects=False, timeout=ClientTimeout(total=8)) as resp:
            status = resp.status
            headers = resp.headers
            await resp.read()
    except Exception as e:
        return web.json_response({
            "port": port,
            "status": "offline",
            "error": str(e),
        })

    app_type, app_label, icon = detect_app_type(status, headers)
    return web.json_response({
        "port": port,
        "status": "online",
        "http_status": status,
        "app_type": app_type,
        "app_label": app_label,
        "icon": icon,
        "server": headers.get("Server", ""),
        "powered_by": headers.get("X-Powered-By", ""),
    })


# ── Catch-all: parse port and route ──────────────────────────────────
async def catch_all(request):
    """Parse /{port}/... and dispatch; otherwise serve dashboard."""
    match = PORT_RE.match(request.path)
    if match:
        port = int(match.group(1))
        rest = match.group("rest") or "/"
        return await port_proxy(request, port, rest)
    return await dashboard(request)


# ── Port proxy (HTTP + WebSocket) ────────────────────────────────────
async def port_proxy(request, port, rest_path):
    """Dispatch a request to localhost:{port}."""
    is_ws = (request.headers.get("upgrade", "").lower() == "websocket")

    upstream_http = "http://127.0.0.1:{}".format(port)
    upstream_ws = "ws://127.0.0.1:{}".format(port)

    # JupyterLab has base_url='/8888/' configured, so keep its prefix.
    if port == 8888:
        forward_path = request.path
        keep_prefix = True
    else:
        forward_path = rest_path
        keep_prefix = False

    if is_ws:
        return await ws_relay(request, upstream_ws, forward_path, port)
    return await http_proxy(request, port, upstream_http, forward_path, keep_prefix)


# ── WebSocket relay (with subprotocol + Origin rewrite) ─────────────
async def ws_relay(request, upstream_base, path, port):
    target_url = upstream_base + path
    if request.query_string:
        target_url += "?" + request.query_string

    client_ws = web.WebSocketResponse()
    await client_ws.prepare(request)

    # Forward subprotocols and extensions so VS Code Server & friends connect.
    subprotocols = []
    proto_header = request.headers.get("Sec-WebSocket-Protocol", "")
    if proto_header:
        subprotocols = [p.strip() for p in proto_header.split(",") if p.strip()]

    # Build forwarded headers: rewrite Origin so upstream trusts us.
    fwd_headers = {
        "Origin": "http://127.0.0.1:{}".format(port),
    }
    # Preserve Sec-WebSocket-Extensions (permessage-deflate etc.)
    ext = request.headers.get("Sec-WebSocket-Extensions")
    if ext:
        fwd_headers["Sec-WebSocket-Extensions"] = ext

    try:
        kwargs = {
            "autoclose": False,
            "autoping": True,
            "heartbeat": 30,
        }
        if subprotocols:
            kwargs["protocols"] = subprotocols

        async with request.app["http_session"].ws_connect(
            target_url, headers=fwd_headers, **kwargs
        ) as upstream_ws:

            async def c2u():
                try:
                    async for msg in client_ws:
                        if msg.type == WSMsgType.TEXT:
                            await upstream_ws.send_str(msg.data)
                        elif msg.type == WSMsgType.BINARY:
                            await upstream_ws.send_bytes(msg.data)
                        elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING,
                                          WSMsgType.CLOSED, WSMsgType.ERROR):
                            break
                except Exception:
                    pass
                finally:
                    if not upstream_ws.closed:
                        await upstream_ws.close()

            async def u2c():
                try:
                    async for msg in upstream_ws:
                        if msg.type == WSMsgType.TEXT:
                            await client_ws.send_str(msg.data)
                        elif msg.type == WSMsgType.BINARY:
                            await client_ws.send_bytes(msg.data)
                        elif msg.type in (WSMsgType.CLOSE, WSMsgType.CLOSING,
                                          WSMsgType.CLOSED, WSMsgType.ERROR):
                            break
                except Exception:
                    pass
                finally:
                    if not client_ws.closed:
                        await client_ws.close()

            await asyncio.gather(c2u(), u2c(), return_exceptions=True)
    except Exception as e:
        print("WS RELAY ERROR {} port={}: {}".format(request.path, port, e), flush=True)
    finally:
        if not client_ws.closed:
            await client_ws.close()

    return client_ws


# ── HTTP reverse proxy (streaming + rewriting) ─────────────────────
async def http_proxy(request, port, upstream_base, path, keep_prefix=False):
    target_url = upstream_base + path
    if request.query_string:
        target_url += "?" + request.query_string

    # Static cache lookup for GET
    cache = request.app["static_cache"]
    cache_key = "{}{}".format(port, path)
    if request.method == "GET":
        cached = cache.get(cache_key)
        if cached:
            status, headers, body = cached
            content_type = headers.get("Content-Type", "")
            if "text/html" in content_type or "text/css" in content_type:
                body = rewrite_body_urls(body, port)
            stream = web.StreamResponse(status=status, headers=dict(headers))
            stream.content_length = len(body)
            await stream.prepare(request)
            await stream.write(body)
            await stream.write_eof()
            return stream

    # Forward request headers (drop hop-by-hop + proxy-specific).
    skip = {"host", "connection", "keep-alive", "transfer-encoding",
            "upgrade", "content-length", "accept-encoding"}
    headers = {k: v for k, v in request.headers.items() if k.lower() not in skip}

    # Forward app cookies (strip our own session cookie).
    cookie_val = filter_request_cookies(request.headers.get("Cookie", ""))
    if cookie_val:
        headers["Cookie"] = cookie_val

    # Rewrite Referer to strip the /{port}/ prefix so upstream sees clean paths.
    referer = headers.get("Referer")
    if referer:
        prefix = "/{}/".format(port)
        # Replace any /{port}/ occurrence in the referer path.
        try:
            from urllib.parse import urlsplit, urlunsplit
            parts = urlsplit(referer)
            new_path = parts.path
            if new_path.startswith(prefix):
                new_path = "/" + new_path[len(prefix):]
            headers["Referer"] = urlunsplit(
                (parts.scheme, parts.netloc, new_path, parts.query, parts.fragment)
            )
        except Exception:
            pass

    # Determine if body needs streaming (non-empty) — stream it to upstream.
    body_iter = None
    if request.can_read_body:
        body_iter = request.content  # async stream of bytes

    try:
        session = request.app["http_session"]
        async with session.request(
            request.method,
            target_url,
            headers=headers,
            data=body_iter,
            allow_redirects=False,
        ) as upstream:
            # Build response headers, dropping hop-by-hop + content-encoding.
            skip_resp = {"transfer-encoding", "connection", "keep-alive",
                         "content-encoding", "content-length"}
            resp_headers = {}
            for k, v in upstream.headers.items():
                if k.lower() in skip_resp:
                    continue
                if k.lower() == "set-cookie":
                    # Scope cookie under /{port}/
                    resp_headers[k] = rewrite_set_cookie(v, port)
                    continue
                if k.lower() == "location" and not keep_prefix:
                    # Rewrite redirect Location (both absolute path & localhost URL).
                    loc = v
                    if loc.startswith("/") and not loc.startswith("//"):
                        resp_headers[k] = "/{}{}".format(port, loc)
                    elif loc.startswith(upstream_base):
                        rest = loc[len(upstream_base):]
                        if rest and not rest.startswith("/"):
                            rest = "/" + rest
                        resp_headers[k] = "/{}{}".format(port, rest)
                    else:
                        resp_headers[k] = v
                    continue
                resp_headers[k] = v

            # Decide: stream as-is, or rewrite body for HTML/CSS?
            content_type = upstream.headers.get("Content-Type", "").lower()
            should_rewrite = ("text/html" in content_type or "text/css" in content_type)

            if not should_rewrite:
                # ── Pure streaming path (SSE, large files, API, everything else)
                stream = web.StreamResponse(status=upstream.status, headers=resp_headers)
                # Disable any upstream/edge buffering for SSE / HMR.
                stream.headers["X-Accel-Buffering"] = "no"
                await stream.prepare(request)
                async for chunk in upstream.content.iter_any():
                    await stream.write(chunk)
                await stream.write_eof()
                return stream

            # ── Rewrite path: buffer HTML/CSS, rewrite URLs, then send.
            body_bytes = await upstream.read()
            body_bytes = rewrite_body_urls(body_bytes, port)
            # Cache static asset (mostly .css) if eligible.
            if request.method == "GET":
                cache.put(cache_key, upstream.status,
                          {k: v for k, v in resp_headers.items()}, body_bytes)
            return web.Response(
                status=upstream.status,
                headers=resp_headers,
                body=body_bytes,
            )
    except Exception as e:
        print("PROXY ERROR {} port={}: {}".format(request.method, port, e), flush=True)
        return web.Response(
            status=502,
            text="Bad Gateway: Cannot reach port {}. Is an app running on it? ({})".format(port, e),
        )


# ── App lifecycle ───────────────────────────────────────────────────
async def on_startup(app):
    connector = TCPConnector(
        limit=MAX_CONNECTIONS,
        limit_per_host=MAX_PER_HOST,
        keepalive_timeout=CONNECTOR_KEEPALIVE,
        enable_cleanup_closed=True,
    )
    # total=None so SSE/HMR long-poll connections never time out at the
    # client-session level; connect timeouts stay bounded.
    session = ClientSession(
        connector=connector,
        timeout=ClientTimeout(total=None, connect=10, sock_connect=10, sock_read=None),
    )
    app["http_session"] = session
    app["static_cache"] = StaticCache()
    print("Proxy pool: max_conn={} keepalive={}s cache_items={}".format(
        MAX_CONNECTIONS, CONNECTOR_KEEPALIVE, CACHE_MAX_ITEMS), flush=True)


async def on_cleanup(app):
    await app["http_session"].close()


# ── Dashboard HTML (defined at import time so handlers can reference it) ──
DASHBOARD_HTML = """




Jupiter — Dashboard



  

Jupiter

Multi-Port Proxy Dashboard · auto-detect
""" def create_app(): app = web.Application(middlewares=[auth_middleware]) app.on_startup.append(on_startup) app.on_cleanup.append(on_cleanup) app.router.add_get("/health", health) app.router.add_get("/login", login_page) app.router.add_post("/login", login_submit) app.router.add_get("/__probe__/{port}", probe) # Catch-all: dashboard or /{port}/... app.router.add_route("*", "/{tail:.*}", catch_all) return app if __name__ == "__main__": app = create_app() print("Proxy starting on http://0.0.0.0:7860", flush=True) web.run_app(app, host="0.0.0.0", port=7860, print=None, handle_signals=True)