borderless / ui /server_api.py
spagestic's picture
switched to qwen
71a7158
Raw
History Blame Contribute Delete
5.68 kB
# ui/server_api.py
from __future__ import annotations
import os
from collections.abc import Iterator
from typing import Any
import gradio as gr
from gradio import ChatMessage
from ui.agent.config import GPU_DURATION, local_inference_enabled
from ui.agent.graph import respond_with_graph
from ui.agent.respond import respond as respond_legacy
from ui.agent.system_prompt import BORDERLESS_SYSTEM_PROMPT
from ui.chat.defaults import DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE, DEFAULT_TOP_P
from ui.globe_commands import empty_globe_state
from ui.intake.choices import (
BUDGET_CHOICES,
EDUCATION_CHOICES,
EXPERIENCE_CHOICES,
FAMILY_CHOICES,
OCCUPATION_CHOICES,
RESIDENCE_STATUS_CHOICES,
TIMELINE_CHOICES,
country_choices,
)
from ui.intake.examples import demo_personas, persona_prompt
from ui.intake.prompts import build_profile_prompt, build_session_title
DropdownValue = str | int | float | list[str | int | float] | None
AGENT_MODE = os.environ.get("BORDERLESS_AGENT_MODE", "graph").strip().lower()
def _respond_fn():
if AGENT_MODE == "legacy":
return respond_legacy
return respond_with_graph
def _chat_message_to_dict(message: ChatMessage) -> dict[str, Any]:
payload: dict[str, Any] = {
"role": message.role,
"content": message.content,
}
if message.metadata:
payload["metadata"] = message.metadata
return payload
def get_intake_choices() -> dict[str, Any]:
return {
"countries": country_choices(),
"education": EDUCATION_CHOICES,
"occupation": OCCUPATION_CHOICES,
"experience": EXPERIENCE_CHOICES,
"budget": BUDGET_CHOICES,
"family": FAMILY_CHOICES,
"timeline": TIMELINE_CHOICES,
"residence_status": RESIDENCE_STATUS_CHOICES,
"personas": demo_personas(),
}
def build_research_prompt(
current_country: DropdownValue,
residence_status: DropdownValue,
education: DropdownValue,
occupation: DropdownValue,
experience: DropdownValue,
budget: DropdownValue,
family: DropdownValue,
timeline: DropdownValue,
goals: str,
) -> dict[str, str]:
prompt = build_profile_prompt(
None,
current_country,
residence_status,
education,
occupation,
experience,
None,
budget,
family,
timeline,
goals,
)
return {
"text": str(prompt.get("text") or ""),
"title": build_session_title(
current_country,
residence_status,
education,
occupation,
experience,
budget,
family,
timeline,
goals,
),
}
def build_persona_prompt(persona_id: str) -> str:
for persona in demo_personas():
if persona["id"] == persona_id:
prompt = persona_prompt(persona["profile"])
return str(prompt.get("text") or "")
return ""
def _merge_chat_history(
history: list[dict[str, Any]],
message: str,
ui_messages: list[ChatMessage],
assistant_text: str,
) -> list[dict[str, Any]]:
updated_history = list(history)
updated_history.append({"role": "user", "content": message})
for ui_message in ui_messages:
updated_history.append(_chat_message_to_dict(ui_message))
if assistant_text and (
not ui_messages
or ui_messages[-1].role != "assistant"
or ui_messages[-1].metadata
):
updated_history.append({"role": "assistant", "content": assistant_text})
return updated_history
def _stream_agent_response(
message: str,
history: list[dict[str, Any]],
globe_state: dict[str, Any],
hf_token: gr.OAuthToken | None,
) -> Iterator[tuple[Any, dict[str, Any]]]:
yield from _respond_fn()(
message,
history,
BORDERLESS_SYSTEM_PROMPT,
DEFAULT_MAX_TOKENS,
DEFAULT_TEMPERATURE,
DEFAULT_TOP_P,
globe_state,
hf_token,
)
def stream_chat(
message: str,
history: list[dict[str, Any]],
globe_state: dict[str, Any] | None,
hf_token: gr.OAuthToken | None,
):
state = globe_state if globe_state else empty_globe_state()
ui_messages: list[ChatMessage] = []
assistant_text = ""
if local_inference_enabled() and AGENT_MODE != "legacy":
import spaces
@spaces.GPU(duration=GPU_DURATION)
def _run_agent_graph_gpu(
message: str,
history: list[dict[str, Any]],
globe_state: dict[str, Any],
hf_token: gr.OAuthToken | None,
) -> Iterator[tuple[Any, dict[str, Any]]]:
yield from _stream_agent_response(message, history, globe_state, hf_token)
agent_stream = _run_agent_graph_gpu(message, history, state, hf_token)
else:
agent_stream = _stream_agent_response(message, history, state, hf_token)
for chunk in agent_stream:
payload, state = chunk
if isinstance(payload, list):
ui_messages = payload
elif isinstance(payload, str):
assistant_text = payload
yield {
"history": _merge_chat_history(
history, message, ui_messages, assistant_text
),
"globe_state": state,
"done": False,
}
yield {
"history": _merge_chat_history(history, message, ui_messages, assistant_text),
"globe_state": state,
"done": True,
}