File size: 5,683 Bytes
b466a63 b23426c 147e220 b466a63 147e220 b23426c b466a63 dbdb569 b466a63 b23426c b466a63 dbdb569 b466a63 0c56308 b466a63 0c56308 b466a63 dbdb569 b466a63 39fd04c 147e220 b466a63 147e220 b466a63 147e220 b466a63 147e220 b466a63 147e220 71a7158 147e220 b466a63 39fd04c b466a63 39fd04c b466a63 39fd04c b466a63 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | # ui/server_api.py
from __future__ import annotations
import os
from collections.abc import Iterator
from typing import Any
import gradio as gr
from gradio import ChatMessage
from ui.agent.config import GPU_DURATION, local_inference_enabled
from ui.agent.graph import respond_with_graph
from ui.agent.respond import respond as respond_legacy
from ui.agent.system_prompt import BORDERLESS_SYSTEM_PROMPT
from ui.chat.defaults import DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE, DEFAULT_TOP_P
from ui.globe_commands import empty_globe_state
from ui.intake.choices import (
BUDGET_CHOICES,
EDUCATION_CHOICES,
EXPERIENCE_CHOICES,
FAMILY_CHOICES,
OCCUPATION_CHOICES,
RESIDENCE_STATUS_CHOICES,
TIMELINE_CHOICES,
country_choices,
)
from ui.intake.examples import demo_personas, persona_prompt
from ui.intake.prompts import build_profile_prompt, build_session_title
DropdownValue = str | int | float | list[str | int | float] | None
AGENT_MODE = os.environ.get("BORDERLESS_AGENT_MODE", "graph").strip().lower()
def _respond_fn():
if AGENT_MODE == "legacy":
return respond_legacy
return respond_with_graph
def _chat_message_to_dict(message: ChatMessage) -> dict[str, Any]:
payload: dict[str, Any] = {
"role": message.role,
"content": message.content,
}
if message.metadata:
payload["metadata"] = message.metadata
return payload
def get_intake_choices() -> dict[str, Any]:
return {
"countries": country_choices(),
"education": EDUCATION_CHOICES,
"occupation": OCCUPATION_CHOICES,
"experience": EXPERIENCE_CHOICES,
"budget": BUDGET_CHOICES,
"family": FAMILY_CHOICES,
"timeline": TIMELINE_CHOICES,
"residence_status": RESIDENCE_STATUS_CHOICES,
"personas": demo_personas(),
}
def build_research_prompt(
current_country: DropdownValue,
residence_status: DropdownValue,
education: DropdownValue,
occupation: DropdownValue,
experience: DropdownValue,
budget: DropdownValue,
family: DropdownValue,
timeline: DropdownValue,
goals: str,
) -> dict[str, str]:
prompt = build_profile_prompt(
None,
current_country,
residence_status,
education,
occupation,
experience,
None,
budget,
family,
timeline,
goals,
)
return {
"text": str(prompt.get("text") or ""),
"title": build_session_title(
current_country,
residence_status,
education,
occupation,
experience,
budget,
family,
timeline,
goals,
),
}
def build_persona_prompt(persona_id: str) -> str:
for persona in demo_personas():
if persona["id"] == persona_id:
prompt = persona_prompt(persona["profile"])
return str(prompt.get("text") or "")
return ""
def _merge_chat_history(
history: list[dict[str, Any]],
message: str,
ui_messages: list[ChatMessage],
assistant_text: str,
) -> list[dict[str, Any]]:
updated_history = list(history)
updated_history.append({"role": "user", "content": message})
for ui_message in ui_messages:
updated_history.append(_chat_message_to_dict(ui_message))
if assistant_text and (
not ui_messages
or ui_messages[-1].role != "assistant"
or ui_messages[-1].metadata
):
updated_history.append({"role": "assistant", "content": assistant_text})
return updated_history
def _stream_agent_response(
message: str,
history: list[dict[str, Any]],
globe_state: dict[str, Any],
hf_token: gr.OAuthToken | None,
) -> Iterator[tuple[Any, dict[str, Any]]]:
yield from _respond_fn()(
message,
history,
BORDERLESS_SYSTEM_PROMPT,
DEFAULT_MAX_TOKENS,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
globe_state,
hf_token,
)
def stream_chat(
message: str,
history: list[dict[str, Any]],
globe_state: dict[str, Any] | None,
hf_token: gr.OAuthToken | None,
):
state = globe_state if globe_state else empty_globe_state()
ui_messages: list[ChatMessage] = []
assistant_text = ""
if local_inference_enabled() and AGENT_MODE != "legacy":
import spaces
@spaces.GPU(duration=GPU_DURATION)
def _run_agent_graph_gpu(
message: str,
history: list[dict[str, Any]],
globe_state: dict[str, Any],
hf_token: gr.OAuthToken | None,
) -> Iterator[tuple[Any, dict[str, Any]]]:
yield from _stream_agent_response(message, history, globe_state, hf_token)
agent_stream = _run_agent_graph_gpu(message, history, state, hf_token)
else:
agent_stream = _stream_agent_response(message, history, state, hf_token)
for chunk in agent_stream:
payload, state = chunk
if isinstance(payload, list):
ui_messages = payload
elif isinstance(payload, str):
assistant_text = payload
yield {
"history": _merge_chat_history(
history, message, ui_messages, assistant_text
),
"globe_state": state,
"done": False,
}
yield {
"history": _merge_chat_history(history, message, ui_messages, assistant_text),
"globe_state": state,
"done": True,
}
|