|
|
import express, { Request } from "express"; |
|
|
import { Sha256 } from "@aws-crypto/sha256-js"; |
|
|
import { SignatureV4 } from "@smithy/signature-v4"; |
|
|
import { HttpRequest } from "@smithy/protocol-http"; |
|
|
import { |
|
|
AnthropicV1TextSchema, |
|
|
AnthropicV1MessagesSchema, |
|
|
} from "../../../../shared/api-schemas"; |
|
|
import { AwsBedrockKey, keyPool } from "../../../../shared/key-management"; |
|
|
import { |
|
|
AWSMistralV1ChatCompletionsSchema, |
|
|
AWSMistralV1TextCompletionsSchema, |
|
|
} from "../../../../shared/api-schemas/mistral-ai"; |
|
|
import { ProxyReqMutator } from "../index"; |
|
|
|
|
|
const AMZ_HOST = |
|
|
process.env.AMZ_HOST || "bedrock-runtime.%REGION%.amazonaws.com"; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export const signAwsRequest: ProxyReqMutator = async (manager) => { |
|
|
const req = manager.request; |
|
|
const { model, stream } = req.body; |
|
|
const key = keyPool.get(model, "aws") as AwsBedrockKey; |
|
|
manager.setKey(key); |
|
|
|
|
|
let system = req.body.system ?? ""; |
|
|
if (Array.isArray(system)) { |
|
|
system = system |
|
|
.map((m: { type: string; text: string }) => m.text) |
|
|
.join("\n"); |
|
|
req.body.system = system; |
|
|
} |
|
|
|
|
|
const credential = getCredentialParts(req); |
|
|
const host = AMZ_HOST.replace("%REGION%", credential.region); |
|
|
|
|
|
|
|
|
|
|
|
manager.setHeader("anthropic-version", "2023-06-01"); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const profile = |
|
|
key.inferenceProfileIds.find((p) => p.includes(model)) || model; |
|
|
|
|
|
|
|
|
|
|
|
const newRequest = new HttpRequest({ |
|
|
method: "POST", |
|
|
protocol: "https:", |
|
|
hostname: host, |
|
|
path: `/model/${profile}/invoke${stream ? "-with-response-stream" : ""}`, |
|
|
headers: { |
|
|
["Host"]: host, |
|
|
["content-type"]: "application/json", |
|
|
}, |
|
|
body: JSON.stringify(getStrictlyValidatedBodyForAws(req)), |
|
|
}); |
|
|
|
|
|
if (stream) { |
|
|
newRequest.headers["x-amzn-bedrock-accept"] = "application/json"; |
|
|
} else { |
|
|
newRequest.headers["accept"] = "*/*"; |
|
|
} |
|
|
|
|
|
const { body, inboundApi, outboundApi } = req; |
|
|
req.log.info( |
|
|
{ key: key.hash, model: body.model, profile, inboundApi, outboundApi }, |
|
|
"Assigned AWS credentials to request" |
|
|
); |
|
|
|
|
|
manager.setSignedRequest(await sign(newRequest, getCredentialParts(req))); |
|
|
}; |
|
|
|
|
|
type Credential = { |
|
|
accessKeyId: string; |
|
|
secretAccessKey: string; |
|
|
region: string; |
|
|
}; |
|
|
|
|
|
function getCredentialParts(req: express.Request): Credential { |
|
|
const [accessKeyId, secretAccessKey, region] = req.key!.key.split(":"); |
|
|
|
|
|
if (!accessKeyId || !secretAccessKey || !region) { |
|
|
req.log.error( |
|
|
{ key: req.key!.hash }, |
|
|
"AWS_CREDENTIALS isn't correctly formatted; refer to the docs" |
|
|
); |
|
|
throw new Error("The key assigned to this request is invalid."); |
|
|
} |
|
|
|
|
|
return { accessKeyId, secretAccessKey, region }; |
|
|
} |
|
|
|
|
|
async function sign(request: HttpRequest, credential: Credential) { |
|
|
const { accessKeyId, secretAccessKey, region } = credential; |
|
|
|
|
|
const signer = new SignatureV4({ |
|
|
sha256: Sha256, |
|
|
credentials: { accessKeyId, secretAccessKey }, |
|
|
region, |
|
|
service: "bedrock", |
|
|
}); |
|
|
|
|
|
return signer.sign(request); |
|
|
} |
|
|
|
|
|
function getStrictlyValidatedBodyForAws(req: Readonly<Request>): unknown { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let strippedParams: Record<string, unknown> = {}; |
|
|
switch (req.outboundApi) { |
|
|
case "anthropic-text": |
|
|
strippedParams = AnthropicV1TextSchema.pick({ |
|
|
prompt: true, |
|
|
max_tokens_to_sample: true, |
|
|
stop_sequences: true, |
|
|
temperature: true, |
|
|
top_k: true, |
|
|
top_p: true, |
|
|
}) |
|
|
.strip() |
|
|
.parse(req.body); |
|
|
break; |
|
|
case "anthropic-chat": |
|
|
strippedParams = AnthropicV1MessagesSchema.pick({ |
|
|
messages: true, |
|
|
system: true, |
|
|
max_tokens: true, |
|
|
stop_sequences: true, |
|
|
temperature: true, |
|
|
top_k: true, |
|
|
top_p: true, |
|
|
tools: true, |
|
|
tool_choice: true, |
|
|
thinking: true |
|
|
}) |
|
|
.strip() |
|
|
.parse(req.body); |
|
|
strippedParams.anthropic_version = "bedrock-2023-05-31"; |
|
|
break; |
|
|
case "mistral-ai": |
|
|
strippedParams = AWSMistralV1ChatCompletionsSchema.parse(req.body); |
|
|
break; |
|
|
case "mistral-text": |
|
|
strippedParams = AWSMistralV1TextCompletionsSchema.parse(req.body); |
|
|
break; |
|
|
default: |
|
|
throw new Error("Unexpected outbound API for AWS."); |
|
|
} |
|
|
return strippedParams; |
|
|
} |
|
|
|