spagestic commited on
Commit
39fd04c
·
1 Parent(s): eb5680f

Make /chat a generator: stream_chat in server_api.py, generator api_chat in app.py, chunked final-answer yields in streaming.py

Browse files
Files changed (3) hide show
  1. app.py +2 -2
  2. ui/agent/streaming.py +15 -4
  3. ui/server_api.py +31 -14
app.py CHANGED
@@ -113,8 +113,8 @@ def api_chat(
113
  history: list[dict],
114
  globe_state: dict | None,
115
  hf_token: gr.OAuthToken | None,
116
- ) -> dict:
117
- return server_api.run_chat(message, history, globe_state, hf_token)
118
 
119
 
120
  app.mount("/assets", StaticFiles(directory=str(ASSETS_DIR)), name="assets")
 
113
  history: list[dict],
114
  globe_state: dict | None,
115
  hf_token: gr.OAuthToken | None,
116
+ ):
117
+ yield from server_api.stream_chat(message, history, globe_state, hf_token)
118
 
119
 
120
  app.mount("/assets", StaticFiles(directory=str(ASSETS_DIR)), name="assets")
ui/agent/streaming.py CHANGED
@@ -5,13 +5,24 @@ from typing import Any
5
 
6
  from gradio import ChatMessage
7
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def yield_streaming_string(text: str, globe_state: dict[str, Any]):
10
  if not text:
11
  yield "", globe_state
12
  return
13
- for index in range(len(text)):
14
- yield text[: index + 1], globe_state
15
 
16
 
17
  def yield_streaming_messages(
@@ -26,8 +37,8 @@ def yield_streaming_messages(
26
  return
27
 
28
  messages.append(ChatMessage(role="assistant", content=""))
29
- for index in range(len(text)):
30
- messages[-1] = ChatMessage(role="assistant", content=text[: index + 1])
31
  yield messages, globe_state
32
 
33
 
 
5
 
6
  from gradio import ChatMessage
7
 
8
+ STREAM_CHUNK_SIZE = 40
9
+
10
+
11
+ def _chunk_end_indices(text: str) -> list[int]:
12
+ if not text:
13
+ return []
14
+ indices = list(range(STREAM_CHUNK_SIZE, len(text), STREAM_CHUNK_SIZE))
15
+ if not indices or indices[-1] != len(text):
16
+ indices.append(len(text))
17
+ return indices
18
+
19
 
20
  def yield_streaming_string(text: str, globe_state: dict[str, Any]):
21
  if not text:
22
  yield "", globe_state
23
  return
24
+ for end in _chunk_end_indices(text):
25
+ yield text[:end], globe_state
26
 
27
 
28
  def yield_streaming_messages(
 
37
  return
38
 
39
  messages.append(ChatMessage(role="assistant", content=""))
40
+ for end in _chunk_end_indices(text):
41
+ messages[-1] = ChatMessage(role="assistant", content=text[:end])
42
  yield messages, globe_state
43
 
44
 
ui/server_api.py CHANGED
@@ -85,12 +85,31 @@ def build_persona_prompt(persona_id: str) -> str:
85
  return ""
86
 
87
 
88
- def run_chat(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  message: str,
90
  history: list[dict[str, Any]],
91
  globe_state: dict[str, Any] | None,
92
  hf_token: gr.OAuthToken | None,
93
- ) -> dict[str, Any]:
94
  state = globe_state if globe_state else empty_globe_state()
95
  ui_messages: list[ChatMessage] = []
96
  assistant_text = ""
@@ -111,18 +130,16 @@ def run_chat(
111
  elif isinstance(payload, str):
112
  assistant_text = payload
113
 
114
- updated_history = list(history)
115
- updated_history.append({"role": "user", "content": message})
116
- for ui_message in ui_messages:
117
- updated_history.append(_chat_message_to_dict(ui_message))
118
- if assistant_text and (
119
- not ui_messages
120
- or ui_messages[-1].role != "assistant"
121
- or ui_messages[-1].metadata
122
- ):
123
- updated_history.append({"role": "assistant", "content": assistant_text})
124
 
125
- return {
126
- "history": updated_history,
127
  "globe_state": state,
 
128
  }
 
85
  return ""
86
 
87
 
88
+ def _merge_chat_history(
89
+ history: list[dict[str, Any]],
90
+ message: str,
91
+ ui_messages: list[ChatMessage],
92
+ assistant_text: str,
93
+ ) -> list[dict[str, Any]]:
94
+ updated_history = list(history)
95
+ updated_history.append({"role": "user", "content": message})
96
+ for ui_message in ui_messages:
97
+ updated_history.append(_chat_message_to_dict(ui_message))
98
+ if assistant_text and (
99
+ not ui_messages
100
+ or ui_messages[-1].role != "assistant"
101
+ or ui_messages[-1].metadata
102
+ ):
103
+ updated_history.append({"role": "assistant", "content": assistant_text})
104
+ return updated_history
105
+
106
+
107
+ def stream_chat(
108
  message: str,
109
  history: list[dict[str, Any]],
110
  globe_state: dict[str, Any] | None,
111
  hf_token: gr.OAuthToken | None,
112
+ ):
113
  state = globe_state if globe_state else empty_globe_state()
114
  ui_messages: list[ChatMessage] = []
115
  assistant_text = ""
 
130
  elif isinstance(payload, str):
131
  assistant_text = payload
132
 
133
+ yield {
134
+ "history": _merge_chat_history(
135
+ history, message, ui_messages, assistant_text
136
+ ),
137
+ "globe_state": state,
138
+ "done": False,
139
+ }
 
 
 
140
 
141
+ yield {
142
+ "history": _merge_chat_history(history, message, ui_messages, assistant_text),
143
  "globe_state": state,
144
+ "done": True,
145
  }