from dataclasses import dataclass from typing import TYPE_CHECKING, TypeVar, overload from pydantic import BaseModel from browser_use.llm.base import BaseChatModel from browser_use.llm.exceptions import ModelProviderError from browser_use.llm.messages import BaseMessage from browser_use.llm.views import ChatInvokeCompletion, ChatInvokeUsage from examples.models.langchain.serializer import LangChainMessageSerializer if TYPE_CHECKING: from langchain_core.language_models.chat_models import BaseChatModel as LangChainBaseChatModel # type: ignore from langchain_core.messages import AIMessage as LangChainAIMessage # type: ignore T = TypeVar('T', bound=BaseModel) @dataclass class ChatLangchain(BaseChatModel): """ A wrapper around LangChain BaseChatModel that implements the browser-use BaseChatModel protocol. This class allows you to use any LangChain-compatible model with browser-use. """ # The LangChain model to wrap chat: 'LangChainBaseChatModel' @property def model(self) -> str: return self.name @property def provider(self) -> str: """Return the provider name based on the LangChain model class.""" model_class_name = self.chat.__class__.__name__.lower() if 'openai' in model_class_name: return 'openai' elif 'anthropic' in model_class_name or 'claude' in model_class_name: return 'anthropic' elif 'google' in model_class_name or 'gemini' in model_class_name: return 'google' elif 'groq' in model_class_name: return 'groq' elif 'ollama' in model_class_name: return 'ollama' elif 'deepseek' in model_class_name: return 'deepseek' else: return 'langchain' @property def name(self) -> str: """Return the model name.""" # Try to get model name from the LangChain model using getattr to avoid type errors model_name = getattr(self.chat, 'model_name', None) if model_name: return str(model_name) model_attr = getattr(self.chat, 'model', None) if model_attr: return str(model_attr) return self.chat.__class__.__name__ def _get_usage(self, response: 'LangChainAIMessage') -> ChatInvokeUsage | None: usage = response.usage_metadata if usage is None: return None prompt_tokens = usage['input_tokens'] or 0 completion_tokens = usage['output_tokens'] or 0 total_tokens = usage['total_tokens'] or 0 input_token_details = usage.get('input_token_details', None) if input_token_details is not None: prompt_cached_tokens = input_token_details.get('cache_read', None) prompt_cache_creation_tokens = input_token_details.get('cache_creation', None) else: prompt_cached_tokens = None prompt_cache_creation_tokens = None return ChatInvokeUsage( prompt_tokens=prompt_tokens, prompt_cached_tokens=prompt_cached_tokens, prompt_cache_creation_tokens=prompt_cache_creation_tokens, prompt_image_tokens=None, completion_tokens=completion_tokens, total_tokens=total_tokens, ) @overload async def ainvoke(self, messages: list[BaseMessage], output_format: None = None) -> ChatInvokeCompletion[str]: ... @overload async def ainvoke(self, messages: list[BaseMessage], output_format: type[T]) -> ChatInvokeCompletion[T]: ... async def ainvoke( self, messages: list[BaseMessage], output_format: type[T] | None = None ) -> ChatInvokeCompletion[T] | ChatInvokeCompletion[str]: """ Invoke the LangChain model with the given messages. Args: messages: List of browser-use chat messages output_format: Optional Pydantic model class for structured output (not supported in basic LangChain integration) Returns: Either a string response or an instance of output_format """ # Convert browser-use messages to LangChain messages langchain_messages = LangChainMessageSerializer.serialize_messages(messages) try: if output_format is None: # Return string response response = await self.chat.ainvoke(langchain_messages) # type: ignore # Import at runtime for isinstance check from langchain_core.messages import AIMessage as LangChainAIMessage # type: ignore if not isinstance(response, LangChainAIMessage): raise ModelProviderError( message=f'Response is not an AIMessage: {type(response)}', model=self.name, ) # Extract content from LangChain response content = response.content if hasattr(response, 'content') else str(response) usage = self._get_usage(response) return ChatInvokeCompletion( completion=str(content), usage=usage, ) else: # Use LangChain's structured output capability try: structured_chat = self.chat.with_structured_output(output_format) parsed_object = await structured_chat.ainvoke(langchain_messages) # For structured output, usage metadata is typically not available # in the parsed object since it's a Pydantic model, not an AIMessage usage = None # Type cast since LangChain's with_structured_output returns the correct type return ChatInvokeCompletion( completion=parsed_object, # type: ignore usage=usage, ) except AttributeError: # Fall back to manual parsing if with_structured_output is not available response = await self.chat.ainvoke(langchain_messages) # type: ignore if not isinstance(response, 'LangChainAIMessage'): raise ModelProviderError( message=f'Response is not an AIMessage: {type(response)}', model=self.name, ) content = response.content if hasattr(response, 'content') else str(response) try: if isinstance(content, str): import json parsed_data = json.loads(content) if isinstance(parsed_data, dict): parsed_object = output_format(**parsed_data) else: raise ValueError('Parsed JSON is not a dictionary') else: raise ValueError('Content is not a string and structured output not supported') except Exception as e: raise ModelProviderError( message=f'Failed to parse response as {output_format.__name__}: {e}', model=self.name, ) from e usage = self._get_usage(response) return ChatInvokeCompletion( completion=parsed_object, usage=usage, ) except Exception as e: # Convert any LangChain errors to browser-use ModelProviderError raise ModelProviderError( message=f'LangChain model error: {str(e)}', model=self.name, ) from e