|
|
from __future__ import annotations |
|
|
|
|
|
from typing import TYPE_CHECKING, Any |
|
|
from uuid import UUID |
|
|
|
|
|
from langchain_core.agents import AgentAction, AgentFinish |
|
|
from langchain_core.callbacks.base import AsyncCallbackHandler |
|
|
from loguru import logger |
|
|
from typing_extensions import override |
|
|
|
|
|
from langflow.api.v1.schemas import ChatResponse, PromptResponse |
|
|
from langflow.services.deps import get_chat_service, get_socket_service |
|
|
from langflow.utils.util import remove_ansi_escape_codes |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from langflow.services.socket.service import SocketIOService |
|
|
|
|
|
|
|
|
|
|
|
class AsyncStreamingLLMCallbackHandleSIO(AsyncCallbackHandler): |
|
|
"""Callback handler for streaming LLM responses.""" |
|
|
|
|
|
@property |
|
|
def ignore_chain(self) -> bool: |
|
|
"""Whether to ignore chain callbacks.""" |
|
|
return False |
|
|
|
|
|
def __init__(self, session_id: str): |
|
|
self.chat_service = get_chat_service() |
|
|
self.client_id = session_id |
|
|
self.socketio_service: SocketIOService = get_socket_service() |
|
|
self.sid = session_id |
|
|
|
|
|
|
|
|
@override |
|
|
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: |
|
|
resp = ChatResponse(message=token, type="stream", intermediate_steps="") |
|
|
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump()) |
|
|
|
|
|
@override |
|
|
async def on_tool_start(self, serialized: dict[str, Any], input_str: str, **kwargs: Any) -> Any: |
|
|
"""Run when tool starts running.""" |
|
|
resp = ChatResponse( |
|
|
message="", |
|
|
type="stream", |
|
|
intermediate_steps=f"Tool input: {input_str}", |
|
|
) |
|
|
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump()) |
|
|
|
|
|
async def on_tool_end(self, output: str, **kwargs: Any) -> Any: |
|
|
"""Run when tool ends running.""" |
|
|
observation_prefix = kwargs.get("observation_prefix", "Tool output: ") |
|
|
split_output = output.split() |
|
|
first_word = split_output[0] |
|
|
rest_of_output = split_output[1:] |
|
|
|
|
|
intermediate_steps = f"{observation_prefix}{first_word}" |
|
|
|
|
|
|
|
|
resp = ChatResponse( |
|
|
message="", |
|
|
type="stream", |
|
|
intermediate_steps=intermediate_steps, |
|
|
) |
|
|
rest_of_resps = [ |
|
|
ChatResponse( |
|
|
message="", |
|
|
type="stream", |
|
|
intermediate_steps=f"{word}", |
|
|
) |
|
|
for word in rest_of_output |
|
|
] |
|
|
resps = [resp, *rest_of_resps] |
|
|
|
|
|
|
|
|
try: |
|
|
|
|
|
for resp in resps: |
|
|
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump()) |
|
|
except Exception: |
|
|
logger.exception("Error sending response") |
|
|
|
|
|
async def on_tool_error( |
|
|
self, |
|
|
error: BaseException, |
|
|
*, |
|
|
run_id: UUID, |
|
|
parent_run_id: UUID | None = None, |
|
|
tags: list[str] | None = None, |
|
|
**kwargs: Any, |
|
|
) -> None: |
|
|
"""Run when tool errors.""" |
|
|
|
|
|
@override |
|
|
async def on_text( |
|
|
self, text: str, **kwargs: Any |
|
|
) -> Any: |
|
|
"""Run on arbitrary text.""" |
|
|
|
|
|
|
|
|
|
|
|
if "Prompt after formatting" in text: |
|
|
text = text.replace("Prompt after formatting:\n", "") |
|
|
text = remove_ansi_escape_codes(text) |
|
|
resp = PromptResponse( |
|
|
prompt=text, |
|
|
) |
|
|
await self.socketio_service.emit_message(to=self.sid, data=resp.model_dump()) |
|
|
|
|
|
@override |
|
|
async def on_agent_action( |
|
|
self, action: AgentAction, **kwargs: Any |
|
|
) -> None: |
|
|
log = f"Thought: {action.log}" |
|
|
|
|
|
|
|
|
if "\n" in log: |
|
|
logs = log.split("\n") |
|
|
for log in logs: |
|
|
resp = ChatResponse(message="", type="stream", intermediate_steps=log) |
|
|
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump()) |
|
|
else: |
|
|
resp = ChatResponse(message="", type="stream", intermediate_steps=log) |
|
|
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump()) |
|
|
|
|
|
@override |
|
|
async def on_agent_finish( |
|
|
self, finish: AgentFinish, **kwargs: Any |
|
|
) -> Any: |
|
|
"""Run on agent end.""" |
|
|
resp = ChatResponse( |
|
|
message="", |
|
|
type="stream", |
|
|
intermediate_steps=finish.log, |
|
|
) |
|
|
await self.socketio_service.emit_token(to=self.sid, data=resp.model_dump()) |
|
|
|