Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| from typing import TYPE_CHECKING, TypeVar, overload | |
| from pydantic import BaseModel | |
| from browser_use.llm.base import BaseChatModel | |
| from browser_use.llm.exceptions import ModelProviderError | |
| from browser_use.llm.messages import BaseMessage | |
| from browser_use.llm.views import ChatInvokeCompletion, ChatInvokeUsage | |
| from examples.models.langchain.serializer import LangChainMessageSerializer | |
| if TYPE_CHECKING: | |
| from langchain_core.language_models.chat_models import BaseChatModel as LangChainBaseChatModel # type: ignore | |
| from langchain_core.messages import AIMessage as LangChainAIMessage # type: ignore | |
| T = TypeVar('T', bound=BaseModel) | |
| class ChatLangchain(BaseChatModel): | |
| """ | |
| A wrapper around LangChain BaseChatModel that implements the browser-use BaseChatModel protocol. | |
| This class allows you to use any LangChain-compatible model with browser-use. | |
| """ | |
| # The LangChain model to wrap | |
| chat: 'LangChainBaseChatModel' | |
| def model(self) -> str: | |
| return self.name | |
| def provider(self) -> str: | |
| """Return the provider name based on the LangChain model class.""" | |
| model_class_name = self.chat.__class__.__name__.lower() | |
| if 'openai' in model_class_name: | |
| return 'openai' | |
| elif 'anthropic' in model_class_name or 'claude' in model_class_name: | |
| return 'anthropic' | |
| elif 'google' in model_class_name or 'gemini' in model_class_name: | |
| return 'google' | |
| elif 'groq' in model_class_name: | |
| return 'groq' | |
| elif 'ollama' in model_class_name: | |
| return 'ollama' | |
| elif 'deepseek' in model_class_name: | |
| return 'deepseek' | |
| else: | |
| return 'langchain' | |
| def name(self) -> str: | |
| """Return the model name.""" | |
| # Try to get model name from the LangChain model using getattr to avoid type errors | |
| model_name = getattr(self.chat, 'model_name', None) | |
| if model_name: | |
| return str(model_name) | |
| model_attr = getattr(self.chat, 'model', None) | |
| if model_attr: | |
| return str(model_attr) | |
| return self.chat.__class__.__name__ | |
| def _get_usage(self, response: 'LangChainAIMessage') -> ChatInvokeUsage | None: | |
| usage = response.usage_metadata | |
| if usage is None: | |
| return None | |
| prompt_tokens = usage['input_tokens'] or 0 | |
| completion_tokens = usage['output_tokens'] or 0 | |
| total_tokens = usage['total_tokens'] or 0 | |
| input_token_details = usage.get('input_token_details', None) | |
| if input_token_details is not None: | |
| prompt_cached_tokens = input_token_details.get('cache_read', None) | |
| prompt_cache_creation_tokens = input_token_details.get('cache_creation', None) | |
| else: | |
| prompt_cached_tokens = None | |
| prompt_cache_creation_tokens = None | |
| return ChatInvokeUsage( | |
| prompt_tokens=prompt_tokens, | |
| prompt_cached_tokens=prompt_cached_tokens, | |
| prompt_cache_creation_tokens=prompt_cache_creation_tokens, | |
| prompt_image_tokens=None, | |
| completion_tokens=completion_tokens, | |
| total_tokens=total_tokens, | |
| ) | |
| async def ainvoke(self, messages: list[BaseMessage], output_format: None = None) -> ChatInvokeCompletion[str]: ... | |
| async def ainvoke(self, messages: list[BaseMessage], output_format: type[T]) -> ChatInvokeCompletion[T]: ... | |
| async def ainvoke( | |
| self, messages: list[BaseMessage], output_format: type[T] | None = None | |
| ) -> ChatInvokeCompletion[T] | ChatInvokeCompletion[str]: | |
| """ | |
| Invoke the LangChain model with the given messages. | |
| Args: | |
| messages: List of browser-use chat messages | |
| output_format: Optional Pydantic model class for structured output (not supported in basic LangChain integration) | |
| Returns: | |
| Either a string response or an instance of output_format | |
| """ | |
| # Convert browser-use messages to LangChain messages | |
| langchain_messages = LangChainMessageSerializer.serialize_messages(messages) | |
| try: | |
| if output_format is None: | |
| # Return string response | |
| response = await self.chat.ainvoke(langchain_messages) # type: ignore | |
| # Import at runtime for isinstance check | |
| from langchain_core.messages import AIMessage as LangChainAIMessage # type: ignore | |
| if not isinstance(response, LangChainAIMessage): | |
| raise ModelProviderError( | |
| message=f'Response is not an AIMessage: {type(response)}', | |
| model=self.name, | |
| ) | |
| # Extract content from LangChain response | |
| content = response.content if hasattr(response, 'content') else str(response) | |
| usage = self._get_usage(response) | |
| return ChatInvokeCompletion( | |
| completion=str(content), | |
| usage=usage, | |
| ) | |
| else: | |
| # Use LangChain's structured output capability | |
| try: | |
| structured_chat = self.chat.with_structured_output(output_format) | |
| parsed_object = await structured_chat.ainvoke(langchain_messages) | |
| # For structured output, usage metadata is typically not available | |
| # in the parsed object since it's a Pydantic model, not an AIMessage | |
| usage = None | |
| # Type cast since LangChain's with_structured_output returns the correct type | |
| return ChatInvokeCompletion( | |
| completion=parsed_object, # type: ignore | |
| usage=usage, | |
| ) | |
| except AttributeError: | |
| # Fall back to manual parsing if with_structured_output is not available | |
| response = await self.chat.ainvoke(langchain_messages) # type: ignore | |
| if not isinstance(response, 'LangChainAIMessage'): | |
| raise ModelProviderError( | |
| message=f'Response is not an AIMessage: {type(response)}', | |
| model=self.name, | |
| ) | |
| content = response.content if hasattr(response, 'content') else str(response) | |
| try: | |
| if isinstance(content, str): | |
| import json | |
| parsed_data = json.loads(content) | |
| if isinstance(parsed_data, dict): | |
| parsed_object = output_format(**parsed_data) | |
| else: | |
| raise ValueError('Parsed JSON is not a dictionary') | |
| else: | |
| raise ValueError('Content is not a string and structured output not supported') | |
| except Exception as e: | |
| raise ModelProviderError( | |
| message=f'Failed to parse response as {output_format.__name__}: {e}', | |
| model=self.name, | |
| ) from e | |
| usage = self._get_usage(response) | |
| return ChatInvokeCompletion( | |
| completion=parsed_object, | |
| usage=usage, | |
| ) | |
| except Exception as e: | |
| # Convert any LangChain errors to browser-use ModelProviderError | |
| raise ModelProviderError( | |
| message=f'LangChain model error: {str(e)}', | |
| model=self.name, | |
| ) from e | |