File size: 3,839 Bytes
5c5b371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import { Request } from "express";
import { assertNever } from "../utils";
import {
  getTokenCount as getClaudeTokenCount,
  init as initClaude,
} from "./claude";
import {
  estimateGoogleAITokenCount,
  getOpenAIImageCost,
  getTokenCount as getOpenAITokenCount,
  init as initOpenAi,
} from "./openai";
import {
  getTokenCount as getMistralAITokenCount,
  init as initMistralAI,
} from "./mistral";
import { APIFormat } from "../key-management";
import {
  AnthropicChatMessage,
  GoogleAIChatMessage,
  MistralAIChatMessage,
  OpenAIChatMessage,
} from "../api-schemas";

export async function init() {
  initClaude();
  initOpenAi();
  initMistralAI();
}

type OpenAIChatTokenCountRequest = {
  prompt: OpenAIChatMessage[];
  completion?: never;
  service: "openai" | "openai-responses";
};

type AnthropicChatTokenCountRequest = {
  prompt: { system: string; messages: AnthropicChatMessage[] };
  completion?: never;
  service: "anthropic-chat";
};

type GoogleAIChatTokenCountRequest = {
  prompt: GoogleAIChatMessage[];
  completion?: never;
  service: "google-ai";
};

type MistralAIChatTokenCountRequest = {
  prompt: string | MistralAIChatMessage[];
  completion?: never;
  service: "mistral-ai" | "mistral-text";
};

type FlatPromptTokenCountRequest = {
  prompt: string;
  completion?: never;
  service: "openai-text" | "anthropic-text" | "google-ai";
};

type StringCompletionTokenCountRequest = {
  prompt?: never;
  completion: string;
  service: APIFormat;
};

type OpenAIImageCompletionTokenCountRequest = {
  prompt?: never;
  completion?: never;
  service: "openai-image";
};

/**
 * Tagged union via `service` field of the different types of requests that can
 * be made to the tokenization service, for both prompts and completions
 */
type TokenCountRequest = { req: Request } & (
  | OpenAIChatTokenCountRequest
  | AnthropicChatTokenCountRequest
  | GoogleAIChatTokenCountRequest
  | MistralAIChatTokenCountRequest
  | FlatPromptTokenCountRequest
  | StringCompletionTokenCountRequest
  | OpenAIImageCompletionTokenCountRequest
);

type TokenCountResult = {
  token_count: number;
  /** Additional tokens for reasoning, if applicable. */
  reasoning_tokens?: number;
  tokenizer: string;
  tokenization_duration_ms: number;
};

export async function countTokens({
  req,
  service,
  prompt,
  completion,
}: TokenCountRequest): Promise<TokenCountResult> {
  const time = process.hrtime();
  switch (service) {
    case "anthropic-chat":
    case "anthropic-text":
      return {
        ...(await getClaudeTokenCount(prompt ?? completion)),
        tokenization_duration_ms: getElapsedMs(time),
      };
    case "openai":
    case "openai-text":
    case "openai-responses":
      return {
        ...(await getOpenAITokenCount(prompt ?? completion, req.body.model)),
        tokenization_duration_ms: getElapsedMs(time),
      };
    case "openai-image":
      return {
        ...getOpenAIImageCost({
          model: req.body.model,
          quality: req.body.quality,
          resolution: req.body.size,
          n: parseInt(req.body.n, 10) || null,
        }),
        tokenization_duration_ms: getElapsedMs(time),
      };
    case "google-ai":
      // TODO: Can't find a tokenization library for Gemini. There is an API
      // endpoint for it but it adds significant latency to the request.
      return {
        ...estimateGoogleAITokenCount(prompt ?? (completion || [])),
        tokenization_duration_ms: getElapsedMs(time),
      };
    case "mistral-ai":
    case "mistral-text":
      return {
        ...getMistralAITokenCount(prompt ?? completion),
        tokenization_duration_ms: getElapsedMs(time),
      };
    default:
      assertNever(service);
  }
}

function getElapsedMs(time: [number, number]) {
  const diff = process.hrtime(time);
  return diff[0] * 1000 + diff[1] / 1e6;
}