|
|
from flask import Flask, request, jsonify
|
|
|
import tiktoken
|
|
|
import os
|
|
|
|
|
|
app = Flask(__name__)
|
|
|
|
|
|
|
|
|
MODEL_MAPPINGS = {
|
|
|
|
|
|
"gpt-4o": "o200k_base",
|
|
|
"gpt-4-turbo": "cl100k_base",
|
|
|
"gpt-4": "cl100k_base",
|
|
|
|
|
|
|
|
|
"gpt-3.5-turbo": "cl100k_base",
|
|
|
"gpt-35-turbo": "cl100k_base",
|
|
|
|
|
|
|
|
|
"text-davinci-003": "p50k_base",
|
|
|
"text-davinci-002": "p50k_base",
|
|
|
"davinci": "r50k_base",
|
|
|
|
|
|
|
|
|
"text-embedding-ada-002": "cl100k_base",
|
|
|
}
|
|
|
|
|
|
@app.route('/count_tokens', methods=['POST'])
|
|
|
def count_tokens():
|
|
|
try:
|
|
|
data = request.json
|
|
|
messages = data.get('messages', [])
|
|
|
system = data.get('system')
|
|
|
model = data.get('model', 'gpt-3.5-turbo')
|
|
|
|
|
|
|
|
|
model_key = model.lower()
|
|
|
encoding_name = None
|
|
|
|
|
|
|
|
|
if model_key in MODEL_MAPPINGS:
|
|
|
encoding_name = MODEL_MAPPINGS[model_key]
|
|
|
else:
|
|
|
|
|
|
for key in MODEL_MAPPINGS:
|
|
|
if key in model_key:
|
|
|
encoding_name = MODEL_MAPPINGS[key]
|
|
|
break
|
|
|
|
|
|
|
|
|
if not encoding_name:
|
|
|
encoding_name = "cl100k_base"
|
|
|
|
|
|
|
|
|
try:
|
|
|
encoding = tiktoken.get_encoding(encoding_name)
|
|
|
except KeyError:
|
|
|
|
|
|
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
|
|
|
|
|
|
|
|
|
total_tokens = 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if encoding_name in ["cl100k_base", "o200k_base"]:
|
|
|
|
|
|
total_tokens += 3
|
|
|
|
|
|
|
|
|
for message in messages:
|
|
|
total_tokens += 4
|
|
|
|
|
|
for key, value in message.items():
|
|
|
total_tokens += len(encoding.encode(value))
|
|
|
|
|
|
|
|
|
if key == "name":
|
|
|
total_tokens -= 1
|
|
|
|
|
|
|
|
|
if system:
|
|
|
total_tokens += 4
|
|
|
total_tokens += len(encoding.encode(system))
|
|
|
else:
|
|
|
|
|
|
all_text = ""
|
|
|
if system:
|
|
|
all_text += system + "\n\n"
|
|
|
|
|
|
for message in messages:
|
|
|
role = message.get('role', '')
|
|
|
content = message.get('content', '')
|
|
|
all_text += f"{role}: {content}\n"
|
|
|
|
|
|
total_tokens = len(encoding.encode(all_text))
|
|
|
|
|
|
return jsonify({
|
|
|
'input_tokens': total_tokens,
|
|
|
'model': model,
|
|
|
'encoding': encoding_name
|
|
|
})
|
|
|
except Exception as e:
|
|
|
return jsonify({
|
|
|
'error': str(e)
|
|
|
}), 400
|
|
|
|
|
|
@app.route('/health', methods=['GET'])
|
|
|
def health():
|
|
|
return jsonify({
|
|
|
'status': 'healthy',
|
|
|
'tokenizer': 'openai-tiktoken',
|
|
|
'supported_models': list(MODEL_MAPPINGS.keys())
|
|
|
})
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
app.run(host='127.0.0.1', port=7862) |