|
|
| from __future__ import annotations
|
|
|
| import os
|
| from collections.abc import Iterator
|
| from typing import Any
|
|
|
| import gradio as gr
|
| from gradio import ChatMessage
|
|
|
| from ui.agent.config import GPU_DURATION, local_inference_enabled
|
| from ui.agent.graph import respond_with_graph
|
| from ui.agent.respond import respond as respond_legacy
|
| from ui.agent.system_prompt import BORDERLESS_SYSTEM_PROMPT
|
| from ui.chat.defaults import DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE, DEFAULT_TOP_P
|
| from ui.globe_commands import empty_globe_state
|
| from ui.intake.choices import (
|
| BUDGET_CHOICES,
|
| EDUCATION_CHOICES,
|
| EXPERIENCE_CHOICES,
|
| FAMILY_CHOICES,
|
| OCCUPATION_CHOICES,
|
| RESIDENCE_STATUS_CHOICES,
|
| TIMELINE_CHOICES,
|
| country_choices,
|
| )
|
| from ui.intake.examples import demo_personas, persona_prompt
|
| from ui.intake.prompts import build_profile_prompt, build_session_title
|
|
|
| DropdownValue = str | int | float | list[str | int | float] | None
|
|
|
| AGENT_MODE = os.environ.get("BORDERLESS_AGENT_MODE", "graph").strip().lower()
|
|
|
|
|
| def _respond_fn():
|
| if AGENT_MODE == "legacy":
|
| return respond_legacy
|
| return respond_with_graph
|
|
|
|
|
| def _chat_message_to_dict(message: ChatMessage) -> dict[str, Any]:
|
| payload: dict[str, Any] = {
|
| "role": message.role,
|
| "content": message.content,
|
| }
|
| if message.metadata:
|
| payload["metadata"] = message.metadata
|
| return payload
|
|
|
|
|
| def get_intake_choices() -> dict[str, Any]:
|
| return {
|
| "countries": country_choices(),
|
| "education": EDUCATION_CHOICES,
|
| "occupation": OCCUPATION_CHOICES,
|
| "experience": EXPERIENCE_CHOICES,
|
| "budget": BUDGET_CHOICES,
|
| "family": FAMILY_CHOICES,
|
| "timeline": TIMELINE_CHOICES,
|
| "residence_status": RESIDENCE_STATUS_CHOICES,
|
| "personas": demo_personas(),
|
| }
|
|
|
|
|
| def build_research_prompt(
|
| current_country: DropdownValue,
|
| residence_status: DropdownValue,
|
| education: DropdownValue,
|
| occupation: DropdownValue,
|
| experience: DropdownValue,
|
| budget: DropdownValue,
|
| family: DropdownValue,
|
| timeline: DropdownValue,
|
| goals: str,
|
| ) -> dict[str, str]:
|
| prompt = build_profile_prompt(
|
| None,
|
| current_country,
|
| residence_status,
|
| education,
|
| occupation,
|
| experience,
|
| None,
|
| budget,
|
| family,
|
| timeline,
|
| goals,
|
| )
|
| return {
|
| "text": str(prompt.get("text") or ""),
|
| "title": build_session_title(
|
| current_country,
|
| residence_status,
|
| education,
|
| occupation,
|
| experience,
|
| budget,
|
| family,
|
| timeline,
|
| goals,
|
| ),
|
| }
|
|
|
|
|
| def build_persona_prompt(persona_id: str) -> str:
|
| for persona in demo_personas():
|
| if persona["id"] == persona_id:
|
| prompt = persona_prompt(persona["profile"])
|
| return str(prompt.get("text") or "")
|
| return ""
|
|
|
|
|
| def _merge_chat_history(
|
| history: list[dict[str, Any]],
|
| message: str,
|
| ui_messages: list[ChatMessage],
|
| assistant_text: str,
|
| ) -> list[dict[str, Any]]:
|
| updated_history = list(history)
|
| updated_history.append({"role": "user", "content": message})
|
| for ui_message in ui_messages:
|
| updated_history.append(_chat_message_to_dict(ui_message))
|
| if assistant_text and (
|
| not ui_messages
|
| or ui_messages[-1].role != "assistant"
|
| or ui_messages[-1].metadata
|
| ):
|
| updated_history.append({"role": "assistant", "content": assistant_text})
|
| return updated_history
|
|
|
|
|
| def _stream_agent_response(
|
| message: str,
|
| history: list[dict[str, Any]],
|
| globe_state: dict[str, Any],
|
| hf_token: gr.OAuthToken | None,
|
| ) -> Iterator[tuple[Any, dict[str, Any]]]:
|
| yield from _respond_fn()(
|
| message,
|
| history,
|
| BORDERLESS_SYSTEM_PROMPT,
|
| DEFAULT_MAX_TOKENS,
|
| DEFAULT_TEMPERATURE,
|
| DEFAULT_TOP_P,
|
| globe_state,
|
| hf_token,
|
| )
|
|
|
|
|
| def stream_chat(
|
| message: str,
|
| history: list[dict[str, Any]],
|
| globe_state: dict[str, Any] | None,
|
| hf_token: gr.OAuthToken | None,
|
| ):
|
| state = globe_state if globe_state else empty_globe_state()
|
| ui_messages: list[ChatMessage] = []
|
| assistant_text = ""
|
|
|
| if local_inference_enabled() and AGENT_MODE != "legacy":
|
| import spaces
|
|
|
| @spaces.GPU(duration=GPU_DURATION)
|
| def _run_agent_graph_gpu(
|
| message: str,
|
| history: list[dict[str, Any]],
|
| globe_state: dict[str, Any],
|
| hf_token: gr.OAuthToken | None,
|
| ) -> Iterator[tuple[Any, dict[str, Any]]]:
|
| yield from _stream_agent_response(message, history, globe_state, hf_token)
|
|
|
| agent_stream = _run_agent_graph_gpu(message, history, state, hf_token)
|
| else:
|
| agent_stream = _stream_agent_response(message, history, state, hf_token)
|
|
|
| for chunk in agent_stream:
|
| payload, state = chunk
|
| if isinstance(payload, list):
|
| ui_messages = payload
|
| elif isinstance(payload, str):
|
| assistant_text = payload
|
|
|
| yield {
|
| "history": _merge_chat_history(
|
| history, message, ui_messages, assistant_text
|
| ),
|
| "globe_state": state,
|
| "done": False,
|
| }
|
|
|
| yield {
|
| "history": _merge_chat_history(history, message, ui_messages, assistant_text),
|
| "globe_state": state,
|
| "done": True,
|
| }
|
|
|