| |
| import os |
| import signal |
| import subprocess |
| import sys |
| import time |
| import warnings |
| from pathlib import Path |
| from typing import Any |
|
|
| import gradio as gr |
| from fastapi import Request |
| from fastapi.responses import HTMLResponse |
| from fastapi.staticfiles import StaticFiles |
| from gradio import Server |
| from gradio.blocks import Blocks |
| from gradio.components import LoginButton |
| from gradio.events import api as gr_api |
|
|
| warnings.filterwarnings( |
| "ignore", |
| message='Field name "json"', |
| category=UserWarning, |
| module=r"firecrawl\.v2\.types", |
| ) |
|
|
| from ui import server_api |
|
|
| ASSETS_DIR = Path(__file__).resolve().parent / "assets" |
|
|
| app = Server(title="Borderless - Immigration Research Agent") |
| |
| demo = app |
|
|
|
|
| def _versioned_homepage_html() -> str: |
| html = (ASSETS_DIR / "index.html").read_text(encoding="utf-8") |
| for asset in ("app.js", "globe.js", "server.css", "globe.css"): |
| version = int((ASSETS_DIR / asset).stat().st_mtime) |
| html = html.replace(f"/assets/{asset}", f"/assets/{asset}?v={version}") |
| return html |
|
|
|
|
| @app.get("/", response_class=HTMLResponse) |
| async def homepage() -> HTMLResponse: |
| return HTMLResponse( |
| content=_versioned_homepage_html(), |
| headers={"Cache-Control": "no-cache, must-revalidate"}, |
| ) |
|
|
|
|
| @app.get("/api/intake_choices") |
| def api_intake_choices() -> dict[str, Any]: |
| return server_api.get_intake_choices() |
|
|
|
|
| @app.get("/api/auth/status") |
| async def auth_status(request: Request) -> dict[str, Any]: |
| session = getattr(request, "session", None) |
| if session is None: |
| return {"logged_in": False} |
|
|
| oauth_info = session.get("oauth_info") |
| if not oauth_info: |
| return {"logged_in": False} |
|
|
| expires_at = oauth_info.get("expires_at") |
| if expires_at is not None and expires_at <= time.time(): |
| session.pop("oauth_info", None) |
| return {"logged_in": False} |
|
|
| userinfo = oauth_info.get("userinfo") or {} |
| return { |
| "logged_in": True, |
| "username": userinfo.get("preferred_username"), |
| "name": userinfo.get("name"), |
| } |
|
|
|
|
| @app.api(name="get_intake_choices") |
| def api_get_intake_choices() -> dict: |
| return server_api.get_intake_choices() |
|
|
|
|
| @app.api(name="build_research_prompt") |
| def api_build_research_prompt( |
| current_country: server_api.DropdownValue, |
| residence_status: server_api.DropdownValue, |
| education: server_api.DropdownValue, |
| occupation: server_api.DropdownValue, |
| experience: server_api.DropdownValue, |
| budget: server_api.DropdownValue, |
| family: server_api.DropdownValue, |
| timeline: server_api.DropdownValue, |
| goals: str, |
| ) -> dict[str, str]: |
| return server_api.build_research_prompt( |
| current_country, |
| residence_status, |
| education, |
| occupation, |
| experience, |
| budget, |
| family, |
| timeline, |
| goals, |
| ) |
|
|
|
|
| @app.api(name="build_persona_prompt") |
| def api_build_persona_prompt(persona_id: str) -> str: |
| return server_api.build_persona_prompt(persona_id) |
|
|
|
|
| @app.api(name="chat") |
| def api_chat( |
| message: str, |
| history: list[dict], |
| globe_state: dict | None, |
| hf_token: gr.OAuthToken | None, |
| ) -> dict[str, Any]: |
| """Stream chat updates. Return type enables Gradio SSE output registration.""" |
| yield from server_api.stream_chat(message, history, globe_state, hf_token) |
|
|
|
|
| app.mount("/assets", StaticFiles(directory=str(ASSETS_DIR)), name="assets") |
|
|
|
|
| def _cleanup_stale_gradio_nodes() -> None: |
| """Terminate orphaned Gradio Node SSR workers from prior restarts.""" |
| patterns = ( |
| "gradio/templates/node/build", |
| "gradio/templates/node/build/server/entry.js", |
| ) |
| for pattern in patterns: |
| try: |
| result = subprocess.run( |
| ["pgrep", "-f", pattern], |
| capture_output=True, |
| text=True, |
| check=False, |
| ) |
| except FileNotFoundError: |
| return |
|
|
| for pid_str in result.stdout.split(): |
| if not pid_str.strip(): |
| continue |
| pid = int(pid_str) |
| try: |
| os.kill(pid, signal.SIGTERM) |
| except ProcessLookupError: |
| pass |
|
|
|
|
| def _launch_with_oauth(**kwargs: Any): |
| """Launch with a hidden LoginButton so HF OAuth routes are registered.""" |
| with Blocks() as blocks: |
| LoginButton(visible=False) |
| for fn, api_kwargs in app._deferred_apis: |
| gr_api(fn=fn, **api_kwargs) |
|
|
| os.environ["GRADIO_SERVER_MODE_ENABLED"] = "1" |
| |
| main = sys.modules.get("__main__") |
| if main is not None: |
| main.demo = blocks |
| globals()["demo"] = blocks |
| return blocks.launch(_app=app, **kwargs) |
|
|
|
|
| app.launch = _launch_with_oauth |
|
|
|
|
| if __name__ == "__main__": |
| _cleanup_stale_gradio_nodes() |
| app.launch(show_error=True) |
|
|