| package aws |
|
|
| import ( |
| "fmt" |
| "io" |
| "net/http" |
| "strings" |
|
|
| "github.com/QuantumNous/new-api/dto" |
| "github.com/QuantumNous/new-api/relay/channel" |
| "github.com/QuantumNous/new-api/relay/channel/claude" |
| relaycommon "github.com/QuantumNous/new-api/relay/common" |
| "github.com/QuantumNous/new-api/service" |
| "github.com/QuantumNous/new-api/types" |
| "github.com/aws/aws-sdk-go-v2/service/bedrockruntime" |
| "github.com/pkg/errors" |
|
|
| "github.com/gin-gonic/gin" |
| ) |
|
|
| type ClientMode int |
|
|
| const ( |
| ClientModeApiKey ClientMode = iota + 1 |
| ClientModeAKSK |
| ) |
|
|
| type Adaptor struct { |
| ClientMode ClientMode |
| AwsClient *bedrockruntime.Client |
| AwsModelId string |
| AwsReq any |
| IsNova bool |
| } |
|
|
| func (a *Adaptor) ConvertGeminiRequest(*gin.Context, *relaycommon.RelayInfo, *dto.GeminiChatRequest) (any, error) { |
| |
| return nil, errors.New("not implemented") |
| } |
|
|
| func (a *Adaptor) ConvertClaudeRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.ClaudeRequest) (any, error) { |
| for i, message := range request.Messages { |
| updated := false |
| if !message.IsStringContent() { |
| content, err := message.ParseContent() |
| if err != nil { |
| return nil, errors.Wrap(err, "failed to parse message content") |
| } |
| for i2, mediaMessage := range content { |
| if mediaMessage.Source != nil { |
| if mediaMessage.Source.Type == "url" { |
| fileData, err := service.GetFileBase64FromUrl(c, mediaMessage.Source.Url, "formatting image for Claude") |
| if err != nil { |
| return nil, fmt.Errorf("get file base64 from url failed: %s", err.Error()) |
| } |
| mediaMessage.Source.MediaType = fileData.MimeType |
| mediaMessage.Source.Data = fileData.Base64Data |
| mediaMessage.Source.Url = "" |
| mediaMessage.Source.Type = "base64" |
| content[i2] = mediaMessage |
| updated = true |
| } |
| } |
| } |
| if updated { |
| message.SetContent(content) |
| } |
| } |
| if updated { |
| request.Messages[i] = message |
| } |
| } |
| return request, nil |
| } |
|
|
| func (a *Adaptor) ConvertAudioRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.AudioRequest) (io.Reader, error) { |
| |
| return nil, errors.New("not implemented") |
| } |
|
|
| func (a *Adaptor) ConvertImageRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.ImageRequest) (any, error) { |
| |
| return nil, errors.New("not implemented") |
| } |
|
|
| func (a *Adaptor) Init(info *relaycommon.RelayInfo) { |
| } |
|
|
| func (a *Adaptor) GetRequestURL(info *relaycommon.RelayInfo) (string, error) { |
| if info.ChannelOtherSettings.AwsKeyType == dto.AwsKeyTypeApiKey { |
| awsModelId := getAwsModelID(info.UpstreamModelName) |
| a.ClientMode = ClientModeApiKey |
| awsSecret := strings.Split(info.ApiKey, "|") |
| if len(awsSecret) != 2 { |
| return "", errors.New("invalid aws api key, should be in format of <api-key>|<region>") |
| } |
| return fmt.Sprintf("https://bedrock-runtime.%s.amazonaws.com/model/%s/converse", awsModelId, awsSecret[1]), nil |
| } else { |
| a.ClientMode = ClientModeAKSK |
| return "", nil |
| } |
| } |
|
|
| func (a *Adaptor) SetupRequestHeader(c *gin.Context, req *http.Header, info *relaycommon.RelayInfo) error { |
| claude.CommonClaudeHeadersOperation(c, req, info) |
| if a.ClientMode == ClientModeApiKey { |
| req.Set("Authorization", "Bearer "+info.ApiKey) |
| } |
| return nil |
| } |
|
|
| func (a *Adaptor) ConvertOpenAIRequest(c *gin.Context, info *relaycommon.RelayInfo, request *dto.GeneralOpenAIRequest) (any, error) { |
| if request == nil { |
| return nil, errors.New("request is nil") |
| } |
| |
| if isNovaModel(request.Model) { |
| novaReq := convertToNovaRequest(request) |
| a.IsNova = true |
| return novaReq, nil |
| } |
|
|
| |
| claudeReq, err := claude.RequestOpenAI2ClaudeMessage(c, *request) |
| if err != nil { |
| return nil, errors.Wrap(err, "failed to convert openai request to claude request") |
| } |
| info.UpstreamModelName = claudeReq.Model |
| return claudeReq, err |
| } |
|
|
| func (a *Adaptor) ConvertRerankRequest(c *gin.Context, relayMode int, request dto.RerankRequest) (any, error) { |
| return nil, nil |
| } |
|
|
| func (a *Adaptor) ConvertEmbeddingRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.EmbeddingRequest) (any, error) { |
| |
| return nil, errors.New("not implemented") |
| } |
|
|
| func (a *Adaptor) ConvertOpenAIResponsesRequest(c *gin.Context, info *relaycommon.RelayInfo, request dto.OpenAIResponsesRequest) (any, error) { |
| |
| return nil, errors.New("not implemented") |
| } |
|
|
| func (a *Adaptor) DoRequest(c *gin.Context, info *relaycommon.RelayInfo, requestBody io.Reader) (any, error) { |
| if a.ClientMode == ClientModeApiKey { |
| return channel.DoApiRequest(a, c, info, requestBody) |
| } else { |
| return doAwsClientRequest(c, info, a, requestBody) |
| } |
| } |
|
|
| func (a *Adaptor) DoResponse(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo) (usage any, err *types.NewAPIError) { |
| if a.ClientMode == ClientModeApiKey { |
| claudeAdaptor := claude.Adaptor{} |
| usage, err = claudeAdaptor.DoResponse(c, resp, info) |
| } else { |
| if a.IsNova { |
| err, usage = handleNovaRequest(c, info, a) |
| } else { |
| if info.IsStream { |
| err, usage = awsStreamHandler(c, info, a) |
| } else { |
| err, usage = awsHandler(c, info, a) |
| } |
| } |
| } |
| return |
| } |
|
|
| func (a *Adaptor) GetModelList() (models []string) { |
| for n := range awsModelIDMap { |
| models = append(models, n) |
| } |
|
|
| return |
| } |
|
|
| func (a *Adaptor) GetChannelName() string { |
| return ChannelName |
| } |
|
|