File size: 4,277 Bytes
d7b3d84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import json
from typing import overload

from langchain_core.messages import (  # pyright: ignore
	AIMessage,
	HumanMessage,
	SystemMessage,
)
from langchain_core.messages import (  # pyright: ignore
	ToolCall as LangChainToolCall,
)
from langchain_core.messages.base import BaseMessage as LangChainBaseMessage  # pyright: ignore

from browser_use.llm.messages import (
	AssistantMessage,
	BaseMessage,
	ContentPartImageParam,
	ContentPartRefusalParam,
	ContentPartTextParam,
	ToolCall,
	UserMessage,
)
from browser_use.llm.messages import (
	SystemMessage as BrowserUseSystemMessage,
)


class LangChainMessageSerializer:
	"""Serializer for converting between browser-use message types and LangChain message types."""

	@staticmethod
	def _serialize_user_content(
		content: str | list[ContentPartTextParam | ContentPartImageParam],
	) -> str | list[str | dict]:
		"""Convert user message content for LangChain compatibility."""
		if isinstance(content, str):
			return content

		serialized_parts = []
		for part in content:
			if part.type == 'text':
				serialized_parts.append(
					{
						'type': 'text',
						'text': part.text,
					}
				)
			elif part.type == 'image_url':
				# LangChain format for images
				serialized_parts.append(
					{'type': 'image_url', 'image_url': {'url': part.image_url.url, 'detail': part.image_url.detail}}
				)

		return serialized_parts

	@staticmethod
	def _serialize_system_content(
		content: str | list[ContentPartTextParam],
	) -> str:
		"""Convert system message content to text string for LangChain compatibility."""
		if isinstance(content, str):
			return content

		text_parts = []
		for part in content:
			if part.type == 'text':
				text_parts.append(part.text)

		return '\n'.join(text_parts)

	@staticmethod
	def _serialize_assistant_content(
		content: str | list[ContentPartTextParam | ContentPartRefusalParam] | None,
	) -> str:
		"""Convert assistant message content to text string for LangChain compatibility."""
		if content is None:
			return ''
		if isinstance(content, str):
			return content

		text_parts = []
		for part in content:
			if part.type == 'text':
				text_parts.append(part.text)
			# elif part.type == 'refusal':
			# 	# Include refusal content as text
			# 	text_parts.append(f'[Refusal: {part.refusal}]')

		return '\n'.join(text_parts)

	@staticmethod
	def _serialize_tool_call(tool_call: ToolCall) -> LangChainToolCall:
		"""Convert browser-use ToolCall to LangChain ToolCall."""
		# Parse the arguments string to a dict for LangChain
		try:
			args_dict = json.loads(tool_call.function.arguments)
		except json.JSONDecodeError:
			# If parsing fails, wrap in a dict
			args_dict = {'arguments': tool_call.function.arguments}

		return LangChainToolCall(
			name=tool_call.function.name,
			args=args_dict,
			id=tool_call.id,
		)

	# region - Serialize overloads
	@overload
	@staticmethod
	def serialize(message: UserMessage) -> HumanMessage: ...

	@overload
	@staticmethod
	def serialize(message: BrowserUseSystemMessage) -> SystemMessage: ...

	@overload
	@staticmethod
	def serialize(message: AssistantMessage) -> AIMessage: ...

	@staticmethod
	def serialize(message: BaseMessage) -> LangChainBaseMessage:
		"""Serialize a browser-use message to a LangChain message."""

		if isinstance(message, UserMessage):
			content = LangChainMessageSerializer._serialize_user_content(message.content)
			return HumanMessage(content=content, name=message.name)

		elif isinstance(message, BrowserUseSystemMessage):
			content = LangChainMessageSerializer._serialize_system_content(message.content)
			return SystemMessage(content=content, name=message.name)

		elif isinstance(message, AssistantMessage):
			# Handle content
			content = LangChainMessageSerializer._serialize_assistant_content(message.content)

			# For simplicity, we'll ignore tool calls in LangChain integration
			# as requested by the user
			return AIMessage(
				content=content,
				name=message.name,
			)

		else:
			raise ValueError(f'Unknown message type: {type(message)}')

	@staticmethod
	def serialize_messages(messages: list[BaseMessage]) -> list[LangChainBaseMessage]:
		"""Serialize a list of browser-use messages to LangChain messages."""
		return [LangChainMessageSerializer.serialize(m) for m in messages]