Make /chat a generator: stream_chat in server_api.py, generator api_chat in app.py, chunked final-answer yields in streaming.py
Browse files- app.py +2 -2
- ui/agent/streaming.py +15 -4
- ui/server_api.py +31 -14
app.py
CHANGED
|
@@ -113,8 +113,8 @@ def api_chat(
|
|
| 113 |
history: list[dict],
|
| 114 |
globe_state: dict | None,
|
| 115 |
hf_token: gr.OAuthToken | None,
|
| 116 |
-
)
|
| 117 |
-
|
| 118 |
|
| 119 |
|
| 120 |
app.mount("/assets", StaticFiles(directory=str(ASSETS_DIR)), name="assets")
|
|
|
|
| 113 |
history: list[dict],
|
| 114 |
globe_state: dict | None,
|
| 115 |
hf_token: gr.OAuthToken | None,
|
| 116 |
+
):
|
| 117 |
+
yield from server_api.stream_chat(message, history, globe_state, hf_token)
|
| 118 |
|
| 119 |
|
| 120 |
app.mount("/assets", StaticFiles(directory=str(ASSETS_DIR)), name="assets")
|
ui/agent/streaming.py
CHANGED
|
@@ -5,13 +5,24 @@ from typing import Any
|
|
| 5 |
|
| 6 |
from gradio import ChatMessage
|
| 7 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
def yield_streaming_string(text: str, globe_state: dict[str, Any]):
|
| 10 |
if not text:
|
| 11 |
yield "", globe_state
|
| 12 |
return
|
| 13 |
-
for
|
| 14 |
-
yield text[:
|
| 15 |
|
| 16 |
|
| 17 |
def yield_streaming_messages(
|
|
@@ -26,8 +37,8 @@ def yield_streaming_messages(
|
|
| 26 |
return
|
| 27 |
|
| 28 |
messages.append(ChatMessage(role="assistant", content=""))
|
| 29 |
-
for
|
| 30 |
-
messages[-1] = ChatMessage(role="assistant", content=text[:
|
| 31 |
yield messages, globe_state
|
| 32 |
|
| 33 |
|
|
|
|
| 5 |
|
| 6 |
from gradio import ChatMessage
|
| 7 |
|
| 8 |
+
STREAM_CHUNK_SIZE = 40
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def _chunk_end_indices(text: str) -> list[int]:
|
| 12 |
+
if not text:
|
| 13 |
+
return []
|
| 14 |
+
indices = list(range(STREAM_CHUNK_SIZE, len(text), STREAM_CHUNK_SIZE))
|
| 15 |
+
if not indices or indices[-1] != len(text):
|
| 16 |
+
indices.append(len(text))
|
| 17 |
+
return indices
|
| 18 |
+
|
| 19 |
|
| 20 |
def yield_streaming_string(text: str, globe_state: dict[str, Any]):
|
| 21 |
if not text:
|
| 22 |
yield "", globe_state
|
| 23 |
return
|
| 24 |
+
for end in _chunk_end_indices(text):
|
| 25 |
+
yield text[:end], globe_state
|
| 26 |
|
| 27 |
|
| 28 |
def yield_streaming_messages(
|
|
|
|
| 37 |
return
|
| 38 |
|
| 39 |
messages.append(ChatMessage(role="assistant", content=""))
|
| 40 |
+
for end in _chunk_end_indices(text):
|
| 41 |
+
messages[-1] = ChatMessage(role="assistant", content=text[:end])
|
| 42 |
yield messages, globe_state
|
| 43 |
|
| 44 |
|
ui/server_api.py
CHANGED
|
@@ -85,12 +85,31 @@ def build_persona_prompt(persona_id: str) -> str:
|
|
| 85 |
return ""
|
| 86 |
|
| 87 |
|
| 88 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
message: str,
|
| 90 |
history: list[dict[str, Any]],
|
| 91 |
globe_state: dict[str, Any] | None,
|
| 92 |
hf_token: gr.OAuthToken | None,
|
| 93 |
-
)
|
| 94 |
state = globe_state if globe_state else empty_globe_state()
|
| 95 |
ui_messages: list[ChatMessage] = []
|
| 96 |
assistant_text = ""
|
|
@@ -111,18 +130,16 @@ def run_chat(
|
|
| 111 |
elif isinstance(payload, str):
|
| 112 |
assistant_text = payload
|
| 113 |
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
or ui_messages[-1].metadata
|
| 122 |
-
):
|
| 123 |
-
updated_history.append({"role": "assistant", "content": assistant_text})
|
| 124 |
|
| 125 |
-
|
| 126 |
-
"history":
|
| 127 |
"globe_state": state,
|
|
|
|
| 128 |
}
|
|
|
|
| 85 |
return ""
|
| 86 |
|
| 87 |
|
| 88 |
+
def _merge_chat_history(
|
| 89 |
+
history: list[dict[str, Any]],
|
| 90 |
+
message: str,
|
| 91 |
+
ui_messages: list[ChatMessage],
|
| 92 |
+
assistant_text: str,
|
| 93 |
+
) -> list[dict[str, Any]]:
|
| 94 |
+
updated_history = list(history)
|
| 95 |
+
updated_history.append({"role": "user", "content": message})
|
| 96 |
+
for ui_message in ui_messages:
|
| 97 |
+
updated_history.append(_chat_message_to_dict(ui_message))
|
| 98 |
+
if assistant_text and (
|
| 99 |
+
not ui_messages
|
| 100 |
+
or ui_messages[-1].role != "assistant"
|
| 101 |
+
or ui_messages[-1].metadata
|
| 102 |
+
):
|
| 103 |
+
updated_history.append({"role": "assistant", "content": assistant_text})
|
| 104 |
+
return updated_history
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def stream_chat(
|
| 108 |
message: str,
|
| 109 |
history: list[dict[str, Any]],
|
| 110 |
globe_state: dict[str, Any] | None,
|
| 111 |
hf_token: gr.OAuthToken | None,
|
| 112 |
+
):
|
| 113 |
state = globe_state if globe_state else empty_globe_state()
|
| 114 |
ui_messages: list[ChatMessage] = []
|
| 115 |
assistant_text = ""
|
|
|
|
| 130 |
elif isinstance(payload, str):
|
| 131 |
assistant_text = payload
|
| 132 |
|
| 133 |
+
yield {
|
| 134 |
+
"history": _merge_chat_history(
|
| 135 |
+
history, message, ui_messages, assistant_text
|
| 136 |
+
),
|
| 137 |
+
"globe_state": state,
|
| 138 |
+
"done": False,
|
| 139 |
+
}
|
|
|
|
|
|
|
|
|
|
| 140 |
|
| 141 |
+
yield {
|
| 142 |
+
"history": _merge_chat_history(history, message, ui_messages, assistant_text),
|
| 143 |
"globe_state": state,
|
| 144 |
+
"done": True,
|
| 145 |
}
|