|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import crypto from "crypto"; |
|
|
import { Handler, Request } from "express"; |
|
|
import { config } from "../config"; |
|
|
import { BadRequestError, TooManyRequestsError } from "../shared/errors"; |
|
|
import { keyPool } from "../shared/key-management"; |
|
|
import { |
|
|
getModelFamilyForRequest, |
|
|
MODEL_FAMILIES, |
|
|
ModelFamily, |
|
|
} from "../shared/models"; |
|
|
import { initializeSseStream } from "../shared/streaming"; |
|
|
import { logger } from "../logger"; |
|
|
import { getUniqueIps } from "./rate-limit"; |
|
|
import { ProxyReqMutator, RequestPreprocessor } from "./middleware/request"; |
|
|
import { sendErrorToClient } from "./middleware/response/error-generator"; |
|
|
import { ProxyReqManager } from "./middleware/request/proxy-req-manager"; |
|
|
import { classifyErrorAndSend } from "./middleware/common"; |
|
|
|
|
|
const queue: Request[] = []; |
|
|
const log = logger.child({ module: "request-queue" }); |
|
|
|
|
|
|
|
|
const USER_CONCURRENCY_LIMIT = parseInt( |
|
|
process.env.USER_CONCURRENCY_LIMIT ?? "1" |
|
|
); |
|
|
const MIN_HEARTBEAT_SIZE = parseInt(process.env.MIN_HEARTBEAT_SIZE_B ?? "512"); |
|
|
const MAX_HEARTBEAT_SIZE = |
|
|
1024 * parseInt(process.env.MAX_HEARTBEAT_SIZE_KB ?? "1024"); |
|
|
const HEARTBEAT_INTERVAL = |
|
|
1000 * parseInt(process.env.HEARTBEAT_INTERVAL_SEC ?? "5"); |
|
|
const LOAD_THRESHOLD = parseFloat(process.env.LOAD_THRESHOLD ?? "150"); |
|
|
const PAYLOAD_SCALE_FACTOR = parseFloat( |
|
|
process.env.PAYLOAD_SCALE_FACTOR ?? "6" |
|
|
); |
|
|
const QUEUE_JOIN_TIMEOUT = 5000; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function getIdentifier(req: Request) { |
|
|
if (req.user) return req.user.token; |
|
|
if (req.risuToken) return req.risuToken; |
|
|
|
|
|
return req.ip; |
|
|
} |
|
|
|
|
|
const sharesIdentifierWith = (incoming: Request) => (queued: Request) => |
|
|
getIdentifier(queued) === getIdentifier(incoming); |
|
|
|
|
|
async function enqueue(req: Request) { |
|
|
if (req.socket.destroyed || req.res?.writableEnded) { |
|
|
|
|
|
|
|
|
|
|
|
req.log.warn("Attempt to enqueue aborted request."); |
|
|
throw new Error("Attempt to enqueue aborted request."); |
|
|
} |
|
|
|
|
|
const enqueuedRequestCount = queue.filter(sharesIdentifierWith(req)).length; |
|
|
|
|
|
if (enqueuedRequestCount >= USER_CONCURRENCY_LIMIT && req.user?.type !== "special") { |
|
|
throw new TooManyRequestsError( |
|
|
"Your IP or user token already has another request in the queue." |
|
|
); |
|
|
} |
|
|
|
|
|
|
|
|
removeProxyMiddlewareEventListeners(req); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const { stream } = req.body; |
|
|
if (stream === "true" || stream === true || req.isStreaming) { |
|
|
const res = req.res!; |
|
|
if (!res.headersSent) { |
|
|
await initStreaming(req); |
|
|
} |
|
|
registerHeartbeat(req); |
|
|
} else if (getProxyLoad() > LOAD_THRESHOLD) { |
|
|
throw new BadRequestError( |
|
|
"Due to heavy traffic on this proxy, you must enable streaming in your chat client to use this endpoint." |
|
|
); |
|
|
} |
|
|
|
|
|
queue.push(req); |
|
|
req.queueOutTime = 0; |
|
|
|
|
|
const removeFromQueue = () => { |
|
|
req.log.info(`Removing aborted request from queue.`); |
|
|
const index = queue.indexOf(req); |
|
|
if (index !== -1) { |
|
|
queue.splice(index, 1); |
|
|
} |
|
|
if (req.heartbeatInterval) clearInterval(req.heartbeatInterval); |
|
|
if (req.monitorInterval) clearInterval(req.monitorInterval); |
|
|
}; |
|
|
req.onAborted = removeFromQueue; |
|
|
req.res!.once("close", removeFromQueue); |
|
|
|
|
|
if (req.retryCount ?? 0 > 0) { |
|
|
req.log.info({ retries: req.retryCount }, `Enqueued request for retry.`); |
|
|
} else { |
|
|
const size = req.socket.bytesRead; |
|
|
const endpoint = req.url?.split("?")[0]; |
|
|
req.log.info({ size, endpoint }, `Enqueued new request.`); |
|
|
} |
|
|
} |
|
|
|
|
|
export async function reenqueueRequest(req: Request) { |
|
|
req.log.info( |
|
|
{ key: req.key?.hash, retryCount: req.retryCount }, |
|
|
`Re-enqueueing request due to retryable error` |
|
|
); |
|
|
req.retryCount++; |
|
|
await enqueue(req); |
|
|
} |
|
|
|
|
|
function getQueueForPartition(partition: ModelFamily): Request[] { |
|
|
return queue.filter((req) => getModelFamilyForRequest(req) === partition); |
|
|
} |
|
|
|
|
|
export function dequeue(partition: ModelFamily): Request | undefined { |
|
|
const modelQueue = getQueueForPartition(partition); |
|
|
|
|
|
if (modelQueue.length === 0) { |
|
|
return undefined; |
|
|
} |
|
|
|
|
|
const req = modelQueue.reduce((prev, curr) => |
|
|
prev.startTime + |
|
|
config.tokensPunishmentFactor * |
|
|
((prev.promptTokens ?? 0) + (prev.outputTokens ?? 0)) < |
|
|
curr.startTime + |
|
|
config.tokensPunishmentFactor * |
|
|
((curr.promptTokens ?? 0) + (curr.outputTokens ?? 0)) |
|
|
? prev |
|
|
: curr |
|
|
); |
|
|
queue.splice(queue.indexOf(req), 1); |
|
|
|
|
|
if (req.onAborted) { |
|
|
req.res!.off("close", req.onAborted); |
|
|
req.onAborted = undefined; |
|
|
} |
|
|
|
|
|
if (req.heartbeatInterval) clearInterval(req.heartbeatInterval); |
|
|
if (req.monitorInterval) clearInterval(req.monitorInterval); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
req.queueOutTime = Date.now(); |
|
|
return req; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function processQueue() { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const reqs: (Request | undefined)[] = []; |
|
|
MODEL_FAMILIES.forEach((modelFamily) => { |
|
|
const lockout = keyPool.getLockoutPeriod(modelFamily); |
|
|
if (lockout === 0) { |
|
|
reqs.push(dequeue(modelFamily)); |
|
|
} |
|
|
}); |
|
|
|
|
|
reqs.filter(Boolean).forEach((req) => { |
|
|
if (req?.proceed) { |
|
|
const modelFamily = getModelFamilyForRequest(req!); |
|
|
req.log.info( |
|
|
{ retries: req.retryCount, partition: modelFamily }, |
|
|
`Dequeuing request.` |
|
|
); |
|
|
req.proceed(); |
|
|
} |
|
|
}); |
|
|
setTimeout(processQueue, 50); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function cleanQueue() { |
|
|
const now = Date.now(); |
|
|
const oldRequests = queue.filter( |
|
|
(req) => now - (req.startTime ?? now) > 5 * 60 * 1000 |
|
|
); |
|
|
oldRequests.forEach((req) => { |
|
|
req.log.info(`Removing request from queue after 5 minutes.`); |
|
|
killQueuedRequest(req); |
|
|
}); |
|
|
|
|
|
const index = waitTimes.findIndex( |
|
|
(waitTime) => now - waitTime.end > 300 * 1000 |
|
|
); |
|
|
const removed = waitTimes.splice(0, index + 1); |
|
|
log.trace( |
|
|
{ stalledRequests: oldRequests.length, prunedWaitTimes: removed.length }, |
|
|
`Cleaning up request queue.` |
|
|
); |
|
|
setTimeout(cleanQueue, 20 * 1000); |
|
|
} |
|
|
|
|
|
export function start() { |
|
|
MODEL_FAMILIES.forEach((modelFamily) => { |
|
|
historicalEmas.set(modelFamily, 0); |
|
|
currentEmas.set(modelFamily, 0); |
|
|
estimates.set(modelFamily, 0); |
|
|
}); |
|
|
processQueue(); |
|
|
cleanQueue(); |
|
|
log.info(`Started request queue.`); |
|
|
} |
|
|
|
|
|
let waitTimes: { |
|
|
partition: ModelFamily; |
|
|
start: number; |
|
|
end: number; |
|
|
}[] = []; |
|
|
|
|
|
|
|
|
export function trackWaitTime(req: Request) { |
|
|
waitTimes.push({ |
|
|
partition: getModelFamilyForRequest(req), |
|
|
start: req.startTime!, |
|
|
end: req.queueOutTime ?? Date.now(), |
|
|
}); |
|
|
} |
|
|
|
|
|
const WAIT_TIME_INTERVAL = 3000; |
|
|
const ALPHA_HISTORICAL = 0.2; |
|
|
const ALPHA_CURRENT = 0.3; |
|
|
const historicalEmas: Map<ModelFamily, number> = new Map(); |
|
|
const currentEmas: Map<ModelFamily, number> = new Map(); |
|
|
const estimates: Map<ModelFamily, number> = new Map(); |
|
|
|
|
|
export function getEstimatedWaitTime(partition: ModelFamily) { |
|
|
return estimates.get(partition) ?? 0; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function calculateWaitTime(partition: ModelFamily) { |
|
|
const now = Date.now(); |
|
|
const recentWaits = waitTimes |
|
|
.filter((wait) => { |
|
|
const isSamePartition = wait.partition === partition; |
|
|
const isRecent = now - wait.end < 300 * 1000; |
|
|
return isSamePartition && isRecent; |
|
|
}) |
|
|
.map((wait) => wait.end - wait.start); |
|
|
const recentAverage = recentWaits.length |
|
|
? recentWaits.reduce((sum, wait) => sum + wait, 0) / recentWaits.length |
|
|
: 0; |
|
|
|
|
|
const historicalEma = historicalEmas.get(partition) ?? 0; |
|
|
historicalEmas.set( |
|
|
partition, |
|
|
ALPHA_HISTORICAL * recentAverage + (1 - ALPHA_HISTORICAL) * historicalEma |
|
|
); |
|
|
|
|
|
const currentWaits = queue |
|
|
.filter((req) => getModelFamilyForRequest(req) === partition) |
|
|
.map((req) => now - req.startTime!); |
|
|
const longestCurrentWait = Math.max(...currentWaits, 0); |
|
|
|
|
|
const currentEma = currentEmas.get(partition) ?? 0; |
|
|
currentEmas.set( |
|
|
partition, |
|
|
ALPHA_CURRENT * longestCurrentWait + (1 - ALPHA_CURRENT) * currentEma |
|
|
); |
|
|
|
|
|
return (historicalEma + currentEma) / 2; |
|
|
} |
|
|
|
|
|
setInterval(() => { |
|
|
MODEL_FAMILIES.forEach((modelFamily) => { |
|
|
estimates.set(modelFamily, calculateWaitTime(modelFamily)); |
|
|
}); |
|
|
}, WAIT_TIME_INTERVAL); |
|
|
|
|
|
export function getQueueLength(partition: ModelFamily | "all" = "all") { |
|
|
if (partition === "all") { |
|
|
return queue.length; |
|
|
} |
|
|
const modelQueue = getQueueForPartition(partition); |
|
|
return modelQueue.length; |
|
|
} |
|
|
|
|
|
export function createQueueMiddleware({ |
|
|
mutations = [], |
|
|
proxyMiddleware, |
|
|
}: { |
|
|
mutations?: ProxyReqMutator[]; |
|
|
proxyMiddleware: Handler; |
|
|
}): Handler { |
|
|
return async (req, res, next) => { |
|
|
req.proceed = async () => { |
|
|
|
|
|
|
|
|
req.isStreaming = req.isStreaming || String(req.body.stream) === "true"; |
|
|
req.body.stream = req.isStreaming; |
|
|
|
|
|
try { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const changeManager = new ProxyReqManager(req); |
|
|
req.changeManager = changeManager; |
|
|
for (const mutator of mutations) { |
|
|
await mutator(changeManager); |
|
|
} |
|
|
} catch (err) { |
|
|
|
|
|
return classifyErrorAndSend(err, req, res); |
|
|
} |
|
|
|
|
|
proxyMiddleware(req, res, next); |
|
|
}; |
|
|
|
|
|
try { |
|
|
await enqueue(req); |
|
|
} catch (err: any) { |
|
|
const title = |
|
|
err.status === 429 |
|
|
? "Proxy queue error (too many concurrent requests)" |
|
|
: "Proxy queue error (streaming required)"; |
|
|
sendErrorToClient({ |
|
|
options: { |
|
|
title, |
|
|
message: err.message, |
|
|
format: req.inboundApi, |
|
|
reqId: req.id, |
|
|
model: req.body?.model, |
|
|
}, |
|
|
req, |
|
|
res, |
|
|
}); |
|
|
} |
|
|
}; |
|
|
} |
|
|
|
|
|
function killQueuedRequest(req: Request) { |
|
|
if (!req.res || req.res.writableEnded) { |
|
|
req.log.warn(`Attempted to terminate request that has already ended.`); |
|
|
queue.splice(queue.indexOf(req), 1); |
|
|
return; |
|
|
} |
|
|
const res = req.res; |
|
|
try { |
|
|
const message = `Your request has been terminated by the proxy because it has been in the queue for more than 5 minutes.`; |
|
|
sendErrorToClient({ |
|
|
options: { |
|
|
title: "Proxy queue error (request killed)", |
|
|
message, |
|
|
format: req.inboundApi, |
|
|
reqId: req.id, |
|
|
model: req.body?.model, |
|
|
}, |
|
|
req, |
|
|
res, |
|
|
}); |
|
|
} catch (e) { |
|
|
req.log.error(e, `Error killing stalled request.`); |
|
|
} |
|
|
} |
|
|
|
|
|
async function initStreaming(req: Request) { |
|
|
const res = req.res!; |
|
|
initializeSseStream(res); |
|
|
|
|
|
const joinMsg = `: joining queue at position ${ |
|
|
queue.length |
|
|
}\n\n${getHeartbeatPayload()}`; |
|
|
|
|
|
let drainTimeout: NodeJS.Timeout; |
|
|
const welcome = new Promise<void>((resolve, reject) => { |
|
|
const onDrain = () => { |
|
|
clearTimeout(drainTimeout); |
|
|
req.log.debug(`Client finished consuming join message.`); |
|
|
res.off("drain", onDrain); |
|
|
resolve(); |
|
|
}; |
|
|
|
|
|
drainTimeout = setTimeout(() => { |
|
|
res.off("drain", onDrain); |
|
|
res.destroy(); |
|
|
reject(new Error("Unreponsive streaming client; killing connection")); |
|
|
}, QUEUE_JOIN_TIMEOUT); |
|
|
|
|
|
if (!res.write(joinMsg)) { |
|
|
req.log.warn("Kernel buffer is full; holding client request."); |
|
|
res.once("drain", onDrain); |
|
|
} else { |
|
|
clearTimeout(drainTimeout); |
|
|
resolve(); |
|
|
} |
|
|
}); |
|
|
|
|
|
await welcome; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function removeProxyMiddlewareEventListeners(req: Request) { |
|
|
|
|
|
|
|
|
const RES_ONCLOSE = `Destroying proxyRes in proxyRes close event`; |
|
|
|
|
|
|
|
|
const RES_ONERROR = `Socket error in proxyReq event`; |
|
|
|
|
|
|
|
|
const REQ_ONABORTED = `proxyReq.abort()`; |
|
|
|
|
|
|
|
|
const REQ_ONERROR = `if (req.socket.destroyed`; |
|
|
|
|
|
const res = req.res!; |
|
|
|
|
|
const resOnClose = res |
|
|
.listeners("close") |
|
|
.find((listener) => listener.toString().includes(RES_ONCLOSE)); |
|
|
if (resOnClose) { |
|
|
res.removeListener("close", resOnClose as any); |
|
|
} |
|
|
|
|
|
const resOnError = res |
|
|
.listeners("error") |
|
|
.find((listener) => listener.toString().includes(RES_ONERROR)); |
|
|
if (resOnError) { |
|
|
res.removeListener("error", resOnError as any); |
|
|
} |
|
|
|
|
|
const reqOnAborted = req |
|
|
.listeners("aborted") |
|
|
.find((listener) => listener.toString().includes(REQ_ONABORTED)); |
|
|
if (reqOnAborted) { |
|
|
req.removeListener("aborted", reqOnAborted as any); |
|
|
} |
|
|
|
|
|
const reqOnError = req |
|
|
.listeners("error") |
|
|
.find((listener) => listener.toString().includes(REQ_ONERROR)); |
|
|
if (reqOnError) { |
|
|
req.removeListener("error", reqOnError as any); |
|
|
} |
|
|
} |
|
|
|
|
|
export function registerHeartbeat(req: Request) { |
|
|
const res = req.res!; |
|
|
|
|
|
let isBufferFull = false; |
|
|
let bufferFullCount = 0; |
|
|
req.heartbeatInterval = setInterval(() => { |
|
|
if (isBufferFull) { |
|
|
bufferFullCount++; |
|
|
if (bufferFullCount >= 3) { |
|
|
req.log.error("Heartbeat skipped too many times; killing connection."); |
|
|
res.destroy(); |
|
|
} else { |
|
|
req.log.warn({ bufferFullCount }, "Heartbeat skipped; buffer is full."); |
|
|
} |
|
|
return; |
|
|
} |
|
|
|
|
|
const data = getHeartbeatPayload(); |
|
|
if (!res.write(data)) { |
|
|
isBufferFull = true; |
|
|
res.once("drain", () => (isBufferFull = false)); |
|
|
} |
|
|
}, HEARTBEAT_INTERVAL); |
|
|
monitorHeartbeat(req); |
|
|
} |
|
|
|
|
|
function monitorHeartbeat(req: Request) { |
|
|
const res = req.res!; |
|
|
|
|
|
let lastBytesSent = 0; |
|
|
req.monitorInterval = setInterval(() => { |
|
|
const bytesSent = res.socket?.bytesWritten ?? 0; |
|
|
const bytesSinceLast = bytesSent - lastBytesSent; |
|
|
req.log.debug( |
|
|
{ |
|
|
previousBytesSent: lastBytesSent, |
|
|
currentBytesSent: bytesSent, |
|
|
}, |
|
|
"Heartbeat monitor check." |
|
|
); |
|
|
lastBytesSent = bytesSent; |
|
|
|
|
|
const minBytes = Math.floor(getHeartbeatSize() / 2); |
|
|
if (bytesSinceLast < minBytes) { |
|
|
req.log.warn( |
|
|
{ minBytes, bytesSinceLast }, |
|
|
"Queued request is not processing heartbeats enough data or server is overloaded; killing connection." |
|
|
); |
|
|
res.destroy(); |
|
|
} |
|
|
}, HEARTBEAT_INTERVAL * 2); |
|
|
} |
|
|
|
|
|
|
|
|
function getHeartbeatSize() { |
|
|
const load = getProxyLoad(); |
|
|
|
|
|
if (load <= LOAD_THRESHOLD) { |
|
|
return MIN_HEARTBEAT_SIZE; |
|
|
} else { |
|
|
const excessLoad = load - LOAD_THRESHOLD; |
|
|
const size = |
|
|
MIN_HEARTBEAT_SIZE + Math.pow(excessLoad * PAYLOAD_SCALE_FACTOR, 2); |
|
|
if (size > MAX_HEARTBEAT_SIZE) return MAX_HEARTBEAT_SIZE; |
|
|
return size; |
|
|
} |
|
|
} |
|
|
|
|
|
function getHeartbeatPayload() { |
|
|
const size = getHeartbeatSize(); |
|
|
const data = |
|
|
process.env.NODE_ENV === "production" |
|
|
? crypto.randomBytes(size).toString("base64") |
|
|
: `payload size: ${size}`; |
|
|
|
|
|
return `: queue heartbeat ${data}\n\n`; |
|
|
} |
|
|
|
|
|
function getProxyLoad() { |
|
|
return Math.max(getUniqueIps(), queue.length); |
|
|
} |
|
|
|