Spaces:
Sleeping
Sleeping
File size: 4,277 Bytes
d7b3d84 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | import json
from typing import overload
from langchain_core.messages import ( # pyright: ignore
AIMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.messages import ( # pyright: ignore
ToolCall as LangChainToolCall,
)
from langchain_core.messages.base import BaseMessage as LangChainBaseMessage # pyright: ignore
from browser_use.llm.messages import (
AssistantMessage,
BaseMessage,
ContentPartImageParam,
ContentPartRefusalParam,
ContentPartTextParam,
ToolCall,
UserMessage,
)
from browser_use.llm.messages import (
SystemMessage as BrowserUseSystemMessage,
)
class LangChainMessageSerializer:
"""Serializer for converting between browser-use message types and LangChain message types."""
@staticmethod
def _serialize_user_content(
content: str | list[ContentPartTextParam | ContentPartImageParam],
) -> str | list[str | dict]:
"""Convert user message content for LangChain compatibility."""
if isinstance(content, str):
return content
serialized_parts = []
for part in content:
if part.type == 'text':
serialized_parts.append(
{
'type': 'text',
'text': part.text,
}
)
elif part.type == 'image_url':
# LangChain format for images
serialized_parts.append(
{'type': 'image_url', 'image_url': {'url': part.image_url.url, 'detail': part.image_url.detail}}
)
return serialized_parts
@staticmethod
def _serialize_system_content(
content: str | list[ContentPartTextParam],
) -> str:
"""Convert system message content to text string for LangChain compatibility."""
if isinstance(content, str):
return content
text_parts = []
for part in content:
if part.type == 'text':
text_parts.append(part.text)
return '\n'.join(text_parts)
@staticmethod
def _serialize_assistant_content(
content: str | list[ContentPartTextParam | ContentPartRefusalParam] | None,
) -> str:
"""Convert assistant message content to text string for LangChain compatibility."""
if content is None:
return ''
if isinstance(content, str):
return content
text_parts = []
for part in content:
if part.type == 'text':
text_parts.append(part.text)
# elif part.type == 'refusal':
# # Include refusal content as text
# text_parts.append(f'[Refusal: {part.refusal}]')
return '\n'.join(text_parts)
@staticmethod
def _serialize_tool_call(tool_call: ToolCall) -> LangChainToolCall:
"""Convert browser-use ToolCall to LangChain ToolCall."""
# Parse the arguments string to a dict for LangChain
try:
args_dict = json.loads(tool_call.function.arguments)
except json.JSONDecodeError:
# If parsing fails, wrap in a dict
args_dict = {'arguments': tool_call.function.arguments}
return LangChainToolCall(
name=tool_call.function.name,
args=args_dict,
id=tool_call.id,
)
# region - Serialize overloads
@overload
@staticmethod
def serialize(message: UserMessage) -> HumanMessage: ...
@overload
@staticmethod
def serialize(message: BrowserUseSystemMessage) -> SystemMessage: ...
@overload
@staticmethod
def serialize(message: AssistantMessage) -> AIMessage: ...
@staticmethod
def serialize(message: BaseMessage) -> LangChainBaseMessage:
"""Serialize a browser-use message to a LangChain message."""
if isinstance(message, UserMessage):
content = LangChainMessageSerializer._serialize_user_content(message.content)
return HumanMessage(content=content, name=message.name)
elif isinstance(message, BrowserUseSystemMessage):
content = LangChainMessageSerializer._serialize_system_content(message.content)
return SystemMessage(content=content, name=message.name)
elif isinstance(message, AssistantMessage):
# Handle content
content = LangChainMessageSerializer._serialize_assistant_content(message.content)
# For simplicity, we'll ignore tool calls in LangChain integration
# as requested by the user
return AIMessage(
content=content,
name=message.name,
)
else:
raise ValueError(f'Unknown message type: {type(message)}')
@staticmethod
def serialize_messages(messages: list[BaseMessage]) -> list[LangChainBaseMessage]:
"""Serialize a list of browser-use messages to LangChain messages."""
return [LangChainMessageSerializer.serialize(m) for m in messages]
|