File size: 5,026 Bytes
2b49599
79f7888
 
 
51b6df5
1c4c2e7
3ba9db2
30321d8
1c4c2e7
30321d8
ddf0d56
1c4c2e7
b466a63
 
 
1c4c2e7
 
 
ddf0d56
3ba9db2
 
 
 
 
 
 
b466a63
ddf0d56
30321d8
 
b466a63
51b6df5
b466a63
ddf0d56
 
a95a8b3
 
 
 
 
 
 
 
b466a63
a95a8b3
 
 
 
 
 
 
 
 
 
ddf0d56
 
1c4c2e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b466a63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbdb569
b466a63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd9dc38
 
39fd04c
b466a63
 
 
ddf0d56
79f7888
 
 
51b6df5
 
 
 
 
79f7888
51b6df5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79f7888
 
1c4c2e7
 
 
 
 
 
 
51b6df5
 
 
 
 
 
 
1c4c2e7
 
 
 
 
ddf0d56
79f7888
b466a63
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
# app.py
import os
import signal
import subprocess
import sys
import time
import warnings
from pathlib import Path
from typing import Any

import gradio as gr
from fastapi import Request
from fastapi.responses import HTMLResponse
from fastapi.staticfiles import StaticFiles
from gradio import Server
from gradio.blocks import Blocks
from gradio.components import LoginButton
from gradio.events import api as gr_api

warnings.filterwarnings(
    "ignore",
    message='Field name "json"',
    category=UserWarning,
    module=r"firecrawl\.v2\.types",
)

from ui import server_api

ASSETS_DIR = Path(__file__).resolve().parent / "assets"

app = Server(title="Borderless - Immigration Research Agent")
# Replaced with the OAuth Blocks instance in _launch_with_oauth (Gradio hot reload).
demo = app


def _versioned_homepage_html() -> str:
    html = (ASSETS_DIR / "index.html").read_text(encoding="utf-8")
    for asset in ("app.js", "globe.js", "server.css", "globe.css"):
        version = int((ASSETS_DIR / asset).stat().st_mtime)
        html = html.replace(f"/assets/{asset}", f"/assets/{asset}?v={version}")
    return html


@app.get("/", response_class=HTMLResponse)
async def homepage() -> HTMLResponse:
    return HTMLResponse(
        content=_versioned_homepage_html(),
        headers={"Cache-Control": "no-cache, must-revalidate"},
    )


@app.get("/api/intake_choices")
def api_intake_choices() -> dict[str, Any]:
    return server_api.get_intake_choices()


@app.get("/api/auth/status")
async def auth_status(request: Request) -> dict[str, Any]:
    session = getattr(request, "session", None)
    if session is None:
        return {"logged_in": False}

    oauth_info = session.get("oauth_info")
    if not oauth_info:
        return {"logged_in": False}

    expires_at = oauth_info.get("expires_at")
    if expires_at is not None and expires_at <= time.time():
        session.pop("oauth_info", None)
        return {"logged_in": False}

    userinfo = oauth_info.get("userinfo") or {}
    return {
        "logged_in": True,
        "username": userinfo.get("preferred_username"),
        "name": userinfo.get("name"),
    }


@app.api(name="get_intake_choices")
def api_get_intake_choices() -> dict:
    return server_api.get_intake_choices()


@app.api(name="build_research_prompt")
def api_build_research_prompt(
    current_country: server_api.DropdownValue,
    residence_status: server_api.DropdownValue,
    education: server_api.DropdownValue,
    occupation: server_api.DropdownValue,
    experience: server_api.DropdownValue,
    budget: server_api.DropdownValue,
    family: server_api.DropdownValue,
    timeline: server_api.DropdownValue,
    goals: str,
) -> dict[str, str]:
    return server_api.build_research_prompt(
        current_country,
        residence_status,
        education,
        occupation,
        experience,
        budget,
        family,
        timeline,
        goals,
    )


@app.api(name="build_persona_prompt")
def api_build_persona_prompt(persona_id: str) -> str:
    return server_api.build_persona_prompt(persona_id)


@app.api(name="chat")
def api_chat(
    message: str,
    history: list[dict],
    globe_state: dict | None,
    hf_token: gr.OAuthToken | None,
) -> dict[str, Any]:
    """Stream chat updates. Return type enables Gradio SSE output registration."""
    yield from server_api.stream_chat(message, history, globe_state, hf_token)


app.mount("/assets", StaticFiles(directory=str(ASSETS_DIR)), name="assets")


def _cleanup_stale_gradio_nodes() -> None:
    """Terminate orphaned Gradio Node SSR workers from prior restarts."""
    patterns = (
        "gradio/templates/node/build",
        "gradio/templates/node/build/server/entry.js",
    )
    for pattern in patterns:
        try:
            result = subprocess.run(
                ["pgrep", "-f", pattern],
                capture_output=True,
                text=True,
                check=False,
            )
        except FileNotFoundError:
            return

        for pid_str in result.stdout.split():
            if not pid_str.strip():
                continue
            pid = int(pid_str)
            try:
                os.kill(pid, signal.SIGTERM)
            except ProcessLookupError:
                pass


def _launch_with_oauth(**kwargs: Any):
    """Launch with a hidden LoginButton so HF OAuth routes are registered."""
    with Blocks() as blocks:
        LoginButton(visible=False)
        for fn, api_kwargs in app._deferred_apis:
            gr_api(fn=fn, **api_kwargs)

        os.environ["GRADIO_SERVER_MODE_ENABLED"] = "1"
        # Gradio Spaces hot reload matches __main__.demo to the launched Blocks.
        main = sys.modules.get("__main__")
        if main is not None:
            main.demo = blocks
        globals()["demo"] = blocks
        return blocks.launch(_app=app, **kwargs)


app.launch = _launch_with_oauth  # type: ignore[method-assign]


if __name__ == "__main__":
    _cleanup_stale_gradio_nodes()
    app.launch(show_error=True)