| import express from "express"; |
| import { pipeline, Readable, Transform } from "stream"; |
| import { StringDecoder } from "string_decoder"; |
| import { promisify } from "util"; |
| import type { logger } from "../../../logger"; |
| import { BadRequestError, RetryableError } from "../../../shared/errors"; |
| import { APIFormat, keyPool } from "../../../shared/key-management"; |
| import { |
| copySseResponseHeaders, |
| initializeSseStream, |
| } from "../../../shared/streaming"; |
| import { reenqueueRequest } from "../../queue"; |
| import type { RawResponseBodyHandler } from "."; |
| import { handleBlockingResponse } from "./handle-blocking-response"; |
| import { buildSpoofedSSE, sendErrorToClient } from "./error-generator"; |
| import { getAwsEventStreamDecoder } from "./streaming/aws-event-stream-decoder"; |
| import { EventAggregator } from "./streaming/event-aggregator"; |
| import { SSEMessageTransformer } from "./streaming/sse-message-transformer"; |
| import { SSEStreamAdapter } from "./streaming/sse-stream-adapter"; |
| import { getStreamDecompressor } from "./compression"; |
|
|
| const pipelineAsync = promisify(pipeline); |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| export const handleStreamedResponse: RawResponseBodyHandler = async ( |
| proxyRes, |
| req, |
| res |
| ) => { |
| const { headers, statusCode } = proxyRes; |
| if (!req.isStreaming) { |
| throw new Error("handleStreamedResponse called for non-streaming request."); |
| } |
|
|
| if (statusCode! > 201) { |
| req.isStreaming = false; |
| req.log.warn( |
| { statusCode }, |
| `Streaming request returned error status code. Falling back to non-streaming response handler.` |
| ); |
| return handleBlockingResponse(proxyRes, req, res); |
| } |
|
|
| req.log.debug({ headers }, `Starting to proxy SSE stream.`); |
|
|
| |
| |
| if (!res.headersSent) { |
| copySseResponseHeaders(proxyRes, res); |
| initializeSseStream(res); |
| } |
|
|
| const prefersNativeEvents = req.inboundApi === req.outboundApi; |
| const streamOptions = { |
| contentType: headers["content-type"], |
| api: req.outboundApi, |
| logger: req.log, |
| }; |
|
|
| |
| |
| |
| |
| |
| const aggregator = new EventAggregator(req); |
|
|
| const decompressor = getStreamDecompressor(headers["content-encoding"]); |
| |
| const decoder = getDecoder({ ...streamOptions, input: proxyRes }); |
| |
| |
| |
| const adapter = new SSEStreamAdapter(streamOptions); |
| |
| |
| const transformer = new SSEMessageTransformer({ |
| inputFormat: req.outboundApi, |
| outputFormat: req.inboundApi, |
| inputApiVersion: String(req.headers["anthropic-version"]), |
| logger: req.log, |
| requestId: String(req.id), |
| requestedModel: req.body.model, |
| }) |
| .on("originalMessage", (msg: string) => { |
| if (prefersNativeEvents) res.write(msg); |
| }) |
| .on("data", (msg) => { |
| if (!prefersNativeEvents) res.write(`data: ${JSON.stringify(msg)}\n\n`); |
| aggregator.addEvent(msg); |
| }); |
|
|
| try { |
| await Promise.race([ |
| handleAbortedStream(req, res), |
| pipelineAsync(proxyRes, decompressor, decoder, adapter, transformer), |
| ]); |
| req.log.debug(`Finished proxying SSE stream.`); |
| res.end(); |
| return aggregator.getFinalResponse(); |
| } catch (err) { |
| if (err instanceof RetryableError) { |
| keyPool.markRateLimited(req.key!); |
| await reenqueueRequest(req); |
| } else if (err instanceof BadRequestError) { |
| sendErrorToClient({ |
| req, |
| res, |
| options: { |
| format: req.inboundApi, |
| title: "Proxy streaming error (Bad Request)", |
| message: `The API returned an error while streaming your request. Your prompt might not be formatted correctly.\n\n*${err.message}*`, |
| reqId: req.id, |
| model: req.body?.model, |
| }, |
| }); |
| } else { |
| const { message, stack, lastEvent } = err; |
| const eventText = JSON.stringify(lastEvent, null, 2) ?? "undefined"; |
| const errorEvent = buildSpoofedSSE({ |
| format: req.inboundApi, |
| title: "Proxy stream error", |
| message: "An unexpected error occurred while streaming the response.", |
| obj: { message, stack, lastEvent: eventText }, |
| reqId: req.id, |
| model: req.body?.model, |
| }); |
| res.write(errorEvent); |
| res.write(`data: [DONE]\n\n`); |
| res.end(); |
| } |
|
|
| |
| |
| |
| if (aggregator.hasEvents()) { |
| return aggregator.getFinalResponse(); |
| } else { |
| |
| |
| throw err; |
| } |
| } |
| }; |
|
|
| function handleAbortedStream(req: express.Request, res: express.Response) { |
| return new Promise<void>((resolve) => |
| res.on("close", () => { |
| if (!res.writableEnded) { |
| req.log.info("Client prematurely closed connection during stream."); |
| } |
| resolve(); |
| }) |
| ); |
| } |
|
|
| function getDecoder(options: { |
| input: Readable; |
| api: APIFormat; |
| logger: typeof logger; |
| contentType?: string; |
| }) { |
| const { contentType, input, logger } = options; |
| if (contentType?.includes("application/vnd.amazon.eventstream")) { |
| return getAwsEventStreamDecoder({ input, logger }); |
| } else if (contentType?.includes("application/json")) { |
| throw new Error("JSON streaming not supported, request SSE instead"); |
| } else { |
| |
| const stringDecoder = new StringDecoder("utf8"); |
| return new Transform({ |
| readableObjectMode: true, |
| writableObjectMode: false, |
| transform(chunk, _encoding, callback) { |
| const text = stringDecoder.write(chunk); |
| if (text) this.push(text); |
| callback(); |
| }, |
| }); |
| } |
| } |
|
|