|
|
import { RequestHandler } from "express"; |
|
|
import { ZodIssue } from "zod"; |
|
|
import { initializeSseStream } from "../../../shared/streaming"; |
|
|
import { classifyErrorAndSend } from "../common"; |
|
|
import { |
|
|
RequestPreprocessor, |
|
|
blockZoomerOrigins, |
|
|
countPromptTokens, |
|
|
languageFilter, |
|
|
setApiFormat, |
|
|
transformOutboundPayload, |
|
|
validateContextSize, |
|
|
validateModelFamily, |
|
|
validateVision, |
|
|
applyQuotaLimits, |
|
|
} from "."; |
|
|
|
|
|
type RequestPreprocessorOptions = { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
beforeTransform?: RequestPreprocessor[]; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
afterTransform?: RequestPreprocessor[]; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export const createPreprocessorMiddleware = ( |
|
|
apiFormat: Parameters<typeof setApiFormat>[0], |
|
|
{ beforeTransform, afterTransform }: RequestPreprocessorOptions = {} |
|
|
): RequestHandler => { |
|
|
const preprocessors: RequestPreprocessor[] = [ |
|
|
setApiFormat(apiFormat), |
|
|
blockZoomerOrigins, |
|
|
...(beforeTransform ?? []), |
|
|
transformOutboundPayload, |
|
|
countPromptTokens, |
|
|
languageFilter, |
|
|
...(afterTransform ?? []), |
|
|
validateContextSize, |
|
|
validateVision, |
|
|
validateModelFamily, |
|
|
applyQuotaLimits, |
|
|
]; |
|
|
return async (...args) => executePreprocessors(preprocessors, args); |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
export const createEmbeddingsPreprocessorMiddleware = (): RequestHandler => { |
|
|
const preprocessors: RequestPreprocessor[] = [ |
|
|
setApiFormat({ inApi: "openai", outApi: "openai", service: "openai" }), |
|
|
(req) => void (req.promptTokens = req.outputTokens = 0), |
|
|
]; |
|
|
return async (...args) => executePreprocessors(preprocessors, args); |
|
|
}; |
|
|
|
|
|
async function executePreprocessors( |
|
|
preprocessors: RequestPreprocessor[], |
|
|
[req, res, next]: Parameters<RequestHandler> |
|
|
) { |
|
|
handleTestMessage(req, res, next); |
|
|
if (res.headersSent) return; |
|
|
|
|
|
try { |
|
|
for (const preprocessor of preprocessors) { |
|
|
await preprocessor(req); |
|
|
} |
|
|
next(); |
|
|
} catch (error) { |
|
|
if (error.constructor.name === "ZodError") { |
|
|
const issues = error?.issues |
|
|
?.map((issue: ZodIssue) => `${issue.path.join(".")}: ${issue.message}`) |
|
|
.join("; "); |
|
|
req.log.warn({ issues }, "Prompt failed preprocessor validation."); |
|
|
} else { |
|
|
req.log.error(error, "Error while executing request preprocessor"); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const { stream } = req.body; |
|
|
const isStreaming = stream === "true" || stream === true; |
|
|
if (isStreaming && !res.headersSent) { |
|
|
initializeSseStream(res); |
|
|
} |
|
|
classifyErrorAndSend(error as Error, req, res); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const handleTestMessage: RequestHandler = (req, res) => { |
|
|
const { method, body } = req; |
|
|
if (method !== "POST") { |
|
|
return; |
|
|
} |
|
|
|
|
|
if (isTestMessage(body)) { |
|
|
req.log.info({ body }, "Received test message. Skipping API call."); |
|
|
res.json({ |
|
|
id: "test-message", |
|
|
object: "chat.completion", |
|
|
created: Date.now(), |
|
|
model: body.model, |
|
|
|
|
|
choices: [ |
|
|
{ |
|
|
message: { role: "assistant", content: "Hello!" }, |
|
|
finish_reason: "stop", |
|
|
index: 0, |
|
|
}, |
|
|
], |
|
|
|
|
|
completion: "Hello!", |
|
|
|
|
|
content: [{ type: "text", text: "Hello!" }], |
|
|
|
|
|
candidates: [ |
|
|
{ |
|
|
content: { parts: [{ text: "Hello!" }] }, |
|
|
finishReason: "stop", |
|
|
}, |
|
|
], |
|
|
proxy_note: |
|
|
"SillyTavern connection test detected. Your prompt was not sent to the actual model and this response was generated by the proxy.", |
|
|
}); |
|
|
} |
|
|
}; |
|
|
|
|
|
function isTestMessage(body: any) { |
|
|
const { messages, prompt, contents } = body; |
|
|
|
|
|
if (messages) { |
|
|
return ( |
|
|
messages.length === 1 && |
|
|
messages[0].role === "user" && |
|
|
messages[0].content === "Hi" |
|
|
); |
|
|
} else if (contents) { |
|
|
return contents.length === 1 && contents[0].parts[0]?.text === "Hi"; |
|
|
} else { |
|
|
return ( |
|
|
prompt?.trim() === "Human: Hi\n\nAssistant:" || |
|
|
prompt?.startsWith("Hi\n\n") |
|
|
); |
|
|
} |
|
|
} |
|
|
|