|
|
| from __future__ import annotations
|
|
|
| from typing import Any
|
|
|
| from gradio import ChatMessage
|
|
|
| STREAM_CHUNK_SIZE = 40
|
|
|
|
|
| def _chunk_end_indices(text: str) -> list[int]:
|
| if not text:
|
| return []
|
| indices = list(range(STREAM_CHUNK_SIZE, len(text), STREAM_CHUNK_SIZE))
|
| if not indices or indices[-1] != len(text):
|
| indices.append(len(text))
|
| return indices
|
|
|
|
|
| def yield_streaming_string(text: str, globe_state: dict[str, Any]):
|
| if not text:
|
| yield "", globe_state
|
| return
|
| for end in _chunk_end_indices(text):
|
| yield text[:end], globe_state
|
|
|
|
|
| def yield_streaming_messages(
|
| ui_messages: list[ChatMessage],
|
| text: str,
|
| globe_state: dict[str, Any],
|
| *,
|
| assistant_metadata: dict[str, Any] | None = None,
|
| ):
|
| messages = list(ui_messages)
|
| metadata = dict(assistant_metadata or {})
|
| if not text:
|
| messages.append(ChatMessage(role="assistant", content="", metadata=metadata or None))
|
| yield messages, globe_state
|
| return
|
|
|
| messages.append(ChatMessage(role="assistant", content="", metadata=metadata or None))
|
| for end in _chunk_end_indices(text):
|
| messages[-1] = ChatMessage(
|
| role="assistant",
|
| content=text[:end],
|
| metadata=metadata or None,
|
| )
|
| yield messages, globe_state
|
|
|
|
|
| def yield_response(
|
| ui_messages: list[ChatMessage],
|
| text: str,
|
| globe_state: dict[str, Any],
|
| *,
|
| assistant_metadata: dict[str, Any] | None = None,
|
| ):
|
| if ui_messages:
|
| yield from yield_streaming_messages(
|
| ui_messages,
|
| text,
|
| globe_state,
|
| assistant_metadata=assistant_metadata,
|
| )
|
| else:
|
| yield from yield_streaming_string(text, globe_state)
|
|
|