File size: 5,683 Bytes
b466a63
 
 
b23426c
147e220
b466a63
 
 
 
 
147e220
b23426c
 
b466a63
 
 
 
 
 
 
 
 
 
 
 
 
 
dbdb569
b466a63
 
 
b23426c
 
 
 
 
 
 
 
b466a63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbdb569
b466a63
0c56308
b466a63
 
 
 
 
0c56308
b466a63
 
 
 
 
dbdb569
 
 
 
 
 
 
 
 
 
 
 
 
 
b466a63
 
 
 
 
 
 
 
 
 
39fd04c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147e220
b466a63
 
147e220
b466a63
147e220
 
b466a63
 
 
 
 
 
147e220
b466a63
147e220
 
 
 
 
 
 
 
 
 
 
 
 
71a7158
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147e220
 
b466a63
 
 
 
 
 
39fd04c
 
 
 
 
 
 
b466a63
39fd04c
 
b466a63
39fd04c
b466a63
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# ui/server_api.py
from __future__ import annotations

import os
from collections.abc import Iterator
from typing import Any

import gradio as gr
from gradio import ChatMessage

from ui.agent.config import GPU_DURATION, local_inference_enabled
from ui.agent.graph import respond_with_graph
from ui.agent.respond import respond as respond_legacy
from ui.agent.system_prompt import BORDERLESS_SYSTEM_PROMPT
from ui.chat.defaults import DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE, DEFAULT_TOP_P
from ui.globe_commands import empty_globe_state
from ui.intake.choices import (
    BUDGET_CHOICES,
    EDUCATION_CHOICES,
    EXPERIENCE_CHOICES,
    FAMILY_CHOICES,
    OCCUPATION_CHOICES,
    RESIDENCE_STATUS_CHOICES,
    TIMELINE_CHOICES,
    country_choices,
)
from ui.intake.examples import demo_personas, persona_prompt
from ui.intake.prompts import build_profile_prompt, build_session_title

DropdownValue = str | int | float | list[str | int | float] | None

AGENT_MODE = os.environ.get("BORDERLESS_AGENT_MODE", "graph").strip().lower()


def _respond_fn():
    if AGENT_MODE == "legacy":
        return respond_legacy
    return respond_with_graph


def _chat_message_to_dict(message: ChatMessage) -> dict[str, Any]:
    payload: dict[str, Any] = {
        "role": message.role,
        "content": message.content,
    }
    if message.metadata:
        payload["metadata"] = message.metadata
    return payload


def get_intake_choices() -> dict[str, Any]:
    return {
        "countries": country_choices(),
        "education": EDUCATION_CHOICES,
        "occupation": OCCUPATION_CHOICES,
        "experience": EXPERIENCE_CHOICES,
        "budget": BUDGET_CHOICES,
        "family": FAMILY_CHOICES,
        "timeline": TIMELINE_CHOICES,
        "residence_status": RESIDENCE_STATUS_CHOICES,
        "personas": demo_personas(),
    }


def build_research_prompt(

    current_country: DropdownValue,

    residence_status: DropdownValue,

    education: DropdownValue,

    occupation: DropdownValue,

    experience: DropdownValue,

    budget: DropdownValue,

    family: DropdownValue,

    timeline: DropdownValue,

    goals: str,

) -> dict[str, str]:
    prompt = build_profile_prompt(
        None,
        current_country,
        residence_status,
        education,
        occupation,
        experience,
        None,
        budget,
        family,
        timeline,
        goals,
    )
    return {
        "text": str(prompt.get("text") or ""),
        "title": build_session_title(
            current_country,
            residence_status,
            education,
            occupation,
            experience,
            budget,
            family,
            timeline,
            goals,
        ),
    }


def build_persona_prompt(persona_id: str) -> str:
    for persona in demo_personas():
        if persona["id"] == persona_id:
            prompt = persona_prompt(persona["profile"])
            return str(prompt.get("text") or "")
    return ""


def _merge_chat_history(

    history: list[dict[str, Any]],

    message: str,

    ui_messages: list[ChatMessage],

    assistant_text: str,

) -> list[dict[str, Any]]:
    updated_history = list(history)
    updated_history.append({"role": "user", "content": message})
    for ui_message in ui_messages:
        updated_history.append(_chat_message_to_dict(ui_message))
    if assistant_text and (
        not ui_messages
        or ui_messages[-1].role != "assistant"
        or ui_messages[-1].metadata
    ):
        updated_history.append({"role": "assistant", "content": assistant_text})
    return updated_history


def _stream_agent_response(

    message: str,

    history: list[dict[str, Any]],

    globe_state: dict[str, Any],

    hf_token: gr.OAuthToken | None,

) -> Iterator[tuple[Any, dict[str, Any]]]:
    yield from _respond_fn()(
        message,
        history,
        BORDERLESS_SYSTEM_PROMPT,
        DEFAULT_MAX_TOKENS,
        DEFAULT_TEMPERATURE,
        DEFAULT_TOP_P,
        globe_state,
        hf_token,
    )


def stream_chat(

    message: str,

    history: list[dict[str, Any]],

    globe_state: dict[str, Any] | None,

    hf_token: gr.OAuthToken | None,

):
    state = globe_state if globe_state else empty_globe_state()
    ui_messages: list[ChatMessage] = []
    assistant_text = ""

    if local_inference_enabled() and AGENT_MODE != "legacy":
        import spaces

        @spaces.GPU(duration=GPU_DURATION)
        def _run_agent_graph_gpu(

            message: str,

            history: list[dict[str, Any]],

            globe_state: dict[str, Any],

            hf_token: gr.OAuthToken | None,

        ) -> Iterator[tuple[Any, dict[str, Any]]]:
            yield from _stream_agent_response(message, history, globe_state, hf_token)

        agent_stream = _run_agent_graph_gpu(message, history, state, hf_token)
    else:
        agent_stream = _stream_agent_response(message, history, state, hf_token)

    for chunk in agent_stream:
        payload, state = chunk
        if isinstance(payload, list):
            ui_messages = payload
        elif isinstance(payload, str):
            assistant_text = payload

        yield {
            "history": _merge_chat_history(
                history, message, ui_messages, assistant_text
            ),
            "globe_state": state,
            "done": False,
        }

    yield {
        "history": _merge_chat_history(history, message, ui_messages, assistant_text),
        "globe_state": state,
        "done": True,
    }