File size: 18,408 Bytes
5c5b371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
/**
 * Very scuffed request queue. OpenAI's GPT-4 keys have a very strict rate limit
 * of 40000 generated tokens per minute. We don't actually know how many tokens
 * a given key has generated, so our queue will simply retry requests that fail
 * with a non-billing related 429 over and over again until they succeed.
 *
 * When a request to a proxied endpoint is received, we create a closure around
 * the call to http-proxy-middleware and attach it to the request. This allows
 * us to pause the request until we have a key available. Further, if the
 * proxied request encounters a retryable error, we can simply put the request
 * back in the queue and it will be retried later using the same closure.
 */

import crypto from "crypto";
import { Handler, Request } from "express";
import { config } from "../config";
import { BadRequestError, TooManyRequestsError } from "../shared/errors";
import { keyPool } from "../shared/key-management";
import {
  getModelFamilyForRequest,
  MODEL_FAMILIES,
  ModelFamily,
} from "../shared/models";
import { initializeSseStream } from "../shared/streaming";
import { logger } from "../logger";
import { getUniqueIps } from "./rate-limit";
import { ProxyReqMutator, RequestPreprocessor } from "./middleware/request";
import { sendErrorToClient } from "./middleware/response/error-generator";
import { ProxyReqManager } from "./middleware/request/proxy-req-manager";
import { classifyErrorAndSend } from "./middleware/common";

const queue: Request[] = [];
const log = logger.child({ module: "request-queue" });

/** Maximum number of queue slots for individual users. */
const USER_CONCURRENCY_LIMIT = parseInt(
  process.env.USER_CONCURRENCY_LIMIT ?? "1"
);
const MIN_HEARTBEAT_SIZE = parseInt(process.env.MIN_HEARTBEAT_SIZE_B ?? "512");
const MAX_HEARTBEAT_SIZE =
  1024 * parseInt(process.env.MAX_HEARTBEAT_SIZE_KB ?? "1024");
const HEARTBEAT_INTERVAL =
  1000 * parseInt(process.env.HEARTBEAT_INTERVAL_SEC ?? "5");
const LOAD_THRESHOLD = parseFloat(process.env.LOAD_THRESHOLD ?? "150");
const PAYLOAD_SCALE_FACTOR = parseFloat(
  process.env.PAYLOAD_SCALE_FACTOR ?? "6"
);
const QUEUE_JOIN_TIMEOUT = 5000;

/**
 * Returns an identifier for a request. This is used to determine if a
 * request is already in the queue.
 *
 * This can be (in order of preference):
 * - user token assigned by the proxy operator
 * - x-risu-tk header, if the request is from RisuAI.xyz
 * - 'shared-ip' if the request is from a shared IP address like Agnai.chat
 * - IP address
 */
function getIdentifier(req: Request) {
  if (req.user) return req.user.token;
  if (req.risuToken) return req.risuToken;
  // if (isFromSharedIp(req)) return "shared-ip";
  return req.ip;
}

const sharesIdentifierWith = (incoming: Request) => (queued: Request) =>
  getIdentifier(queued) === getIdentifier(incoming);

async function enqueue(req: Request) {
  if (req.socket.destroyed || req.res?.writableEnded) {
    // In rare cases, a request can be disconnected after it is dequeued for a
    // retry, but before it is re-enqueued. In this case we may miss the abort
    // and the request will loop in the queue forever.
    req.log.warn("Attempt to enqueue aborted request.");
    throw new Error("Attempt to enqueue aborted request.");
  }

  const enqueuedRequestCount = queue.filter(sharesIdentifierWith(req)).length;
  // Do not apply concurrency limit to "special" users
  if (enqueuedRequestCount >= USER_CONCURRENCY_LIMIT && req.user?.type !== "special") {
    throw new TooManyRequestsError(
      "Your IP or user token already has another request in the queue."
    );
  }

  // shitty hack to remove hpm's event listeners on retried requests
  removeProxyMiddlewareEventListeners(req);

  // If the request opted into streaming, we need to register a heartbeat
  // handler to keep the connection alive while it waits in the queue. We
  // deregister the handler when the request is dequeued.
  const { stream } = req.body;
  if (stream === "true" || stream === true || req.isStreaming) {
    const res = req.res!;
    if (!res.headersSent) {
      await initStreaming(req);
    }
    registerHeartbeat(req);
  } else if (getProxyLoad() > LOAD_THRESHOLD) {
    throw new BadRequestError(
      "Due to heavy traffic on this proxy, you must enable streaming in your chat client to use this endpoint."
    );
  }

  queue.push(req);
  req.queueOutTime = 0;

  const removeFromQueue = () => {
    req.log.info(`Removing aborted request from queue.`);
    const index = queue.indexOf(req);
    if (index !== -1) {
      queue.splice(index, 1);
    }
    if (req.heartbeatInterval) clearInterval(req.heartbeatInterval);
    if (req.monitorInterval) clearInterval(req.monitorInterval);
  };
  req.onAborted = removeFromQueue;
  req.res!.once("close", removeFromQueue);

  if (req.retryCount ?? 0 > 0) {
    req.log.info({ retries: req.retryCount }, `Enqueued request for retry.`);
  } else {
    const size = req.socket.bytesRead;
    const endpoint = req.url?.split("?")[0];
    req.log.info({ size, endpoint }, `Enqueued new request.`);
  }
}

export async function reenqueueRequest(req: Request) {
  req.log.info(
    { key: req.key?.hash, retryCount: req.retryCount },
    `Re-enqueueing request due to retryable error`
  );
  req.retryCount++;
  await enqueue(req);
}

function getQueueForPartition(partition: ModelFamily): Request[] {
  return queue.filter((req) => getModelFamilyForRequest(req) === partition);
}

export function dequeue(partition: ModelFamily): Request | undefined {
  const modelQueue = getQueueForPartition(partition);

  if (modelQueue.length === 0) {
    return undefined;
  }

  const req = modelQueue.reduce((prev, curr) =>
    prev.startTime +
      config.tokensPunishmentFactor *
        ((prev.promptTokens ?? 0) + (prev.outputTokens ?? 0)) <
    curr.startTime +
      config.tokensPunishmentFactor *
        ((curr.promptTokens ?? 0) + (curr.outputTokens ?? 0))
      ? prev
      : curr
  );
  queue.splice(queue.indexOf(req), 1);

  if (req.onAborted) {
    req.res!.off("close", req.onAborted);
    req.onAborted = undefined;
  }

  if (req.heartbeatInterval) clearInterval(req.heartbeatInterval);
  if (req.monitorInterval) clearInterval(req.monitorInterval);

  // Track the time leaving the queue now, but don't add it to the wait times
  // yet because we don't know if the request will succeed or fail. We track
  // the time now and not after the request succeeds because we don't want to
  // include the model processing time.
  req.queueOutTime = Date.now();
  return req;
}

/**
 * Naive way to keep the queue moving by continuously dequeuing requests. Not
 * ideal because it limits throughput but we probably won't have enough traffic
 * or keys for this to be a problem.  If it does we can dequeue multiple
 * per tick.
 **/
function processQueue() {
  // This isn't completely correct, because a key can service multiple models.
  // Currently if a key is locked out on one model it will also stop servicing
  // the others, because we only track rate limits for the key as a whole.

  const reqs: (Request | undefined)[] = [];
  MODEL_FAMILIES.forEach((modelFamily) => {
    const lockout = keyPool.getLockoutPeriod(modelFamily);
    if (lockout === 0) {
      reqs.push(dequeue(modelFamily));
    }
  });

  reqs.filter(Boolean).forEach((req) => {
    if (req?.proceed) {
      const modelFamily = getModelFamilyForRequest(req!);
      req.log.info(
        { retries: req.retryCount, partition: modelFamily },
        `Dequeuing request.`
      );
      req.proceed();
    }
  });
  setTimeout(processQueue, 50);
}

/**
 * Kill stalled requests after 5 minutes, and remove tracked wait times after 2
 * minutes.
 **/
function cleanQueue() {
  const now = Date.now();
  const oldRequests = queue.filter(
    (req) => now - (req.startTime ?? now) > 5 * 60 * 1000
  );
  oldRequests.forEach((req) => {
    req.log.info(`Removing request from queue after 5 minutes.`);
    killQueuedRequest(req);
  });

  const index = waitTimes.findIndex(
    (waitTime) => now - waitTime.end > 300 * 1000
  );
  const removed = waitTimes.splice(0, index + 1);
  log.trace(
    { stalledRequests: oldRequests.length, prunedWaitTimes: removed.length },
    `Cleaning up request queue.`
  );
  setTimeout(cleanQueue, 20 * 1000);
}

export function start() {
  MODEL_FAMILIES.forEach((modelFamily) => {
    historicalEmas.set(modelFamily, 0);
    currentEmas.set(modelFamily, 0);
    estimates.set(modelFamily, 0);
  });
  processQueue();
  cleanQueue();
  log.info(`Started request queue.`);
}

let waitTimes: {
  partition: ModelFamily;
  start: number;
  end: number;
}[] = [];

/** Adds a successful request to the list of wait times. */
export function trackWaitTime(req: Request) {
  waitTimes.push({
    partition: getModelFamilyForRequest(req),
    start: req.startTime!,
    end: req.queueOutTime ?? Date.now(),
  });
}

const WAIT_TIME_INTERVAL = 3000;
const ALPHA_HISTORICAL = 0.2;
const ALPHA_CURRENT = 0.3;
const historicalEmas: Map<ModelFamily, number> = new Map();
const currentEmas: Map<ModelFamily, number> = new Map();
const estimates: Map<ModelFamily, number> = new Map();

export function getEstimatedWaitTime(partition: ModelFamily) {
  return estimates.get(partition) ?? 0;
}

/**
 * Returns estimated wait time for the given queue partition in milliseconds.
 * Requests which are deprioritized are not included in the calculation as they
 * would skew the results due to their longer wait times.
 */
function calculateWaitTime(partition: ModelFamily) {
  const now = Date.now();
  const recentWaits = waitTimes
    .filter((wait) => {
      const isSamePartition = wait.partition === partition;
      const isRecent = now - wait.end < 300 * 1000;
      return isSamePartition && isRecent;
    })
    .map((wait) => wait.end - wait.start);
  const recentAverage = recentWaits.length
    ? recentWaits.reduce((sum, wait) => sum + wait, 0) / recentWaits.length
    : 0;

  const historicalEma = historicalEmas.get(partition) ?? 0;
  historicalEmas.set(
    partition,
    ALPHA_HISTORICAL * recentAverage + (1 - ALPHA_HISTORICAL) * historicalEma
  );

  const currentWaits = queue
    .filter((req) => getModelFamilyForRequest(req) === partition)
    .map((req) => now - req.startTime!);
  const longestCurrentWait = Math.max(...currentWaits, 0);

  const currentEma = currentEmas.get(partition) ?? 0;
  currentEmas.set(
    partition,
    ALPHA_CURRENT * longestCurrentWait + (1 - ALPHA_CURRENT) * currentEma
  );

  return (historicalEma + currentEma) / 2;
}

setInterval(() => {
  MODEL_FAMILIES.forEach((modelFamily) => {
    estimates.set(modelFamily, calculateWaitTime(modelFamily));
  });
}, WAIT_TIME_INTERVAL);

export function getQueueLength(partition: ModelFamily | "all" = "all") {
  if (partition === "all") {
    return queue.length;
  }
  const modelQueue = getQueueForPartition(partition);
  return modelQueue.length;
}

export function createQueueMiddleware({
  mutations = [],
  proxyMiddleware,
}: {
  mutations?: ProxyReqMutator[];
  proxyMiddleware: Handler;
}): Handler {
  return async (req, res, next) => {
    req.proceed = async () => {
      // canonicalize the stream field which is set in a few places not always
      // consistently
      req.isStreaming = req.isStreaming || String(req.body.stream) === "true";
      req.body.stream = req.isStreaming;

      try {
        // Just before executing the proxyMiddleware, we will create a
        // ProxyReqManager to track modifications to the request. This allows
        // us to revert those changes if the proxied request fails with a
        // retryable error. That happens in proxyMiddleware's onProxyRes
        // handler.
        const changeManager = new ProxyReqManager(req);
        req.changeManager = changeManager;
        for (const mutator of mutations) {
          await mutator(changeManager);
        }
      } catch (err) {
        // Failure during request preparation is a fatal error.
        return classifyErrorAndSend(err, req, res);
      }

      proxyMiddleware(req, res, next);
    };

    try {
      await enqueue(req);
    } catch (err: any) {
      const title =
        err.status === 429
          ? "Proxy queue error (too many concurrent requests)"
          : "Proxy queue error (streaming required)";
      sendErrorToClient({
        options: {
          title,
          message: err.message,
          format: req.inboundApi,
          reqId: req.id,
          model: req.body?.model,
        },
        req,
        res,
      });
    }
  };
}

function killQueuedRequest(req: Request) {
  if (!req.res || req.res.writableEnded) {
    req.log.warn(`Attempted to terminate request that has already ended.`);
    queue.splice(queue.indexOf(req), 1);
    return;
  }
  const res = req.res;
  try {
    const message = `Your request has been terminated by the proxy because it has been in the queue for more than 5 minutes.`;
    sendErrorToClient({
      options: {
        title: "Proxy queue error (request killed)",
        message,
        format: req.inboundApi,
        reqId: req.id,
        model: req.body?.model,
      },
      req,
      res,
    });
  } catch (e) {
    req.log.error(e, `Error killing stalled request.`);
  }
}

async function initStreaming(req: Request) {
  const res = req.res!;
  initializeSseStream(res);

  const joinMsg = `: joining queue at position ${
    queue.length
  }\n\n${getHeartbeatPayload()}`;

  let drainTimeout: NodeJS.Timeout;
  const welcome = new Promise<void>((resolve, reject) => {
    const onDrain = () => {
      clearTimeout(drainTimeout);
      req.log.debug(`Client finished consuming join message.`);
      res.off("drain", onDrain);
      resolve();
    };

    drainTimeout = setTimeout(() => {
      res.off("drain", onDrain);
      res.destroy();
      reject(new Error("Unreponsive streaming client; killing connection"));
    }, QUEUE_JOIN_TIMEOUT);

    if (!res.write(joinMsg)) {
      req.log.warn("Kernel buffer is full; holding client request.");
      res.once("drain", onDrain);
    } else {
      clearTimeout(drainTimeout);
      resolve();
    }
  });

  await welcome;
}

/**
 * http-proxy-middleware attaches a bunch of event listeners to the req and
 * res objects which causes problems with our approach to re-enqueuing failed
 * proxied requests. This function removes those event listeners.
 * We don't have references to the original event listeners, so we have to
 * look through the list and remove HPM's listeners by looking for particular
 * strings in the listener functions. This is an astoundingly shitty way to do
 * this, but it's the best I can come up with.
 */
function removeProxyMiddlewareEventListeners(req: Request) {
  // node_modules/http-proxy-middleware/dist/plugins/default/debug-proxy-errors-plugin.js:29
  // res.listeners('close')
  const RES_ONCLOSE = `Destroying proxyRes in proxyRes close event`;
  // node_modules/http-proxy-middleware/dist/plugins/default/debug-proxy-errors-plugin.js:19
  // res.listeners('error')
  const RES_ONERROR = `Socket error in proxyReq event`;
  // node_modules/http-proxy/lib/http-proxy/passes/web-incoming.js:146
  // req.listeners('aborted')
  const REQ_ONABORTED = `proxyReq.abort()`;
  // node_modules/http-proxy/lib/http-proxy/passes/web-incoming.js:156
  // req.listeners('error')
  const REQ_ONERROR = `if (req.socket.destroyed`;

  const res = req.res!;

  const resOnClose = res
    .listeners("close")
    .find((listener) => listener.toString().includes(RES_ONCLOSE));
  if (resOnClose) {
    res.removeListener("close", resOnClose as any);
  }

  const resOnError = res
    .listeners("error")
    .find((listener) => listener.toString().includes(RES_ONERROR));
  if (resOnError) {
    res.removeListener("error", resOnError as any);
  }

  const reqOnAborted = req
    .listeners("aborted")
    .find((listener) => listener.toString().includes(REQ_ONABORTED));
  if (reqOnAborted) {
    req.removeListener("aborted", reqOnAborted as any);
  }

  const reqOnError = req
    .listeners("error")
    .find((listener) => listener.toString().includes(REQ_ONERROR));
  if (reqOnError) {
    req.removeListener("error", reqOnError as any);
  }
}

export function registerHeartbeat(req: Request) {
  const res = req.res!;

  let isBufferFull = false;
  let bufferFullCount = 0;
  req.heartbeatInterval = setInterval(() => {
    if (isBufferFull) {
      bufferFullCount++;
      if (bufferFullCount >= 3) {
        req.log.error("Heartbeat skipped too many times; killing connection.");
        res.destroy();
      } else {
        req.log.warn({ bufferFullCount }, "Heartbeat skipped; buffer is full.");
      }
      return;
    }

    const data = getHeartbeatPayload();
    if (!res.write(data)) {
      isBufferFull = true;
      res.once("drain", () => (isBufferFull = false));
    }
  }, HEARTBEAT_INTERVAL);
  monitorHeartbeat(req);
}

function monitorHeartbeat(req: Request) {
  const res = req.res!;

  let lastBytesSent = 0;
  req.monitorInterval = setInterval(() => {
    const bytesSent = res.socket?.bytesWritten ?? 0;
    const bytesSinceLast = bytesSent - lastBytesSent;
    req.log.debug(
      {
        previousBytesSent: lastBytesSent,
        currentBytesSent: bytesSent,
      },
      "Heartbeat monitor check."
    );
    lastBytesSent = bytesSent;

    const minBytes = Math.floor(getHeartbeatSize() / 2);
    if (bytesSinceLast < minBytes) {
      req.log.warn(
        { minBytes, bytesSinceLast },
        "Queued request is not processing heartbeats enough data or server is overloaded; killing connection."
      );
      res.destroy();
    }
  }, HEARTBEAT_INTERVAL * 2);
}

/** Sends larger heartbeats when the queue is overloaded */
function getHeartbeatSize() {
  const load = getProxyLoad();

  if (load <= LOAD_THRESHOLD) {
    return MIN_HEARTBEAT_SIZE;
  } else {
    const excessLoad = load - LOAD_THRESHOLD;
    const size =
      MIN_HEARTBEAT_SIZE + Math.pow(excessLoad * PAYLOAD_SCALE_FACTOR, 2);
    if (size > MAX_HEARTBEAT_SIZE) return MAX_HEARTBEAT_SIZE;
    return size;
  }
}

function getHeartbeatPayload() {
  const size = getHeartbeatSize();
  const data =
    process.env.NODE_ENV === "production"
      ? crypto.randomBytes(size).toString("base64")
      : `payload size: ${size}`;

  return `: queue heartbeat ${data}\n\n`;
}

function getProxyLoad() {
  return Math.max(getUniqueIps(), queue.length);
}