| # SPDX-License-Identifier: Apache-2.0 | |
| # Copied from vLLM | |
| import logging | |
| from abc import ABC, abstractmethod | |
| from typing import Union | |
| import orjson | |
| logger = logging.getLogger(__name__) | |
| try: | |
| from mcp import ClientSession | |
| except ImportError as e: | |
| mcp = e | |
| from openai_harmony import Author, Message, Role, StreamState, TextContent | |
| from sglang.srt.entrypoints.harmony_utils import ( | |
| get_encoding, | |
| get_streamable_parser_for_assistant, | |
| render_for_completion, | |
| ) | |
| from sglang.srt.entrypoints.tool import Tool | |
| class ConversationContext(ABC): | |
| def append_output(self, output) -> None: | |
| pass | |
| async def call_tool(self) -> list[Message]: | |
| pass | |
| def need_builtin_tool_call(self) -> bool: | |
| pass | |
| def render_for_completion(self) -> list[int]: | |
| pass | |
| class SimpleContext(ConversationContext): | |
| def __init__(self): | |
| self.last_output = None | |
| def append_output(self, output) -> None: | |
| self.last_output = output | |
| def need_builtin_tool_call(self) -> bool: | |
| return False | |
| async def call_tool(self) -> list[Message]: | |
| raise NotImplementedError("Should not be called.") | |
| def render_for_completion(self) -> list[int]: | |
| raise NotImplementedError("Should not be called.") | |
| class HarmonyContext(ConversationContext): | |
| def __init__( | |
| self, | |
| messages: list, | |
| tool_sessions: dict[str, Union["ClientSession", Tool]], | |
| ): | |
| # TODO: Remove the hack of Union[ClientSession, Tool] by using MCP | |
| # when demo. | |
| self._messages = messages | |
| self.tool_sessions = tool_sessions | |
| self.parser = get_streamable_parser_for_assistant() | |
| self.num_init_messages = len(messages) | |
| # TODO | |
| self.num_prompt_tokens = 0 | |
| self.num_cached_tokens = 0 | |
| self.num_output_tokens = 0 | |
| self.num_reasoning_tokens = 0 | |
| def append_output(self, output) -> None: | |
| if isinstance(output, dict) and "output_ids" in output: | |
| output_token_ids = output["output_ids"] | |
| # TODO: REMOVE here: | |
| # Very hacky, find the first occurrence of token 200006 and cut from there | |
| try: | |
| start_index = output_token_ids.index(200006) | |
| output_token_ids = output_token_ids[start_index:] | |
| except ValueError: | |
| pass | |
| for token_id in output_token_ids: | |
| self.parser.process(token_id) | |
| output_msgs = self.parser.messages | |
| meta_info = output["meta_info"] | |
| if isinstance(meta_info, dict): | |
| if "prompt_token_ids" in meta_info: | |
| self.num_prompt_tokens = meta_info["prompt_tokens"] | |
| if "cached_tokens" in meta_info: | |
| self.num_cached_tokens = meta_info["cached_tokens"] | |
| if "completion_tokens" in meta_info: | |
| self.num_output_tokens += meta_info["completion_tokens"] | |
| else: | |
| output_msgs = output | |
| self._messages.extend(output_msgs) | |
| def messages(self) -> list: | |
| return self._messages | |
| def need_builtin_tool_call(self) -> bool: | |
| if not self.messages: | |
| return False | |
| last_msg = self.messages[-1] | |
| recipient = last_msg.recipient | |
| return recipient is not None and ( | |
| recipient.startswith("browser.") or recipient.startswith("python") | |
| ) | |
| async def call_tool(self) -> list[Message]: | |
| if not self.messages: | |
| return [] | |
| last_msg = self.messages[-1] | |
| recipient = last_msg.recipient | |
| if recipient is not None: | |
| if recipient.startswith("browser."): | |
| return await self.call_search_tool( | |
| self.tool_sessions["browser"], last_msg | |
| ) | |
| elif recipient.startswith("python"): | |
| return await self.call_python_tool( | |
| self.tool_sessions["python"], last_msg | |
| ) | |
| raise ValueError("No tool call found") | |
| def render_for_completion(self) -> list[int]: | |
| return render_for_completion(self.messages) | |
| async def call_search_tool( | |
| self, tool_session: Union["ClientSession", Tool], last_msg: Message | |
| ) -> list[Message]: | |
| if isinstance(tool_session, Tool): | |
| return await tool_session.get_result(self) | |
| tool_name = last_msg.recipient.split(".")[1] | |
| args = orjson.loads(last_msg.content[0].text) | |
| result = await tool_session.call_tool(tool_name, args) | |
| result_str = result.content[0].text | |
| content = TextContent(text=result_str) | |
| author = Author(role=Role.TOOL, name=last_msg.recipient) | |
| return [Message(author=author, content=[content], recipient=Role.ASSISTANT)] | |
| async def call_python_tool( | |
| self, tool_session: Union["ClientSession", Tool], last_msg: Message | |
| ) -> list[Message]: | |
| if isinstance(tool_session, Tool): | |
| return await tool_session.get_result(self) | |
| param = { | |
| "code": last_msg.content[0].text, | |
| } | |
| result = await tool_session.call_tool("python", param) | |
| result_str = result.content[0].text | |
| content = TextContent(text=result_str) | |
| author = Author(role=Role.TOOL, name="python") | |
| return [ | |
| Message( | |
| author=author, | |
| content=[content], | |
| channel=last_msg.channel, | |
| recipient=Role.ASSISTANT, | |
| ) | |
| ] | |
| class StreamingHarmonyContext(HarmonyContext): | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.last_output = None | |
| self.parser = get_streamable_parser_for_assistant() | |
| self.encoding = get_encoding() | |
| self.last_tok = None | |
| def messages(self) -> list: | |
| return self.parser.messages | |
| def append_output(self, output) -> None: | |
| if isinstance(output, dict) and "output_ids" in output: | |
| # RequestOutput from SGLang with outputs | |
| output_token_ids = output["output_ids"] | |
| # TODO: REMOVE here: | |
| # Very hacky, find the first occurrence of token 200006 and cut from there | |
| # Find the first occurrence of token 200006 and cut from there | |
| try: | |
| start_index = output_token_ids.index(200006) | |
| output_token_ids = output_token_ids[start_index:] | |
| except ValueError: | |
| pass | |
| for token_id in output_token_ids: | |
| self.parser.process(token_id) | |
| else: | |
| # Handle the case of tool output in direct message format | |
| assert len(output) == 1, "Tool output should be a single message" | |
| msg = output[0] | |
| # Sometimes the recipient is not set for tool messages, | |
| # so we set it to "assistant" | |
| if msg.author.role == Role.TOOL and msg.recipient is None: | |
| msg.recipient = "assistant" | |
| toks = self.encoding.render(msg) | |
| for tok in toks: | |
| self.parser.process(tok) | |
| self.last_tok = toks[-1] | |
| def is_expecting_start(self) -> bool: | |
| return self.parser.state == StreamState.EXPECT_START | |
| def is_assistant_action_turn(self) -> bool: | |
| return self.last_tok in self.encoding.stop_tokens_for_assistant_actions() | |
| def render_for_completion(self) -> list[int]: | |
| # now this list of tokens as next turn's starting tokens | |
| # `<|start|>assistant``, | |
| # we need to process them in parser. | |
| rendered_tokens = super().render_for_completion() | |
| last_n = -1 | |
| to_process = [] | |
| while rendered_tokens[last_n] != self.last_tok: | |
| to_process.append(rendered_tokens[last_n]) | |
| last_n -= 1 | |
| for tok in reversed(to_process): | |
| self.parser.process(tok) | |
| return rendered_tokens | |
Xet Storage Details
- Size:
- 8.04 kB
- Xet hash:
- eb6aba6e7f0a1c09610b4462acb6b2acdae92e6cf710ad8059075c923aeb2e5f
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.