Spaces:
Paused
Paused
| import copy | |
| import json | |
| import os | |
| from typing import TYPE_CHECKING, Any, Optional, Union | |
| import httpx | |
| from pydantic import BaseModel | |
| import litellm | |
| from litellm._logging import verbose_logger | |
| from litellm.litellm_core_utils.litellm_logging import Logging as LitellmLogging | |
| from litellm.llms.custom_httpx.http_handler import ( | |
| AsyncHTTPHandler, | |
| HTTPHandler, | |
| _get_httpx_client, | |
| get_async_httpx_client, | |
| ) | |
| from litellm.types.utils import ImageResponse | |
| from ..base_aws_llm import BaseAWSLLM | |
| from ..common_utils import BedrockError | |
| if TYPE_CHECKING: | |
| from botocore.awsrequest import AWSPreparedRequest | |
| else: | |
| AWSPreparedRequest = Any | |
| class BedrockImagePreparedRequest(BaseModel): | |
| """ | |
| Internal/Helper class for preparing the request for bedrock image generation | |
| """ | |
| endpoint_url: str | |
| prepped: AWSPreparedRequest | |
| body: bytes | |
| data: dict | |
| class BedrockImageGeneration(BaseAWSLLM): | |
| """ | |
| Bedrock Image Generation handler | |
| """ | |
| def image_generation( | |
| self, | |
| model: str, | |
| prompt: str, | |
| model_response: ImageResponse, | |
| optional_params: dict, | |
| logging_obj: LitellmLogging, | |
| timeout: Optional[Union[float, httpx.Timeout]], | |
| aimg_generation: bool = False, | |
| api_base: Optional[str] = None, | |
| extra_headers: Optional[dict] = None, | |
| client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, | |
| ): | |
| prepared_request = self._prepare_request( | |
| model=model, | |
| optional_params=optional_params, | |
| api_base=api_base, | |
| extra_headers=extra_headers, | |
| logging_obj=logging_obj, | |
| prompt=prompt, | |
| ) | |
| if aimg_generation is True: | |
| return self.async_image_generation( | |
| prepared_request=prepared_request, | |
| timeout=timeout, | |
| model=model, | |
| logging_obj=logging_obj, | |
| prompt=prompt, | |
| model_response=model_response, | |
| client=( | |
| client | |
| if client is not None and isinstance(client, AsyncHTTPHandler) | |
| else None | |
| ), | |
| ) | |
| if client is None or not isinstance(client, HTTPHandler): | |
| client = _get_httpx_client() | |
| try: | |
| response = client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore | |
| response.raise_for_status() | |
| except httpx.HTTPStatusError as err: | |
| error_code = err.response.status_code | |
| raise BedrockError(status_code=error_code, message=err.response.text) | |
| except httpx.TimeoutException: | |
| raise BedrockError(status_code=408, message="Timeout error occurred.") | |
| ### FORMAT RESPONSE TO OPENAI FORMAT ### | |
| model_response = self._transform_response_dict_to_openai_response( | |
| model_response=model_response, | |
| model=model, | |
| logging_obj=logging_obj, | |
| prompt=prompt, | |
| response=response, | |
| data=prepared_request.data, | |
| ) | |
| return model_response | |
| async def async_image_generation( | |
| self, | |
| prepared_request: BedrockImagePreparedRequest, | |
| timeout: Optional[Union[float, httpx.Timeout]], | |
| model: str, | |
| logging_obj: LitellmLogging, | |
| prompt: str, | |
| model_response: ImageResponse, | |
| client: Optional[AsyncHTTPHandler] = None, | |
| ) -> ImageResponse: | |
| """ | |
| Asynchronous handler for bedrock image generation | |
| Awaits the response from the bedrock image generation endpoint | |
| """ | |
| async_client = client or get_async_httpx_client( | |
| llm_provider=litellm.LlmProviders.BEDROCK, | |
| params={"timeout": timeout}, | |
| ) | |
| try: | |
| response = await async_client.post(url=prepared_request.endpoint_url, headers=prepared_request.prepped.headers, data=prepared_request.body) # type: ignore | |
| response.raise_for_status() | |
| except httpx.HTTPStatusError as err: | |
| error_code = err.response.status_code | |
| raise BedrockError(status_code=error_code, message=err.response.text) | |
| except httpx.TimeoutException: | |
| raise BedrockError(status_code=408, message="Timeout error occurred.") | |
| ### FORMAT RESPONSE TO OPENAI FORMAT ### | |
| model_response = self._transform_response_dict_to_openai_response( | |
| model=model, | |
| logging_obj=logging_obj, | |
| prompt=prompt, | |
| response=response, | |
| data=prepared_request.data, | |
| model_response=model_response, | |
| ) | |
| return model_response | |
| def _prepare_request( | |
| self, | |
| model: str, | |
| optional_params: dict, | |
| api_base: Optional[str], | |
| extra_headers: Optional[dict], | |
| logging_obj: LitellmLogging, | |
| prompt: str, | |
| ) -> BedrockImagePreparedRequest: | |
| """ | |
| Prepare the request body, headers, and endpoint URL for the Bedrock Image Generation API | |
| Args: | |
| model (str): The model to use for the image generation | |
| optional_params (dict): The optional parameters for the image generation | |
| api_base (Optional[str]): The base URL for the Bedrock API | |
| extra_headers (Optional[dict]): The extra headers to include in the request | |
| logging_obj (LitellmLogging): The logging object to use for logging | |
| prompt (str): The prompt to use for the image generation | |
| Returns: | |
| BedrockImagePreparedRequest: The prepared request object | |
| The BedrockImagePreparedRequest contains: | |
| endpoint_url (str): The endpoint URL for the Bedrock Image Generation API | |
| prepped (httpx.Request): The prepared request object | |
| body (bytes): The request body | |
| """ | |
| try: | |
| from botocore.auth import SigV4Auth | |
| from botocore.awsrequest import AWSRequest | |
| except ImportError: | |
| raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") | |
| boto3_credentials_info = self._get_boto_credentials_from_optional_params( | |
| optional_params, model | |
| ) | |
| ### SET RUNTIME ENDPOINT ### | |
| modelId = model | |
| _, proxy_endpoint_url = self.get_runtime_endpoint( | |
| api_base=api_base, | |
| aws_bedrock_runtime_endpoint=boto3_credentials_info.aws_bedrock_runtime_endpoint, | |
| aws_region_name=boto3_credentials_info.aws_region_name, | |
| ) | |
| proxy_endpoint_url = f"{proxy_endpoint_url}/model/{modelId}/invoke" | |
| sigv4 = SigV4Auth( | |
| boto3_credentials_info.credentials, | |
| "bedrock", | |
| boto3_credentials_info.aws_region_name, | |
| ) | |
| data = self._get_request_body( | |
| model=model, prompt=prompt, optional_params=optional_params | |
| ) | |
| # Make POST Request | |
| body = json.dumps(data).encode("utf-8") | |
| headers = {"Content-Type": "application/json"} | |
| if extra_headers is not None: | |
| headers = {"Content-Type": "application/json", **extra_headers} | |
| request = AWSRequest( | |
| method="POST", url=proxy_endpoint_url, data=body, headers=headers | |
| ) | |
| sigv4.add_auth(request) | |
| if ( | |
| extra_headers is not None and "Authorization" in extra_headers | |
| ): # prevent sigv4 from overwriting the auth header | |
| request.headers["Authorization"] = extra_headers["Authorization"] | |
| prepped = request.prepare() | |
| ## LOGGING | |
| logging_obj.pre_call( | |
| input=prompt, | |
| api_key="", | |
| additional_args={ | |
| "complete_input_dict": data, | |
| "api_base": proxy_endpoint_url, | |
| "headers": prepped.headers, | |
| }, | |
| ) | |
| return BedrockImagePreparedRequest( | |
| endpoint_url=proxy_endpoint_url, | |
| prepped=prepped, | |
| body=body, | |
| data=data, | |
| ) | |
| def _get_request_body( | |
| self, | |
| model: str, | |
| prompt: str, | |
| optional_params: dict, | |
| ) -> dict: | |
| """ | |
| Get the request body for the Bedrock Image Generation API | |
| Checks the model/provider and transforms the request body accordingly | |
| Returns: | |
| dict: The request body to use for the Bedrock Image Generation API | |
| """ | |
| provider = model.split(".")[0] | |
| inference_params = copy.deepcopy(optional_params) | |
| inference_params.pop( | |
| "user", None | |
| ) # make sure user is not passed in for bedrock call | |
| data = {} | |
| if provider == "stability": | |
| if litellm.AmazonStability3Config._is_stability_3_model(model): | |
| request_body = litellm.AmazonStability3Config.transform_request_body( | |
| prompt=prompt, optional_params=optional_params | |
| ) | |
| return dict(request_body) | |
| else: | |
| prompt = prompt.replace(os.linesep, " ") | |
| ## LOAD CONFIG | |
| config = litellm.AmazonStabilityConfig.get_config() | |
| for k, v in config.items(): | |
| if ( | |
| k not in inference_params | |
| ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in | |
| inference_params[k] = v | |
| data = { | |
| "text_prompts": [{"text": prompt, "weight": 1}], | |
| **inference_params, | |
| } | |
| elif provider == "amazon": | |
| return dict( | |
| litellm.AmazonNovaCanvasConfig.transform_request_body( | |
| text=prompt, optional_params=optional_params | |
| ) | |
| ) | |
| else: | |
| raise BedrockError( | |
| status_code=422, message=f"Unsupported model={model}, passed in" | |
| ) | |
| return data | |
| def _transform_response_dict_to_openai_response( | |
| self, | |
| model_response: ImageResponse, | |
| model: str, | |
| logging_obj: LitellmLogging, | |
| prompt: str, | |
| response: httpx.Response, | |
| data: dict, | |
| ) -> ImageResponse: | |
| """ | |
| Transforms the Image Generation response from Bedrock to OpenAI format | |
| """ | |
| ## LOGGING | |
| if logging_obj is not None: | |
| logging_obj.post_call( | |
| input=prompt, | |
| api_key="", | |
| original_response=response.text, | |
| additional_args={"complete_input_dict": data}, | |
| ) | |
| verbose_logger.debug("raw model_response: %s", response.text) | |
| response_dict = response.json() | |
| if response_dict is None: | |
| raise ValueError("Error in response object format, got None") | |
| config_class = ( | |
| litellm.AmazonStability3Config | |
| if litellm.AmazonStability3Config._is_stability_3_model(model=model) | |
| else ( | |
| litellm.AmazonNovaCanvasConfig | |
| if litellm.AmazonNovaCanvasConfig._is_nova_model(model=model) | |
| else litellm.AmazonStabilityConfig | |
| ) | |
| ) | |
| config_class.transform_response_dict_to_openai_response( | |
| model_response=model_response, | |
| response_dict=response_dict, | |
| ) | |
| return model_response | |