Spaces:
Paused
Paused
| """ | |
| Common utilities used across bedrock chat/embedding/image generation | |
| """ | |
| import os | |
| from typing import List, Literal, Optional, Union | |
| import httpx | |
| import litellm | |
| from litellm.llms.base_llm.base_utils import BaseLLMModelInfo | |
| from litellm.llms.base_llm.chat.transformation import BaseLLMException | |
| from litellm.secret_managers.main import get_secret | |
| class BedrockError(BaseLLMException): | |
| pass | |
| class AmazonBedrockGlobalConfig: | |
| def __init__(self): | |
| pass | |
| def get_mapped_special_auth_params(self) -> dict: | |
| """ | |
| Mapping of common auth params across bedrock/vertex/azure/watsonx | |
| """ | |
| return {"region_name": "aws_region_name"} | |
| def map_special_auth_params(self, non_default_params: dict, optional_params: dict): | |
| mapped_params = self.get_mapped_special_auth_params() | |
| for param, value in non_default_params.items(): | |
| if param in mapped_params: | |
| optional_params[mapped_params[param]] = value | |
| return optional_params | |
| def get_all_regions(self) -> List[str]: | |
| return ( | |
| self.get_us_regions() | |
| + self.get_eu_regions() | |
| + self.get_ap_regions() | |
| + self.get_ca_regions() | |
| + self.get_sa_regions() | |
| ) | |
| def get_ap_regions(self) -> List[str]: | |
| """ | |
| Source: https://www.aws-services.info/bedrock.html | |
| """ | |
| return [ | |
| "ap-northeast-1", # Asia Pacific (Tokyo) | |
| "ap-northeast-2", # Asia Pacific (Seoul) | |
| "ap-northeast-3", # Asia Pacific (Osaka) | |
| "ap-south-1", # Asia Pacific (Mumbai) | |
| "ap-south-2", # Asia Pacific (Hyderabad) | |
| "ap-southeast-1", # Asia Pacific (Singapore) | |
| "ap-southeast-2", # Asia Pacific (Sydney) | |
| ] | |
| def get_sa_regions(self) -> List[str]: | |
| return ["sa-east-1"] | |
| def get_eu_regions(self) -> List[str]: | |
| """ | |
| Source: https://www.aws-services.info/bedrock.html | |
| """ | |
| return [ | |
| "eu-west-1", # Europe (Ireland) | |
| "eu-west-2", # Europe (London) | |
| "eu-west-3", # Europe (Paris) | |
| "eu-central-1", # Europe (Frankfurt) | |
| "eu-central-2", # Europe (Zurich) | |
| "eu-south-1", # Europe (Milan) | |
| "eu-south-2", # Europe (Spain) | |
| "eu-north-1", # Europe (Stockholm) | |
| ] | |
| def get_ca_regions(self) -> List[str]: | |
| return ["ca-central-1"] | |
| def get_us_regions(self) -> List[str]: | |
| """ | |
| Source: https://www.aws-services.info/bedrock.html | |
| """ | |
| return [ | |
| "us-east-1", # US East (N. Virginia) | |
| "us-east-2", # US East (Ohio) | |
| "us-west-1", # US West (N. California) | |
| "us-west-2", # US West (Oregon) | |
| "us-gov-east-1", # AWS GovCloud (US-East) | |
| "us-gov-west-1", # AWS GovCloud (US-West) | |
| ] | |
| def add_custom_header(headers): | |
| """Closure to capture the headers and add them.""" | |
| def callback(request, **kwargs): | |
| """Actual callback function that Boto3 will call.""" | |
| for header_name, header_value in headers.items(): | |
| request.headers.add_header(header_name, header_value) | |
| return callback | |
| def init_bedrock_client( | |
| region_name=None, | |
| aws_access_key_id: Optional[str] = None, | |
| aws_secret_access_key: Optional[str] = None, | |
| aws_region_name: Optional[str] = None, | |
| aws_bedrock_runtime_endpoint: Optional[str] = None, | |
| aws_session_name: Optional[str] = None, | |
| aws_profile_name: Optional[str] = None, | |
| aws_role_name: Optional[str] = None, | |
| aws_web_identity_token: Optional[str] = None, | |
| extra_headers: Optional[dict] = None, | |
| timeout: Optional[Union[float, httpx.Timeout]] = None, | |
| ): | |
| # check for custom AWS_REGION_NAME and use it if not passed to init_bedrock_client | |
| litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) | |
| standard_aws_region_name = get_secret("AWS_REGION", None) | |
| ## CHECK IS 'os.environ/' passed in | |
| # Define the list of parameters to check | |
| params_to_check = [ | |
| aws_access_key_id, | |
| aws_secret_access_key, | |
| aws_region_name, | |
| aws_bedrock_runtime_endpoint, | |
| aws_session_name, | |
| aws_profile_name, | |
| aws_role_name, | |
| aws_web_identity_token, | |
| ] | |
| # Iterate over parameters and update if needed | |
| for i, param in enumerate(params_to_check): | |
| if param and param.startswith("os.environ/"): | |
| params_to_check[i] = get_secret(param) # type: ignore | |
| # Assign updated values back to parameters | |
| ( | |
| aws_access_key_id, | |
| aws_secret_access_key, | |
| aws_region_name, | |
| aws_bedrock_runtime_endpoint, | |
| aws_session_name, | |
| aws_profile_name, | |
| aws_role_name, | |
| aws_web_identity_token, | |
| ) = params_to_check | |
| # SSL certificates (a.k.a CA bundle) used to verify the identity of requested hosts. | |
| ssl_verify = os.getenv("SSL_VERIFY", litellm.ssl_verify) | |
| ### SET REGION NAME | |
| if region_name: | |
| pass | |
| elif aws_region_name: | |
| region_name = aws_region_name | |
| elif litellm_aws_region_name: | |
| region_name = litellm_aws_region_name | |
| elif standard_aws_region_name: | |
| region_name = standard_aws_region_name | |
| else: | |
| raise BedrockError( | |
| message="AWS region not set: set AWS_REGION_NAME or AWS_REGION env variable or in .env file", | |
| status_code=401, | |
| ) | |
| # check for custom AWS_BEDROCK_RUNTIME_ENDPOINT and use it if not passed to init_bedrock_client | |
| env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") | |
| if aws_bedrock_runtime_endpoint: | |
| endpoint_url = aws_bedrock_runtime_endpoint | |
| elif env_aws_bedrock_runtime_endpoint: | |
| endpoint_url = env_aws_bedrock_runtime_endpoint | |
| else: | |
| endpoint_url = f"https://bedrock-runtime.{region_name}.amazonaws.com" | |
| import boto3 | |
| if isinstance(timeout, float): | |
| config = boto3.session.Config(connect_timeout=timeout, read_timeout=timeout) # type: ignore | |
| elif isinstance(timeout, httpx.Timeout): | |
| config = boto3.session.Config( # type: ignore | |
| connect_timeout=timeout.connect, read_timeout=timeout.read | |
| ) | |
| else: | |
| config = boto3.session.Config() # type: ignore | |
| ### CHECK STS ### | |
| if ( | |
| aws_web_identity_token is not None | |
| and aws_role_name is not None | |
| and aws_session_name is not None | |
| ): | |
| oidc_token = get_secret(aws_web_identity_token) | |
| if oidc_token is None: | |
| raise BedrockError( | |
| message="OIDC token could not be retrieved from secret manager.", | |
| status_code=401, | |
| ) | |
| sts_client = boto3.client("sts") | |
| # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html | |
| # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html | |
| sts_response = sts_client.assume_role_with_web_identity( | |
| RoleArn=aws_role_name, | |
| RoleSessionName=aws_session_name, | |
| WebIdentityToken=oidc_token, | |
| DurationSeconds=3600, | |
| ) | |
| client = boto3.client( | |
| service_name="bedrock-runtime", | |
| aws_access_key_id=sts_response["Credentials"]["AccessKeyId"], | |
| aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"], | |
| aws_session_token=sts_response["Credentials"]["SessionToken"], | |
| region_name=region_name, | |
| endpoint_url=endpoint_url, | |
| config=config, | |
| verify=ssl_verify, | |
| ) | |
| elif aws_role_name is not None and aws_session_name is not None: | |
| # use sts if role name passed in | |
| sts_client = boto3.client( | |
| "sts", | |
| aws_access_key_id=aws_access_key_id, | |
| aws_secret_access_key=aws_secret_access_key, | |
| ) | |
| sts_response = sts_client.assume_role( | |
| RoleArn=aws_role_name, RoleSessionName=aws_session_name | |
| ) | |
| client = boto3.client( | |
| service_name="bedrock-runtime", | |
| aws_access_key_id=sts_response["Credentials"]["AccessKeyId"], | |
| aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"], | |
| aws_session_token=sts_response["Credentials"]["SessionToken"], | |
| region_name=region_name, | |
| endpoint_url=endpoint_url, | |
| config=config, | |
| verify=ssl_verify, | |
| ) | |
| elif aws_access_key_id is not None: | |
| # uses auth params passed to completion | |
| # aws_access_key_id is not None, assume user is trying to auth using litellm.completion | |
| client = boto3.client( | |
| service_name="bedrock-runtime", | |
| aws_access_key_id=aws_access_key_id, | |
| aws_secret_access_key=aws_secret_access_key, | |
| region_name=region_name, | |
| endpoint_url=endpoint_url, | |
| config=config, | |
| verify=ssl_verify, | |
| ) | |
| elif aws_profile_name is not None: | |
| # uses auth values from AWS profile usually stored in ~/.aws/credentials | |
| client = boto3.Session(profile_name=aws_profile_name).client( | |
| service_name="bedrock-runtime", | |
| region_name=region_name, | |
| endpoint_url=endpoint_url, | |
| config=config, | |
| verify=ssl_verify, | |
| ) | |
| else: | |
| # aws_access_key_id is None, assume user is trying to auth using env variables | |
| # boto3 automatically reads env variables | |
| client = boto3.client( | |
| service_name="bedrock-runtime", | |
| region_name=region_name, | |
| endpoint_url=endpoint_url, | |
| config=config, | |
| verify=ssl_verify, | |
| ) | |
| if extra_headers: | |
| client.meta.events.register( | |
| "before-sign.bedrock-runtime.*", add_custom_header(extra_headers) | |
| ) | |
| return client | |
| class ModelResponseIterator: | |
| def __init__(self, model_response): | |
| self.model_response = model_response | |
| self.is_done = False | |
| # Sync iterator | |
| def __iter__(self): | |
| return self | |
| def __next__(self): | |
| if self.is_done: | |
| raise StopIteration | |
| self.is_done = True | |
| return self.model_response | |
| # Async iterator | |
| def __aiter__(self): | |
| return self | |
| async def __anext__(self): | |
| if self.is_done: | |
| raise StopAsyncIteration | |
| self.is_done = True | |
| return self.model_response | |
| def get_bedrock_tool_name(response_tool_name: str) -> str: | |
| """ | |
| If litellm formatted the input tool name, we need to convert it back to the original name. | |
| Args: | |
| response_tool_name (str): The name of the tool as received from the response. | |
| Returns: | |
| str: The original name of the tool. | |
| """ | |
| if response_tool_name in litellm.bedrock_tool_name_mappings.cache_dict: | |
| response_tool_name = litellm.bedrock_tool_name_mappings.cache_dict[ | |
| response_tool_name | |
| ] | |
| return response_tool_name | |
| class BedrockModelInfo(BaseLLMModelInfo): | |
| global_config = AmazonBedrockGlobalConfig() | |
| all_global_regions = global_config.get_all_regions() | |
| def extract_model_name_from_arn(model: str) -> str: | |
| """ | |
| Extract the model name from an AWS Bedrock ARN. | |
| Returns the string after the last '/' if 'arn' is in the input string. | |
| Args: | |
| arn (str): The ARN string to parse | |
| Returns: | |
| str: The extracted model name if 'arn' is in the string, | |
| otherwise returns the original string | |
| """ | |
| if "arn" in model.lower(): | |
| return model.split("/")[-1] | |
| return model | |
| def get_non_litellm_routing_model_name(model: str) -> str: | |
| if model.startswith("bedrock/"): | |
| model = model.split("/", 1)[1] | |
| if model.startswith("converse/"): | |
| model = model.split("/", 1)[1] | |
| if model.startswith("invoke/"): | |
| model = model.split("/", 1)[1] | |
| return model | |
| def get_base_model(model: str) -> str: | |
| """ | |
| Get the base model from the given model name. | |
| Handle model names like - "us.meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1" | |
| AND "meta.llama3-2-11b-instruct-v1:0" -> "meta.llama3-2-11b-instruct-v1" | |
| """ | |
| model = BedrockModelInfo.get_non_litellm_routing_model_name(model=model) | |
| model = BedrockModelInfo.extract_model_name_from_arn(model) | |
| potential_region = model.split(".", 1)[0] | |
| alt_potential_region = model.split("/", 1)[ | |
| 0 | |
| ] # in model cost map we store regional information like `/us-west-2/bedrock-model` | |
| if ( | |
| potential_region | |
| in BedrockModelInfo._supported_cross_region_inference_region() | |
| ): | |
| return model.split(".", 1)[1] | |
| elif ( | |
| alt_potential_region in BedrockModelInfo.all_global_regions | |
| and len(model.split("/", 1)) > 1 | |
| ): | |
| return model.split("/", 1)[1] | |
| return model | |
| def _supported_cross_region_inference_region() -> List[str]: | |
| """ | |
| Abbreviations of regions AWS Bedrock supports for cross region inference | |
| """ | |
| return ["us", "eu", "apac"] | |
| def get_bedrock_route(model: str) -> Literal["converse", "invoke", "converse_like"]: | |
| """ | |
| Get the bedrock route for the given model. | |
| """ | |
| base_model = BedrockModelInfo.get_base_model(model) | |
| alt_model = BedrockModelInfo.get_non_litellm_routing_model_name(model=model) | |
| if "invoke/" in model: | |
| return "invoke" | |
| elif "converse_like" in model: | |
| return "converse_like" | |
| elif "converse/" in model: | |
| return "converse" | |
| elif ( | |
| base_model in litellm.bedrock_converse_models | |
| or alt_model in litellm.bedrock_converse_models | |
| ): | |
| return "converse" | |
| return "invoke" | |