import { RequestHandler } from "express"; import { ZodIssue } from "zod"; import { initializeSseStream } from "../../../shared/streaming"; import { classifyErrorAndSend } from "../common"; import { RequestPreprocessor, blockZoomerOrigins, countPromptTokens, languageFilter, setApiFormat, transformOutboundPayload, validateContextSize, validateModelFamily, validateVision, applyQuotaLimits, } from "."; type RequestPreprocessorOptions = { /** * Functions to run before the request body is transformed between API * formats. Use this to change the behavior of the transformation, such as for * endpoints which can accept multiple API formats. */ beforeTransform?: RequestPreprocessor[]; /** * Functions to run after the request body is transformed and token counts are * assigned. Use this to perform validation or other actions that depend on * the request body being in the final API format. */ afterTransform?: RequestPreprocessor[]; }; /** * Returns a middleware function that processes the request body into the given * API format, and then sequentially runs the given additional preprocessors. * These should be used for validation and transformations that only need to * happen once per request. * * These run first in the request lifecycle, a single time per request before it * is added to the request queue. They aren't run again if the request is * re-attempted after a rate limit. * * To run functions against requests every time they are re-attempted, write a * ProxyReqMutator and pass it to createQueuedProxyMiddleware instead. */ export const createPreprocessorMiddleware = ( apiFormat: Parameters[0], { beforeTransform, afterTransform }: RequestPreprocessorOptions = {} ): RequestHandler => { const preprocessors: RequestPreprocessor[] = [ setApiFormat(apiFormat), blockZoomerOrigins, ...(beforeTransform ?? []), transformOutboundPayload, countPromptTokens, languageFilter, ...(afterTransform ?? []), validateContextSize, validateVision, validateModelFamily, applyQuotaLimits, ]; return async (...args) => executePreprocessors(preprocessors, args); }; /** * Returns a middleware function that specifically prepares requests for * OpenAI's embeddings API. Tokens are not counted because embeddings requests * are basically free. */ export const createEmbeddingsPreprocessorMiddleware = (): RequestHandler => { const preprocessors: RequestPreprocessor[] = [ setApiFormat({ inApi: "openai", outApi: "openai", service: "openai" }), (req) => void (req.promptTokens = req.outputTokens = 0), ]; return async (...args) => executePreprocessors(preprocessors, args); }; async function executePreprocessors( preprocessors: RequestPreprocessor[], [req, res, next]: Parameters ) { handleTestMessage(req, res, next); if (res.headersSent) return; try { for (const preprocessor of preprocessors) { await preprocessor(req); } next(); } catch (error) { if (error.constructor.name === "ZodError") { const issues = error?.issues ?.map((issue: ZodIssue) => `${issue.path.join(".")}: ${issue.message}`) .join("; "); req.log.warn({ issues }, "Prompt failed preprocessor validation."); } else { req.log.error(error, "Error while executing request preprocessor"); } // If the requested has opted into streaming, the client probably won't // handle a non-eventstream response, but we haven't initialized the SSE // stream yet as that is typically done later by the request queue. We'll // do that here and then call classifyErrorAndSend to use the streaming // error handler. const { stream } = req.body; const isStreaming = stream === "true" || stream === true; if (isStreaming && !res.headersSent) { initializeSseStream(res); } classifyErrorAndSend(error as Error, req, res); } } /** * Bypasses the API call and returns a test message response if the request body * is a known test message from SillyTavern. Otherwise these messages just waste * API request quota and confuse users when the proxy is busy, because ST always * makes them with `stream: false` (which is not allowed when the proxy is busy) */ const handleTestMessage: RequestHandler = (req, res) => { const { method, body } = req; if (method !== "POST") { return; } if (isTestMessage(body)) { req.log.info({ body }, "Received test message. Skipping API call."); res.json({ id: "test-message", object: "chat.completion", created: Date.now(), model: body.model, // openai chat choices: [ { message: { role: "assistant", content: "Hello!" }, finish_reason: "stop", index: 0, }, ], // anthropic text completion: "Hello!", // anthropic chat content: [{ type: "text", text: "Hello!" }], // gemini candidates: [ { content: { parts: [{ text: "Hello!" }] }, finishReason: "stop", }, ], proxy_note: "SillyTavern connection test detected. Your prompt was not sent to the actual model and this response was generated by the proxy.", }); } }; function isTestMessage(body: any) { const { messages, prompt, contents } = body; if (messages) { return ( messages.length === 1 && messages[0].role === "user" && messages[0].content === "Hi" ); } else if (contents) { return contents.length === 1 && contents[0].parts[0]?.text === "Hi"; } else { return ( prompt?.trim() === "Human: Hi\n\nAssistant:" || prompt?.startsWith("Hi\n\n") ); } }