| | import copy
|
| | import os
|
| | import time
|
| | import warnings
|
| | from functools import partial
|
| | from typing import Any
|
| |
|
| | import requests
|
| |
|
| | from openhands.core.config import LLMConfig
|
| |
|
| | with warnings.catch_warnings():
|
| | warnings.simplefilter('ignore')
|
| | import litellm
|
| |
|
| | from litellm import ChatCompletionMessageToolCall, ModelInfo, PromptTokensDetails
|
| | from litellm import Message as LiteLLMMessage
|
| | from litellm import completion as litellm_completion
|
| | from litellm import completion_cost as litellm_completion_cost
|
| | from litellm.exceptions import (
|
| | APIConnectionError,
|
| | APIError,
|
| | InternalServerError,
|
| | RateLimitError,
|
| | ServiceUnavailableError,
|
| | )
|
| | from litellm.types.utils import CostPerToken, ModelResponse, Usage
|
| | from litellm.utils import create_pretrained_tokenizer
|
| |
|
| | from openhands.core.exceptions import CloudFlareBlockageError
|
| | from openhands.core.logger import openhands_logger as logger
|
| | from openhands.core.message import Message
|
| | from openhands.llm.debug_mixin import DebugMixin
|
| | from openhands.llm.fn_call_converter import (
|
| | STOP_WORDS,
|
| | convert_fncall_messages_to_non_fncall_messages,
|
| | convert_non_fncall_messages_to_fncall_messages,
|
| | )
|
| | from openhands.llm.metrics import Metrics
|
| | from openhands.llm.retry_mixin import RetryMixin
|
| |
|
| | __all__ = ['LLM']
|
| |
|
| |
|
| | LLM_RETRY_EXCEPTIONS: tuple[type[Exception], ...] = (
|
| | APIConnectionError,
|
| |
|
| |
|
| | APIError,
|
| | InternalServerError,
|
| | RateLimitError,
|
| | ServiceUnavailableError,
|
| | )
|
| |
|
| |
|
| |
|
| | CACHE_PROMPT_SUPPORTED_MODELS = [
|
| | 'claude-3-5-sonnet-20241022',
|
| | 'claude-3-5-sonnet-20240620',
|
| | 'claude-3-5-haiku-20241022',
|
| | 'claude-3-haiku-20240307',
|
| | 'claude-3-opus-20240229',
|
| | ]
|
| |
|
| |
|
| | FUNCTION_CALLING_SUPPORTED_MODELS = [
|
| | 'claude-3-5-sonnet',
|
| | 'claude-3-5-sonnet-20240620',
|
| | 'claude-3-5-sonnet-20241022',
|
| | 'claude-3.5-haiku',
|
| | 'claude-3-5-haiku-20241022',
|
| | 'gpt-4o-mini',
|
| | 'gpt-4o',
|
| | 'o1-2024-12-17',
|
| | ]
|
| |
|
| | REASONING_EFFORT_SUPPORTED_MODELS = [
|
| | 'o1-2024-12-17',
|
| | ]
|
| |
|
| | MODELS_WITHOUT_STOP_WORDS = [
|
| | 'o1-mini',
|
| | ]
|
| |
|
| |
|
| | class LLM(RetryMixin, DebugMixin):
|
| | """The LLM class represents a Language Model instance.
|
| |
|
| | Attributes:
|
| | config: an LLMConfig object specifying the configuration of the LLM.
|
| | """
|
| |
|
| | def __init__(
|
| | self,
|
| | config: LLMConfig,
|
| | metrics: Metrics | None = None,
|
| | ):
|
| | """Initializes the LLM. If LLMConfig is passed, its values will be the fallback.
|
| |
|
| | Passing simple parameters always overrides config.
|
| |
|
| | Args:
|
| | config: The LLM configuration.
|
| | metrics: The metrics to use.
|
| | """
|
| | self._tried_model_info = False
|
| | self.metrics: Metrics = (
|
| | metrics if metrics is not None else Metrics(model_name=config.model)
|
| | )
|
| | self.cost_metric_supported: bool = True
|
| | self.config: LLMConfig = copy.deepcopy(config)
|
| |
|
| | self.model_info: ModelInfo | None = None
|
| |
|
| | if self.config.log_completions:
|
| | if self.config.log_completions_folder is None:
|
| | raise RuntimeError(
|
| | 'log_completions_folder is required when log_completions is enabled'
|
| | )
|
| | os.makedirs(self.config.log_completions_folder, exist_ok=True)
|
| |
|
| |
|
| |
|
| | with warnings.catch_warnings():
|
| | warnings.simplefilter('ignore')
|
| | self.init_model_info()
|
| | if self.vision_is_active():
|
| | logger.debug('LLM: model has vision enabled')
|
| | if self.is_caching_prompt_active():
|
| | logger.debug('LLM: caching prompt enabled')
|
| | if self.is_function_calling_active():
|
| | logger.debug('LLM: model supports function calling')
|
| |
|
| |
|
| | if self.config.custom_tokenizer is not None:
|
| | self.tokenizer = create_pretrained_tokenizer(self.config.custom_tokenizer)
|
| | else:
|
| | self.tokenizer = None
|
| |
|
| |
|
| | self._completion = partial(
|
| | litellm_completion,
|
| | model=self.config.model,
|
| | api_key=self.config.api_key.get_secret_value()
|
| | if self.config.api_key
|
| | else None,
|
| | base_url=self.config.base_url,
|
| | api_version=self.config.api_version,
|
| | custom_llm_provider=self.config.custom_llm_provider,
|
| | max_tokens=self.config.max_output_tokens,
|
| | timeout=self.config.timeout,
|
| | temperature=self.config.temperature,
|
| | top_p=self.config.top_p,
|
| | drop_params=self.config.drop_params,
|
| | )
|
| |
|
| | self._completion_unwrapped = self._completion
|
| |
|
| | @self.retry_decorator(
|
| | num_retries=self.config.num_retries,
|
| | retry_exceptions=LLM_RETRY_EXCEPTIONS,
|
| | retry_min_wait=self.config.retry_min_wait,
|
| | retry_max_wait=self.config.retry_max_wait,
|
| | retry_multiplier=self.config.retry_multiplier,
|
| | )
|
| | def wrapper(*args, **kwargs):
|
| | """Wrapper for the litellm completion function. Logs the input and output of the completion function."""
|
| | from openhands.core.utils import json
|
| |
|
| | messages: list[dict[str, Any]] | dict[str, Any] = []
|
| | mock_function_calling = kwargs.pop('mock_function_calling', False)
|
| |
|
| |
|
| |
|
| | if len(args) > 1:
|
| |
|
| |
|
| |
|
| |
|
| | messages = args[1] if len(args) > 1 else args[0]
|
| | kwargs['messages'] = messages
|
| |
|
| |
|
| | args = args[2:]
|
| | elif 'messages' in kwargs:
|
| | messages = kwargs['messages']
|
| |
|
| |
|
| | messages = messages if isinstance(messages, list) else [messages]
|
| | original_fncall_messages = copy.deepcopy(messages)
|
| | mock_fncall_tools = None
|
| | if mock_function_calling:
|
| | assert (
|
| | 'tools' in kwargs
|
| | ), "'tools' must be in kwargs when mock_function_calling is True"
|
| | messages = convert_fncall_messages_to_non_fncall_messages(
|
| | messages, kwargs['tools']
|
| | )
|
| | kwargs['messages'] = messages
|
| | if self.config.model not in MODELS_WITHOUT_STOP_WORDS:
|
| | kwargs['stop'] = STOP_WORDS
|
| | mock_fncall_tools = kwargs.pop('tools')
|
| |
|
| |
|
| | if not messages:
|
| | raise ValueError(
|
| | 'The messages list is empty. At least one message is required.'
|
| | )
|
| |
|
| |
|
| | self.log_prompt(messages)
|
| |
|
| | if self.is_caching_prompt_active():
|
| |
|
| | if 'claude-3' in self.config.model:
|
| | kwargs['extra_headers'] = {
|
| | 'anthropic-beta': 'prompt-caching-2024-07-31',
|
| | }
|
| |
|
| |
|
| | if self.config.model.lower() in REASONING_EFFORT_SUPPORTED_MODELS:
|
| | kwargs['reasoning_effort'] = self.config.reasoning_effort
|
| |
|
| |
|
| |
|
| |
|
| | litellm.modify_params = self.config.modify_params
|
| |
|
| | try:
|
| |
|
| | start_time = time.time()
|
| |
|
| | resp: ModelResponse = self._completion_unwrapped(*args, **kwargs)
|
| |
|
| |
|
| | latency = time.time() - start_time
|
| | response_id = resp.get('id', 'unknown')
|
| | self.metrics.add_response_latency(latency, response_id)
|
| |
|
| | non_fncall_response = copy.deepcopy(resp)
|
| | if mock_function_calling:
|
| | assert len(resp.choices) == 1
|
| | assert mock_fncall_tools is not None
|
| | non_fncall_response_message = resp.choices[0].message
|
| | fn_call_messages_with_response = (
|
| | convert_non_fncall_messages_to_fncall_messages(
|
| | messages + [non_fncall_response_message], mock_fncall_tools
|
| | )
|
| | )
|
| | fn_call_response_message = fn_call_messages_with_response[-1]
|
| | if not isinstance(fn_call_response_message, LiteLLMMessage):
|
| | fn_call_response_message = LiteLLMMessage(
|
| | **fn_call_response_message
|
| | )
|
| | resp.choices[0].message = fn_call_response_message
|
| |
|
| | message_back: str = resp['choices'][0]['message']['content'] or ''
|
| | tool_calls: list[ChatCompletionMessageToolCall] = resp['choices'][0][
|
| | 'message'
|
| | ].get('tool_calls', [])
|
| | if tool_calls:
|
| | for tool_call in tool_calls:
|
| | fn_name = tool_call.function.name
|
| | fn_args = tool_call.function.arguments
|
| | message_back += f'\nFunction call: {fn_name}({fn_args})'
|
| |
|
| |
|
| | self.log_response(message_back)
|
| |
|
| |
|
| | cost = self._post_completion(resp)
|
| |
|
| |
|
| | if self.config.log_completions:
|
| | assert self.config.log_completions_folder is not None
|
| | log_file = os.path.join(
|
| | self.config.log_completions_folder,
|
| |
|
| | f'{self.metrics.model_name.replace("/", "__")}-{time.time()}.json',
|
| | )
|
| |
|
| |
|
| | _d = {
|
| | 'messages': messages,
|
| | 'response': resp,
|
| | 'args': args,
|
| | 'kwargs': {k: v for k, v in kwargs.items() if k != 'messages'},
|
| | 'timestamp': time.time(),
|
| | 'cost': cost,
|
| | }
|
| |
|
| |
|
| | if mock_function_calling:
|
| |
|
| | _d['response'] = non_fncall_response
|
| |
|
| |
|
| | _d['fncall_messages'] = original_fncall_messages
|
| | _d['fncall_response'] = resp
|
| | with open(log_file, 'w') as f:
|
| | f.write(json.dumps(_d))
|
| |
|
| | return resp
|
| | except APIError as e:
|
| | if 'Attention Required! | Cloudflare' in str(e):
|
| | raise CloudFlareBlockageError(
|
| | 'Request blocked by CloudFlare'
|
| | ) from e
|
| | raise
|
| |
|
| | self._completion = wrapper
|
| |
|
| | @property
|
| | def completion(self):
|
| | """Decorator for the litellm completion function.
|
| |
|
| | Check the complete documentation at https://litellm.vercel.app/docs/completion
|
| | """
|
| | return self._completion
|
| |
|
| | def init_model_info(self):
|
| | if self._tried_model_info:
|
| | return
|
| | self._tried_model_info = True
|
| | try:
|
| | if self.config.model.startswith('openrouter'):
|
| | self.model_info = litellm.get_model_info(self.config.model)
|
| | except Exception as e:
|
| | logger.debug(f'Error getting model info: {e}')
|
| |
|
| | if self.config.model.startswith('litellm_proxy/'):
|
| |
|
| |
|
| | response = requests.get(
|
| | f'{self.config.base_url}/v1/model/info',
|
| | headers={
|
| | 'Authorization': f'Bearer {self.config.api_key.get_secret_value() if self.config.api_key else None}'
|
| | },
|
| | )
|
| | resp_json = response.json()
|
| | if 'data' not in resp_json:
|
| | logger.error(
|
| | f'Error getting model info from LiteLLM proxy: {resp_json}'
|
| | )
|
| | all_model_info = resp_json.get('data', [])
|
| | current_model_info = next(
|
| | (
|
| | info
|
| | for info in all_model_info
|
| | if info['model_name']
|
| | == self.config.model.removeprefix('litellm_proxy/')
|
| | ),
|
| | None,
|
| | )
|
| | if current_model_info:
|
| | self.model_info = current_model_info['model_info']
|
| |
|
| |
|
| | if not self.model_info:
|
| | try:
|
| | self.model_info = litellm.get_model_info(
|
| | self.config.model.split(':')[0]
|
| | )
|
| |
|
| | except Exception:
|
| | pass
|
| | if not self.model_info:
|
| | try:
|
| | self.model_info = litellm.get_model_info(
|
| | self.config.model.split('/')[-1]
|
| | )
|
| |
|
| | except Exception:
|
| | pass
|
| | from openhands.core.utils import json
|
| |
|
| | logger.debug(f'Model info: {json.dumps(self.model_info, indent=2)}')
|
| |
|
| | if self.config.model.startswith('huggingface'):
|
| |
|
| | logger.debug(
|
| | f'Setting top_p to 0.9 for Hugging Face model: {self.config.model}'
|
| | )
|
| | self.config.top_p = 0.9 if self.config.top_p == 1 else self.config.top_p
|
| |
|
| |
|
| | if self.config.max_input_tokens is None:
|
| | if (
|
| | self.model_info is not None
|
| | and 'max_input_tokens' in self.model_info
|
| | and isinstance(self.model_info['max_input_tokens'], int)
|
| | ):
|
| | self.config.max_input_tokens = self.model_info['max_input_tokens']
|
| | else:
|
| |
|
| | self.config.max_input_tokens = 4096
|
| |
|
| | if self.config.max_output_tokens is None:
|
| |
|
| | self.config.max_output_tokens = 4096
|
| | if self.model_info is not None:
|
| |
|
| |
|
| | if 'max_output_tokens' in self.model_info and isinstance(
|
| | self.model_info['max_output_tokens'], int
|
| | ):
|
| | self.config.max_output_tokens = self.model_info['max_output_tokens']
|
| | elif 'max_tokens' in self.model_info and isinstance(
|
| | self.model_info['max_tokens'], int
|
| | ):
|
| | self.config.max_output_tokens = self.model_info['max_tokens']
|
| |
|
| | def vision_is_active(self) -> bool:
|
| | with warnings.catch_warnings():
|
| | warnings.simplefilter('ignore')
|
| | return not self.config.disable_vision and self._supports_vision()
|
| |
|
| | def _supports_vision(self) -> bool:
|
| | """Acquire from litellm if model is vision capable.
|
| |
|
| | Returns:
|
| | bool: True if model is vision capable. Return False if model not supported by litellm.
|
| | """
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | return (
|
| | litellm.supports_vision(self.config.model)
|
| | or litellm.supports_vision(self.config.model.split('/')[-1])
|
| | or (
|
| | self.model_info is not None
|
| | and self.model_info.get('supports_vision', False)
|
| | )
|
| | )
|
| |
|
| | def is_caching_prompt_active(self) -> bool:
|
| | """Check if prompt caching is supported and enabled for current model.
|
| |
|
| | Returns:
|
| | boolean: True if prompt caching is supported and enabled for the given model.
|
| | """
|
| | return (
|
| | self.config.caching_prompt is True
|
| | and (
|
| | self.config.model in CACHE_PROMPT_SUPPORTED_MODELS
|
| | or self.config.model.split('/')[-1] in CACHE_PROMPT_SUPPORTED_MODELS
|
| | )
|
| |
|
| | )
|
| |
|
| | def is_function_calling_active(self) -> bool:
|
| |
|
| | model_name_supported = (
|
| | self.config.model in FUNCTION_CALLING_SUPPORTED_MODELS
|
| | or self.config.model.split('/')[-1] in FUNCTION_CALLING_SUPPORTED_MODELS
|
| | or any(m in self.config.model for m in FUNCTION_CALLING_SUPPORTED_MODELS)
|
| | )
|
| |
|
| |
|
| | if self.config.native_tool_calling is None:
|
| | return model_name_supported
|
| | elif self.config.native_tool_calling is False:
|
| | return False
|
| | else:
|
| |
|
| | supports_fn_call = litellm.supports_function_calling(
|
| | model=self.config.model
|
| | )
|
| | return supports_fn_call
|
| |
|
| | def _post_completion(self, response: ModelResponse) -> float:
|
| | """Post-process the completion response.
|
| |
|
| | Logs the cost and usage stats of the completion call.
|
| | """
|
| | try:
|
| | cur_cost = self._completion_cost(response)
|
| | except Exception:
|
| | cur_cost = 0
|
| |
|
| | stats = ''
|
| | if self.cost_metric_supported:
|
| |
|
| | stats = 'Cost: %.2f USD | Accumulated Cost: %.2f USD\n' % (
|
| | cur_cost,
|
| | self.metrics.accumulated_cost,
|
| | )
|
| |
|
| |
|
| | if self.metrics.response_latencies:
|
| | latest_latency = self.metrics.response_latencies[-1]
|
| | stats += 'Response Latency: %.3f seconds\n' % latest_latency.latency
|
| |
|
| | usage: Usage | None = response.get('usage')
|
| |
|
| | if usage:
|
| |
|
| | input_tokens = usage.get('prompt_tokens')
|
| | output_tokens = usage.get('completion_tokens')
|
| |
|
| | if input_tokens:
|
| | stats += 'Input tokens: ' + str(input_tokens)
|
| |
|
| | if output_tokens:
|
| | stats += (
|
| | (' | ' if input_tokens else '')
|
| | + 'Output tokens: '
|
| | + str(output_tokens)
|
| | + '\n'
|
| | )
|
| |
|
| |
|
| | prompt_tokens_details: PromptTokensDetails = usage.get(
|
| | 'prompt_tokens_details'
|
| | )
|
| | cache_hit_tokens = (
|
| | prompt_tokens_details.cached_tokens if prompt_tokens_details else None
|
| | )
|
| | if cache_hit_tokens:
|
| | stats += 'Input tokens (cache hit): ' + str(cache_hit_tokens) + '\n'
|
| |
|
| |
|
| |
|
| |
|
| | model_extra = usage.get('model_extra', {})
|
| | cache_write_tokens = model_extra.get('cache_creation_input_tokens')
|
| | if cache_write_tokens:
|
| | stats += 'Input tokens (cache write): ' + str(cache_write_tokens) + '\n'
|
| |
|
| |
|
| | if stats:
|
| | logger.debug(stats)
|
| |
|
| | return cur_cost
|
| |
|
| | def get_token_count(self, messages: list[dict] | list[Message]) -> int:
|
| | """Get the number of tokens in a list of messages. Use dicts for better token counting.
|
| |
|
| | Args:
|
| | messages (list): A list of messages, either as a list of dicts or as a list of Message objects.
|
| | Returns:
|
| | int: The number of tokens.
|
| | """
|
| |
|
| | if (
|
| | isinstance(messages, list)
|
| | and len(messages) > 0
|
| | and isinstance(messages[0], Message)
|
| | ):
|
| | logger.info(
|
| | 'Message objects now include serialized tool calls in token counting'
|
| | )
|
| | messages = self.format_messages_for_llm(messages)
|
| |
|
| |
|
| |
|
| | try:
|
| | return litellm.token_counter(
|
| | model=self.config.model,
|
| | messages=messages,
|
| | custom_tokenizer=self.tokenizer,
|
| | )
|
| | except Exception as e:
|
| |
|
| | logger.error(
|
| | f'Error getting token count for\n model {self.config.model}\n{e}'
|
| | + (
|
| | f'\ncustom_tokenizer: {self.config.custom_tokenizer}'
|
| | if self.config.custom_tokenizer is not None
|
| | else ''
|
| | )
|
| | )
|
| | return 0
|
| |
|
| | def _is_local(self) -> bool:
|
| | """Determines if the system is using a locally running LLM.
|
| |
|
| | Returns:
|
| | boolean: True if executing a local model.
|
| | """
|
| | if self.config.base_url is not None:
|
| | for substring in ['localhost', '127.0.0.1' '0.0.0.0']:
|
| | if substring in self.config.base_url:
|
| | return True
|
| | elif self.config.model is not None:
|
| | if self.config.model.startswith('ollama'):
|
| | return True
|
| | return False
|
| |
|
| | def _completion_cost(self, response) -> float:
|
| | """Calculate the cost of a completion response based on the model. Local models are treated as free.
|
| | Add the current cost into total cost in metrics.
|
| |
|
| | Args:
|
| | response: A response from a model invocation.
|
| |
|
| | Returns:
|
| | number: The cost of the response.
|
| | """
|
| | if not self.cost_metric_supported:
|
| | return 0.0
|
| |
|
| | extra_kwargs = {}
|
| | if (
|
| | self.config.input_cost_per_token is not None
|
| | and self.config.output_cost_per_token is not None
|
| | ):
|
| | cost_per_token = CostPerToken(
|
| | input_cost_per_token=self.config.input_cost_per_token,
|
| | output_cost_per_token=self.config.output_cost_per_token,
|
| | )
|
| | logger.debug(f'Using custom cost per token: {cost_per_token}')
|
| | extra_kwargs['custom_cost_per_token'] = cost_per_token
|
| |
|
| |
|
| | _hidden_params = getattr(response, '_hidden_params', {})
|
| | cost = _hidden_params.get('additional_headers', {}).get(
|
| | 'llm_provider-x-litellm-response-cost', None
|
| | )
|
| | if cost is not None:
|
| | cost = float(cost)
|
| | logger.debug(f'Got response_cost from response: {cost}')
|
| |
|
| | try:
|
| | if cost is None:
|
| | try:
|
| | cost = litellm_completion_cost(
|
| | completion_response=response, **extra_kwargs
|
| | )
|
| | except Exception as e:
|
| | logger.error(f'Error getting cost from litellm: {e}')
|
| |
|
| | if cost is None:
|
| | _model_name = '/'.join(self.config.model.split('/')[1:])
|
| | cost = litellm_completion_cost(
|
| | completion_response=response, model=_model_name, **extra_kwargs
|
| | )
|
| | logger.debug(
|
| | f'Using fallback model name {_model_name} to get cost: {cost}'
|
| | )
|
| | self.metrics.add_cost(cost)
|
| | return cost
|
| | except Exception:
|
| | self.cost_metric_supported = False
|
| | logger.debug('Cost calculation not supported for this model.')
|
| | return 0.0
|
| |
|
| | def __str__(self):
|
| | if self.config.api_version:
|
| | return f'LLM(model={self.config.model}, api_version={self.config.api_version}, base_url={self.config.base_url})'
|
| | elif self.config.base_url:
|
| | return f'LLM(model={self.config.model}, base_url={self.config.base_url})'
|
| | return f'LLM(model={self.config.model})'
|
| |
|
| | def __repr__(self):
|
| | return str(self)
|
| |
|
| | def reset(self) -> None:
|
| | self.metrics.reset()
|
| |
|
| | def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
|
| | if isinstance(messages, Message):
|
| | messages = [messages]
|
| |
|
| |
|
| | for message in messages:
|
| | message.cache_enabled = self.is_caching_prompt_active()
|
| | message.vision_enabled = self.vision_is_active()
|
| | message.function_calling_enabled = self.is_function_calling_active()
|
| | if 'deepseek' in self.config.model:
|
| | message.force_string_serializer = True
|
| |
|
| |
|
| | return [message.model_dump() for message in messages]
|
| |
|