| import pino from "pino"; |
| import { Transform, TransformOptions } from "stream"; |
| import { Message } from "@smithy/eventstream-codec"; |
| import { APIFormat } from "../../../../shared/key-management"; |
| import { BadRequestError, RetryableError } from "../../../../shared/errors"; |
|
|
| type SSEStreamAdapterOptions = TransformOptions & { |
| contentType?: string; |
| api: APIFormat; |
| logger: pino.Logger; |
| }; |
|
|
| |
| |
| |
| |
| |
| |
| |
| export class SSEStreamAdapter extends Transform { |
| private readonly isAwsStream; |
| private api: APIFormat; |
| private partialMessage = ""; |
| private textDecoder = new TextDecoder("utf8"); |
| private log: pino.Logger; |
|
|
| constructor(options: SSEStreamAdapterOptions) { |
| super({ ...options, objectMode: true }); |
| this.isAwsStream = |
| options?.contentType === "application/vnd.amazon.eventstream"; |
| this.api = options.api; |
| this.log = options.logger.child({ module: "sse-stream-adapter" }); |
| } |
|
|
| protected processAwsMessage(message: Message): string | null { |
| |
| |
| const { headers, body } = message; |
| const eventType = headers[":event-type"]?.value; |
| const messageType = headers[":message-type"]?.value; |
| const contentType = headers[":content-type"]?.value; |
| const exceptionType = headers[":exception-type"]?.value; |
| const errorCode = headers[":error-code"]?.value; |
| const bodyStr = this.textDecoder.decode(body); |
|
|
| switch (messageType) { |
| case "event": |
| if (contentType === "application/json" && eventType === "chunk") { |
| const { bytes } = JSON.parse(bodyStr); |
| const event = Buffer.from(bytes, "base64").toString("utf8"); |
| const eventObj = JSON.parse(event); |
|
|
| if ("completion" in eventObj) { |
| return ["event: completion", `data: ${event}`].join(`\n`); |
| } else if (eventObj.type) { |
| return [`event: ${eventObj.type}`, `data: ${event}`].join(`\n`); |
| } else { |
| return `data: ${event}`; |
| } |
| } |
| |
| case "exception": |
| case "error": |
| const type = String( |
| exceptionType || errorCode || "UnknownError" |
| ).toLowerCase(); |
| switch (type) { |
| case "throttlingexception": |
| this.log.warn( |
| "AWS request throttled after streaming has already started; retrying" |
| ); |
| throw new RetryableError("AWS request throttled mid-stream"); |
| case "validationexception": |
| try { |
| const { message } = JSON.parse(bodyStr); |
| this.log.error({ message }, "Received AWS validation error"); |
| this.emit( |
| "error", |
| new BadRequestError(`AWS validation error: ${message}`) |
| ); |
| return null; |
| } catch (error) { |
| this.log.error( |
| { body: bodyStr, error }, |
| "Could not parse AWS validation error" |
| ); |
| } |
| |
| default: |
| let text; |
| try { |
| text = JSON.parse(bodyStr).message; |
| } catch (error) { |
| text = bodyStr; |
| } |
| const error: any = new Error( |
| `Got mysterious error chunk: [${type}] ${text}` |
| ); |
| error.lastEvent = text; |
| this.emit("error", error); |
| return null; |
| } |
| default: |
| |
| this.log.error({ message }, "Received very bad AWS stream event"); |
| return null; |
| } |
| } |
|
|
| _transform(data: any, _enc: string, callback: (err?: Error | null) => void) { |
| try { |
| if (this.isAwsStream) { |
| |
| const message = this.processAwsMessage(data); |
| if (message) this.push(message + "\n\n"); |
| } else { |
| |
| const fullMessages = (this.partialMessage + data).split( |
| /\r\r|\n\n|\r\n\r\n/ |
| ); |
| this.partialMessage = fullMessages.pop() || ""; |
|
|
| for (const message of fullMessages) { |
| |
| |
| |
| this.push(message.replace(/\r\n?/g, "\n") + "\n\n"); |
| } |
| } |
| callback(); |
| } catch (error) { |
| error.lastEvent = data?.toString() ?? "[SSEStreamAdapter] no data"; |
| callback(error); |
| } |
| } |
|
|
| _flush(callback: (err?: Error | null) => void) { |
| callback(); |
| } |
| } |
|
|