File size: 5,751 Bytes
5c5b371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import { RequestHandler } from "express";
import { ZodIssue } from "zod";
import { initializeSseStream } from "../../../shared/streaming";
import { classifyErrorAndSend } from "../common";
import {
  RequestPreprocessor,
  blockZoomerOrigins,
  countPromptTokens,
  languageFilter,
  setApiFormat,
  transformOutboundPayload,
  validateContextSize,
  validateModelFamily,
  validateVision,
  applyQuotaLimits,
} from ".";

type RequestPreprocessorOptions = {
  /**
   * Functions to run before the request body is transformed between API
   * formats. Use this to change the behavior of the transformation, such as for
   * endpoints which can accept multiple API formats.
   */
  beforeTransform?: RequestPreprocessor[];
  /**
   * Functions to run after the request body is transformed and token counts are
   * assigned. Use this to perform validation or other actions that depend on
   * the request body being in the final API format.
   */
  afterTransform?: RequestPreprocessor[];
};

/**
 * Returns a middleware function that processes the request body into the given
 * API format, and then sequentially runs the given additional preprocessors.
 * These should be used for validation and transformations that only need to
 * happen once per request.
 *
 * These run first in the request lifecycle, a single time per request before it
 * is added to the request queue. They aren't run again if the request is
 * re-attempted after a rate limit.
 *
 * To run functions against requests every time they are re-attempted, write a
 * ProxyReqMutator and pass it to createQueuedProxyMiddleware instead.
 */
export const createPreprocessorMiddleware = (
  apiFormat: Parameters<typeof setApiFormat>[0],
  { beforeTransform, afterTransform }: RequestPreprocessorOptions = {}
): RequestHandler => {
  const preprocessors: RequestPreprocessor[] = [
    setApiFormat(apiFormat),
    blockZoomerOrigins,
    ...(beforeTransform ?? []),
    transformOutboundPayload,
    countPromptTokens,
    languageFilter,
    ...(afterTransform ?? []),
    validateContextSize,
    validateVision,
    validateModelFamily,
    applyQuotaLimits,
  ];
  return async (...args) => executePreprocessors(preprocessors, args);
};

/**
 * Returns a middleware function that specifically prepares requests for
 * OpenAI's embeddings API. Tokens are not counted because embeddings requests
 * are basically free.
 */
export const createEmbeddingsPreprocessorMiddleware = (): RequestHandler => {
  const preprocessors: RequestPreprocessor[] = [
    setApiFormat({ inApi: "openai", outApi: "openai", service: "openai" }),
    (req) => void (req.promptTokens = req.outputTokens = 0),
  ];
  return async (...args) => executePreprocessors(preprocessors, args);
};

async function executePreprocessors(
  preprocessors: RequestPreprocessor[],
  [req, res, next]: Parameters<RequestHandler>
) {
  handleTestMessage(req, res, next);
  if (res.headersSent) return;

  try {
    for (const preprocessor of preprocessors) {
      await preprocessor(req);
    }
    next();
  } catch (error) {
    if (error.constructor.name === "ZodError") {
      const issues = error?.issues
        ?.map((issue: ZodIssue) => `${issue.path.join(".")}: ${issue.message}`)
        .join("; ");
      req.log.warn({ issues }, "Prompt failed preprocessor validation.");
    } else {
      req.log.error(error, "Error while executing request preprocessor");
    }

    // If the requested has opted into streaming, the client probably won't
    // handle a non-eventstream response, but we haven't initialized the SSE
    // stream yet as that is typically done later by the request queue. We'll
    // do that here and then call classifyErrorAndSend to use the streaming
    // error handler.
    const { stream } = req.body;
    const isStreaming = stream === "true" || stream === true;
    if (isStreaming && !res.headersSent) {
      initializeSseStream(res);
    }
    classifyErrorAndSend(error as Error, req, res);
  }
}

/**
 * Bypasses the API call and returns a test message response if the request body
 * is a known test message from SillyTavern. Otherwise these messages just waste
 * API request quota and confuse users when the proxy is busy, because ST always
 * makes them with `stream: false` (which is not allowed when the proxy is busy)
 */
const handleTestMessage: RequestHandler = (req, res) => {
  const { method, body } = req;
  if (method !== "POST") {
    return;
  }

  if (isTestMessage(body)) {
    req.log.info({ body }, "Received test message. Skipping API call.");
    res.json({
      id: "test-message",
      object: "chat.completion",
      created: Date.now(),
      model: body.model,
      // openai chat
      choices: [
        {
          message: { role: "assistant", content: "Hello!" },
          finish_reason: "stop",
          index: 0,
        },
      ],
      // anthropic text
      completion: "Hello!",
      // anthropic chat
      content: [{ type: "text", text: "Hello!" }],
      // gemini
      candidates: [
        {
          content: { parts: [{ text: "Hello!" }] },
          finishReason: "stop",
        },
      ],
      proxy_note:
        "SillyTavern connection test detected. Your prompt was not sent to the actual model and this response was generated by the proxy.",
    });
  }
};

function isTestMessage(body: any) {
  const { messages, prompt, contents } = body;

  if (messages) {
    return (
      messages.length === 1 &&
      messages[0].role === "user" &&
      messages[0].content === "Hi"
    );
  } else if (contents) {
    return contents.length === 1 && contents[0].parts[0]?.text === "Hi";
  } else {
    return (
      prompt?.trim() === "Human: Hi\n\nAssistant:" ||
      prompt?.startsWith("Hi\n\n")
    );
  }
}