Spaces:
Sleeping
Sleeping
| """AWS Bedrock ModelClient integration.""" | |
| import os | |
| import json | |
| import logging | |
| import boto3 | |
| import botocore | |
| import backoff | |
| from typing import Dict, Any, Optional, List, Generator, Union, AsyncGenerator | |
| from adalflow.core.model_client import ModelClient | |
| from adalflow.core.types import ModelType, GeneratorOutput | |
| # Configure logging | |
| from api.logging_config import setup_logging | |
| setup_logging() | |
| log = logging.getLogger(__name__) | |
| class BedrockClient(ModelClient): | |
| __doc__ = r"""A component wrapper for the AWS Bedrock API client. | |
| AWS Bedrock provides a unified API that gives access to various foundation models | |
| including Amazon's own models and third-party models like Anthropic Claude. | |
| Example: | |
| ```python | |
| from api.bedrock_client import BedrockClient | |
| client = BedrockClient() | |
| generator = adal.Generator( | |
| model_client=client, | |
| model_kwargs={"model": "anthropic.claude-3-sonnet-20240229-v1:0"} | |
| ) | |
| ``` | |
| """ | |
| def __init__( | |
| self, | |
| aws_access_key_id: Optional[str] = None, | |
| aws_secret_access_key: Optional[str] = None, | |
| aws_region: Optional[str] = None, | |
| aws_role_arn: Optional[str] = None, | |
| *args, | |
| **kwargs | |
| ) -> None: | |
| """Initialize the AWS Bedrock client. | |
| Args: | |
| aws_access_key_id: AWS access key ID. If not provided, will use environment variable AWS_ACCESS_KEY_ID. | |
| aws_secret_access_key: AWS secret access key. If not provided, will use environment variable AWS_SECRET_ACCESS_KEY. | |
| aws_region: AWS region. If not provided, will use environment variable AWS_REGION. | |
| aws_role_arn: AWS IAM role ARN for role-based authentication. If not provided, will use environment variable AWS_ROLE_ARN. | |
| """ | |
| super().__init__(*args, **kwargs) | |
| from api.config import AWS_ACCESS_KEY_ID, AWS_SECRET_ACCESS_KEY, AWS_REGION, AWS_ROLE_ARN | |
| self.aws_access_key_id = aws_access_key_id or AWS_ACCESS_KEY_ID | |
| self.aws_secret_access_key = aws_secret_access_key or AWS_SECRET_ACCESS_KEY | |
| self.aws_region = aws_region or AWS_REGION or "us-east-1" | |
| self.aws_role_arn = aws_role_arn or AWS_ROLE_ARN | |
| self.sync_client = self.init_sync_client() | |
| self.async_client = None # Initialize async client only when needed | |
| def init_sync_client(self): | |
| """Initialize the synchronous AWS Bedrock client.""" | |
| try: | |
| # Create a session with the provided credentials | |
| session = boto3.Session( | |
| aws_access_key_id=self.aws_access_key_id, | |
| aws_secret_access_key=self.aws_secret_access_key, | |
| region_name=self.aws_region | |
| ) | |
| # If a role ARN is provided, assume that role | |
| if self.aws_role_arn: | |
| sts_client = session.client('sts') | |
| assumed_role = sts_client.assume_role( | |
| RoleArn=self.aws_role_arn, | |
| RoleSessionName="DeepWikiBedrockSession" | |
| ) | |
| credentials = assumed_role['Credentials'] | |
| # Create a new session with the assumed role credentials | |
| session = boto3.Session( | |
| aws_access_key_id=credentials['AccessKeyId'], | |
| aws_secret_access_key=credentials['SecretAccessKey'], | |
| aws_session_token=credentials['SessionToken'], | |
| region_name=self.aws_region | |
| ) | |
| # Create the Bedrock client | |
| bedrock_runtime = session.client( | |
| service_name='bedrock-runtime', | |
| region_name=self.aws_region | |
| ) | |
| return bedrock_runtime | |
| except Exception as e: | |
| log.error(f"Error initializing AWS Bedrock client: {str(e)}") | |
| # Return None to indicate initialization failure | |
| return None | |
| def init_async_client(self): | |
| """Initialize the asynchronous AWS Bedrock client. | |
| Note: boto3 doesn't have native async support, so we'll use the sync client | |
| in async methods and handle async behavior at a higher level. | |
| """ | |
| # For now, just return the sync client | |
| return self.sync_client | |
| def _get_model_provider(self, model_id: str) -> str: | |
| """Extract the provider from the model ID. | |
| Args: | |
| model_id: The model ID, e.g., "anthropic.claude-3-sonnet-20240229-v1:0" | |
| Returns: | |
| The provider name, e.g., "anthropic" | |
| """ | |
| if "." in model_id: | |
| return model_id.split(".")[0] | |
| return "amazon" # Default provider | |
| def _format_prompt_for_provider(self, provider: str, prompt: str, messages=None) -> Dict[str, Any]: | |
| """Format the prompt according to the provider's requirements. | |
| Args: | |
| provider: The provider name, e.g., "anthropic" | |
| prompt: The prompt text | |
| messages: Optional list of messages for chat models | |
| Returns: | |
| A dictionary with the formatted prompt | |
| """ | |
| if provider == "anthropic": | |
| # Format for Claude models | |
| if messages: | |
| # Format as a conversation | |
| formatted_messages = [] | |
| for msg in messages: | |
| role = "user" if msg.get("role") == "user" else "assistant" | |
| formatted_messages.append({ | |
| "role": role, | |
| "content": [{"type": "text", "text": msg.get("content", "")}] | |
| }) | |
| return { | |
| "anthropic_version": "bedrock-2023-05-31", | |
| "messages": formatted_messages, | |
| "max_tokens": 4096 | |
| } | |
| else: | |
| # Format as a single prompt | |
| return { | |
| "anthropic_version": "bedrock-2023-05-31", | |
| "messages": [ | |
| {"role": "user", "content": [{"type": "text", "text": prompt}]} | |
| ], | |
| "max_tokens": 4096 | |
| } | |
| elif provider == "amazon": | |
| # Format for Amazon Titan models | |
| return { | |
| "inputText": prompt, | |
| "textGenerationConfig": { | |
| "maxTokenCount": 4096, | |
| "stopSequences": [], | |
| "temperature": 0.7, | |
| "topP": 0.8 | |
| } | |
| } | |
| elif provider == "cohere": | |
| # Format for Cohere models | |
| return { | |
| "prompt": prompt, | |
| "max_tokens": 4096, | |
| "temperature": 0.7, | |
| "p": 0.8 | |
| } | |
| elif provider == "ai21": | |
| # Format for AI21 models | |
| return { | |
| "prompt": prompt, | |
| "maxTokens": 4096, | |
| "temperature": 0.7, | |
| "topP": 0.8 | |
| } | |
| else: | |
| # Default format | |
| return {"prompt": prompt} | |
| def _extract_response_text(self, provider: str, response: Dict[str, Any]) -> str: | |
| """Extract the generated text from the response. | |
| Args: | |
| provider: The provider name, e.g., "anthropic" | |
| response: The response from the Bedrock API | |
| Returns: | |
| The generated text | |
| """ | |
| if provider == "anthropic": | |
| return response.get("content", [{}])[0].get("text", "") | |
| elif provider == "amazon": | |
| return response.get("results", [{}])[0].get("outputText", "") | |
| elif provider == "cohere": | |
| return response.get("generations", [{}])[0].get("text", "") | |
| elif provider == "ai21": | |
| return response.get("completions", [{}])[0].get("data", {}).get("text", "") | |
| else: | |
| # Try to extract text from the response | |
| if isinstance(response, dict): | |
| for key in ["text", "content", "output", "completion"]: | |
| if key in response: | |
| return response[key] | |
| return str(response) | |
| def call(self, api_kwargs: Dict = None, model_type: ModelType = None) -> Any: | |
| """Make a synchronous call to the AWS Bedrock API.""" | |
| api_kwargs = api_kwargs or {} | |
| # Check if client is initialized | |
| if not self.sync_client: | |
| error_msg = "AWS Bedrock client not initialized. Check your AWS credentials and region." | |
| log.error(error_msg) | |
| return error_msg | |
| if model_type == ModelType.LLM: | |
| model_id = api_kwargs.get("model", "anthropic.claude-3-sonnet-20240229-v1:0") | |
| provider = self._get_model_provider(model_id) | |
| # Get the prompt from api_kwargs | |
| prompt = api_kwargs.get("input", "") | |
| messages = api_kwargs.get("messages") | |
| # Format the prompt according to the provider | |
| request_body = self._format_prompt_for_provider(provider, prompt, messages) | |
| # Add model parameters if provided | |
| if "temperature" in api_kwargs: | |
| if provider == "anthropic": | |
| request_body["temperature"] = api_kwargs["temperature"] | |
| elif provider == "amazon": | |
| request_body["textGenerationConfig"]["temperature"] = api_kwargs["temperature"] | |
| elif provider == "cohere": | |
| request_body["temperature"] = api_kwargs["temperature"] | |
| elif provider == "ai21": | |
| request_body["temperature"] = api_kwargs["temperature"] | |
| if "top_p" in api_kwargs: | |
| if provider == "anthropic": | |
| request_body["top_p"] = api_kwargs["top_p"] | |
| elif provider == "amazon": | |
| request_body["textGenerationConfig"]["topP"] = api_kwargs["top_p"] | |
| elif provider == "cohere": | |
| request_body["p"] = api_kwargs["top_p"] | |
| elif provider == "ai21": | |
| request_body["topP"] = api_kwargs["top_p"] | |
| # Convert request body to JSON | |
| body = json.dumps(request_body) | |
| try: | |
| # Make the API call | |
| response = self.sync_client.invoke_model( | |
| modelId=model_id, | |
| body=body | |
| ) | |
| # Parse the response | |
| response_body = json.loads(response["body"].read()) | |
| # Extract the generated text | |
| generated_text = self._extract_response_text(provider, response_body) | |
| return generated_text | |
| except Exception as e: | |
| log.error(f"Error calling AWS Bedrock API: {str(e)}") | |
| return f"Error: {str(e)}" | |
| else: | |
| raise ValueError(f"Model type {model_type} is not supported by AWS Bedrock client") | |
| async def acall(self, api_kwargs: Dict = None, model_type: ModelType = None) -> Any: | |
| """Make an asynchronous call to the AWS Bedrock API.""" | |
| # For now, just call the sync method | |
| # In a real implementation, you would use an async library or run the sync method in a thread pool | |
| return self.call(api_kwargs, model_type) | |
| def convert_inputs_to_api_kwargs( | |
| self, input: Any = None, model_kwargs: Dict = None, model_type: ModelType = None | |
| ) -> Dict: | |
| """Convert inputs to API kwargs for AWS Bedrock.""" | |
| model_kwargs = model_kwargs or {} | |
| api_kwargs = {} | |
| if model_type == ModelType.LLM: | |
| api_kwargs["model"] = model_kwargs.get("model", "anthropic.claude-3-sonnet-20240229-v1:0") | |
| api_kwargs["input"] = input | |
| # Add model parameters | |
| if "temperature" in model_kwargs: | |
| api_kwargs["temperature"] = model_kwargs["temperature"] | |
| if "top_p" in model_kwargs: | |
| api_kwargs["top_p"] = model_kwargs["top_p"] | |
| return api_kwargs | |
| else: | |
| raise ValueError(f"Model type {model_type} is not supported by AWS Bedrock client") | |