Spaces:
Runtime error
Runtime error
| """Callback Handler streams to stdout on new llm token.""" | |
| import sys | |
| from typing import Any, Dict, List, Optional | |
| from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler | |
| DEFAULT_ANSWER_PREFIX_TOKENS = ["Final", "Answer", ":"] | |
| class FinalStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): | |
| """Callback handler for streaming in agents. | |
| Only works with agents using LLMs that support streaming. | |
| Only the final output of the agent will be streamed. | |
| """ | |
| def append_to_last_tokens(self, token: str) -> None: | |
| self.last_tokens.append(token) | |
| self.last_tokens_stripped.append(token.strip()) | |
| if len(self.last_tokens) > len(self.answer_prefix_tokens): | |
| self.last_tokens.pop(0) | |
| self.last_tokens_stripped.pop(0) | |
| def check_if_answer_reached(self) -> bool: | |
| if self.strip_tokens: | |
| return self.last_tokens_stripped == self.answer_prefix_tokens_stripped | |
| else: | |
| return self.last_tokens == self.answer_prefix_tokens | |
| def __init__( | |
| self, | |
| *, | |
| answer_prefix_tokens: Optional[List[str]] = None, | |
| strip_tokens: bool = True, | |
| stream_prefix: bool = False, | |
| ) -> None: | |
| """Instantiate FinalStreamingStdOutCallbackHandler. | |
| Args: | |
| answer_prefix_tokens: Token sequence that prefixes the answer. | |
| Default is ["Final", "Answer", ":"] | |
| strip_tokens: Ignore white spaces and new lines when comparing | |
| answer_prefix_tokens to last tokens? (to determine if answer has been | |
| reached) | |
| stream_prefix: Should answer prefix itself also be streamed? | |
| """ | |
| super().__init__() | |
| if answer_prefix_tokens is None: | |
| self.answer_prefix_tokens = DEFAULT_ANSWER_PREFIX_TOKENS | |
| else: | |
| self.answer_prefix_tokens = answer_prefix_tokens | |
| if strip_tokens: | |
| self.answer_prefix_tokens_stripped = [ | |
| token.strip() for token in self.answer_prefix_tokens | |
| ] | |
| else: | |
| self.answer_prefix_tokens_stripped = self.answer_prefix_tokens | |
| self.last_tokens = [""] * len(self.answer_prefix_tokens) | |
| self.last_tokens_stripped = [""] * len(self.answer_prefix_tokens) | |
| self.strip_tokens = strip_tokens | |
| self.stream_prefix = stream_prefix | |
| self.answer_reached = False | |
| def on_llm_start( | |
| self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | |
| ) -> None: | |
| """Run when LLM starts running.""" | |
| self.answer_reached = False | |
| def on_llm_new_token(self, token: str, **kwargs: Any) -> None: | |
| """Run on new LLM token. Only available when streaming is enabled.""" | |
| # Remember the last n tokens, where n = len(answer_prefix_tokens) | |
| self.append_to_last_tokens(token) | |
| # Check if the last n tokens match the answer_prefix_tokens list ... | |
| if self.check_if_answer_reached(): | |
| self.answer_reached = True | |
| if self.stream_prefix: | |
| for t in self.last_tokens: | |
| sys.stdout.write(t) | |
| sys.stdout.flush() | |
| return | |
| # ... if yes, then print tokens from now on | |
| if self.answer_reached: | |
| sys.stdout.write(token) | |
| sys.stdout.flush() | |