File size: 3,563 Bytes
5c5b371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import type { Request, Response, RequestHandler } from "express";
import { config } from "../config";
import { authenticate, getUser } from "../shared/users/user-store";
import { sendErrorToClient } from "./middleware/response/error-generator";

const GATEKEEPER = config.gatekeeper;
const PROXY_KEY = config.proxyKey;
const ADMIN_KEY = config.adminKey;

function getProxyAuthorizationFromRequest(req: Request): string | undefined {
  // Anthropic's API uses x-api-key instead of Authorization.  Some clients will
  // pass the _proxy_ key in this header too, instead of providing it as a
  // Bearer token in the Authorization header.  So we need to check both.
  // Prefer the Authorization header if both are present.
  // Google AI uses a key querystring parameter.

  if (req.headers.authorization) {
    const token = req.headers.authorization?.slice("Bearer ".length);
    delete req.headers.authorization;
    return token;
  }

  if (req.headers["x-api-key"]) {
    const token = req.headers["x-api-key"]?.toString();
    delete req.headers["x-api-key"];
    return token;
  }

  if (req.headers["x-goog-api-key"]) {
    const token = req.headers["x-goog-api-key"]?.toString();
    delete req.headers["x-goog-api-key"];
    return token;
  }
  
  if (req.query.key) {
    const token = req.query.key?.toString();
    delete req.query.key;
    return token;
  }

  return undefined;
}

export const gatekeeper: RequestHandler = (req, res, next) => {
  const token = getProxyAuthorizationFromRequest(req);

  // TODO: Generate anonymous users based on IP address for public or proxy_key
  // modes so that all middleware can assume a user of some sort is present.

  if (ADMIN_KEY && token === ADMIN_KEY) {
    return next();
  }

  if (GATEKEEPER === "none") {
    return next();
  }

  if (GATEKEEPER === "proxy_key" && token === PROXY_KEY) {
    return next();
  }

  if (GATEKEEPER === "user_token" && token) {
    // RisuAI users all come from a handful of aws lambda IPs so we cannot use
    // IP alone to distinguish between them and prevent usertoken sharing.
    // Risu sends a signed token in the request headers with an anonymous user
    // ID that we can instead use to associate requests with an individual.
    const ip = req.risuToken?.length
      ? `risu${req.risuToken}-${req.ip}`
      : req.ip;

    const { user, result } = authenticate(token, ip);

    switch (result) {
      case "success":
        req.user = user;
        return next();
      case "limited":
        return sendError(
          req,
          res,
          403,
          `Forbidden: no more IP addresses allowed for this user token`,
          { currentIp: ip, maxIps: user?.maxIps }
        );
      case "disabled":
        const bannedUser = getUser(token);
        if (bannedUser?.disabledAt) {
          const reason = bannedUser.disabledReason || "User token disabled";
          return sendError(req, res, 403, `Forbidden: ${reason}`);
        }
    }
  }

  sendError(req, res, 401, "Unauthorized");
};

function sendError(
  req: Request,
  res: Response,
  status: number,
  message: string,
  data: any = {}
) {
  const isPost = req.method === "POST";
  const hasBody = isPost && req.body;
  const hasModel = hasBody && req.body.model;

  if (!hasModel) {
    return res.status(status).json({ error: message });
  }

  sendErrorToClient({
    req,
    res,
    options: {
      title: `Proxy gatekeeper error (HTTP ${status})`,
      message,
      format: "unknown",
      statusCode: status,
      reqId: req.id,
      obj: data,
    },
  });
}