github-actions[bot]
Sync from GitHub Viciy2023/Qwen2API-A@b372de2fdb435c7fa78fc69c146257a58c842fba
4289eb1
const { isJson, generateUUID } = require('../utils/tools.js')
const { createUsageObject } = require('../utils/precise-tokenizer.js')
const { sendChatRequest } = require('../utils/request.js')
const accountManager = require('../utils/account.js')
const config = require('../config/index.js')
const axios = require('axios')
const { logger } = require('../utils/logger')
const usageStats = require('../utils/usage-stats')
/**
* 设置响应头
* @param {object} res - Express 响应对象
* @param {boolean} stream - 是否流式响应
*/
const setResponseHeaders = (res, stream) => {
try {
if (stream) {
res.set({
'Content-Type': 'text/event-stream',
'Cache-Control': 'no-cache',
'Connection': 'keep-alive',
})
} else {
res.set({
'Content-Type': 'application/json',
})
}
} catch (e) {
logger.error('处理聊天请求时发生错误', 'CHAT', '', e)
}
}
/**
* 处理流式响应
* @param {object} res - Express 响应对象
* @param {object} response - 上游响应流
* @param {boolean} enable_thinking - 是否启用思考模式
* @param {boolean} enable_web_search - 是否启用网络搜索
* @param {object} requestBody - 原始请求体,用于提取prompt信息
*/
const handleStreamResponse = async (res, response, enable_thinking, enable_web_search, requestBody = null) => {
try {
const message_id = generateUUID()
const decoder = new TextDecoder('utf-8')
let web_search_info = null
let thinking_start = false
let thinking_end = false
let buffer = ''
// Token消耗量统计
let totalTokens = {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0
}
let completionContent = '' // 收集完整的回复内容用于token估算
// 提取prompt文本用于token估算
let promptText = ''
if (requestBody && requestBody.messages) {
promptText = requestBody.messages.map(msg => {
if (typeof msg.content === 'string') {
return msg.content
} else if (Array.isArray(msg.content)) {
return msg.content.map(item => item.text || '').join('')
}
return ''
}).join('\n')
}
response.on('data', async (chunk) => {
const decodeText = decoder.decode(chunk, { stream: true })
// console.log(decodeText)
buffer += decodeText
const chunks = []
let startIndex = 0
while (true) {
const dataStart = buffer.indexOf('data: ', startIndex)
if (dataStart === -1) break
const dataEnd = buffer.indexOf('\n\n', dataStart)
if (dataEnd === -1) break
const dataChunk = buffer.substring(dataStart, dataEnd).trim()
chunks.push(dataChunk)
startIndex = dataEnd + 2
}
if (startIndex > 0) {
buffer = buffer.substring(startIndex)
}
for (const item of chunks) {
try {
let dataContent = item.replace("data: ", '')
let decodeJson = isJson(dataContent) ? JSON.parse(dataContent) : null
if (decodeJson === null || !decodeJson.choices || decodeJson.choices.length === 0) {
continue
}
// 提取真实的usage信息(如果上游API提供)
if (decodeJson.usage) {
totalTokens = {
prompt_tokens: decodeJson.usage.prompt_tokens || totalTokens.prompt_tokens,
completion_tokens: decodeJson.usage.completion_tokens || totalTokens.completion_tokens,
total_tokens: decodeJson.usage.total_tokens || totalTokens.total_tokens
}
}
// 处理 web_search 信息
if (decodeJson.choices[0].delta && decodeJson.choices[0].delta.name === 'web_search') {
web_search_info = decodeJson.choices[0].delta.extra.web_search_info
}
if (!decodeJson.choices[0].delta || !decodeJson.choices[0].delta.content ||
(decodeJson.choices[0].delta.phase !== 'think' && decodeJson.choices[0].delta.phase !== 'answer')) {
continue
}
let content = decodeJson.choices[0].delta.content
completionContent += content // 累计完整内容用于token估算
if (decodeJson.choices[0].delta.phase === 'think' && !thinking_start) {
thinking_start = true
if (web_search_info) {
content = `<think>\n\n${await accountManager.generateMarkdownTable(web_search_info, config.searchInfoMode)}\n\n${content}`
} else {
content = `<think>\n\n${content}`
}
}
if (decodeJson.choices[0].delta.phase === 'answer' && !thinking_end && thinking_start) {
thinking_end = true
content = `\n\n</think>\n${content}`
}
const StreamTemplate = {
"id": `chatcmpl-${message_id}`,
"object": "chat.completion.chunk",
"created": new Date().getTime(),
"choices": [
{
"index": 0,
"delta": {
"content": content
},
"finish_reason": null
}
]
}
res.write(`data: ${JSON.stringify(StreamTemplate)}\n\n`)
} catch (error) {
logger.error('流式数据处理错误', 'CHAT', '', error)
res.status(500).json({ error: "服务错误!!!" })
}
}
})
response.on('end', async () => {
try {
// 处理最终的搜索信息
if ((config.outThink === false || !enable_thinking) && web_search_info && config.searchInfoMode === "text") {
const webSearchTable = await accountManager.generateMarkdownTable(web_search_info, "text")
res.write(`data: ${JSON.stringify({
"id": `chatcmpl-${message_id}`,
"object": "chat.completion.chunk",
"created": new Date().getTime(),
"choices": [
{
"index": 0,
"delta": {
"content": `\n\n---\n${webSearchTable}`
},
"finish_reason": null
}
]
})}\n\n`)
}
// 计算最终的token使用量
if (totalTokens.prompt_tokens === 0 && totalTokens.completion_tokens === 0) {
totalTokens = createUsageObject(requestBody?.messages || promptText, completionContent, null)
logger.info(`流式使用tiktoken计算 - Prompt: ${totalTokens.prompt_tokens}, Completion: ${totalTokens.completion_tokens}, Total: ${totalTokens.total_tokens}`, 'CHAT')
} else {
logger.info(`流式使用上游真实Token - Prompt: ${totalTokens.prompt_tokens}, Completion: ${totalTokens.completion_tokens}, Total: ${totalTokens.total_tokens}`, 'CHAT')
}
// 确保token数量的有效性
totalTokens.prompt_tokens = Math.max(0, totalTokens.prompt_tokens || 0)
totalTokens.completion_tokens = Math.max(0, totalTokens.completion_tokens || 0)
totalTokens.total_tokens = totalTokens.prompt_tokens + totalTokens.completion_tokens
// 发送最终的finish chunk,包含finish_reason
res.write(`data: ${JSON.stringify({
"id": `chatcmpl-${message_id}`,
"object": "chat.completion.chunk",
"created": new Date().getTime(),
"choices": [
{
"index": 0,
"delta": {},
"finish_reason": "stop"
}
]
})}\n\n`)
// 发送usage信息chunk(符合OpenAI API标准)
res.write(`data: ${JSON.stringify({
"id": `chatcmpl-${message_id}`,
"object": "chat.completion.chunk",
"created": new Date().getTime(),
"choices": [],
"usage": totalTokens
})}\n\n`)
// 发送结束标记
res.write(`data: [DONE]\n\n`)
res.end()
await usageStats.track({ model: requestBody?.model, success: true, usage: totalTokens })
} catch (e) {
logger.error('流式响应处理错误', 'CHAT', '', e)
res.status(500).json({ error: "服务错误!!!" })
}
})
} catch (error) {
logger.error('聊天处理错误', 'CHAT', '', error)
res.status(500).json({ error: "服务错误!!!" })
}
}
/**
* 处理非流式响应(从流式数据累积完整响应)
* @param {object} res - Express 响应对象
* @param {object} response - 上游响应流
* @param {boolean} enable_thinking - 是否启用思考模式
* @param {boolean} enable_web_search - 是否启用网络搜索
* @param {string} model - 模型名称
* @param {object} requestBody - 原始请求体,用于提取prompt信息
*/
const handleNonStreamResponse = async (res, response, enable_thinking, enable_web_search, model, requestBody = null) => {
try {
const decoder = new TextDecoder('utf-8')
let buffer = ''
let fullContent = ''
let web_search_info = null
let thinking_start = false
let thinking_end = false
// Token消耗量统计
let totalTokens = {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0
}
// 提取prompt文本用于token估算
let promptText = ''
if (requestBody && requestBody.messages) {
promptText = requestBody.messages.map(msg => {
if (typeof msg.content === 'string') {
return msg.content
} else if (Array.isArray(msg.content)) {
return msg.content.map(item => item.text || '').join('')
}
return ''
}).join('\n')
}
// 处理流式响应并累积内容
await new Promise((resolve, reject) => {
response.on('data', async (chunk) => {
const decodeText = decoder.decode(chunk, { stream: true })
buffer += decodeText
const chunks = []
let startIndex = 0
while (true) {
const dataStart = buffer.indexOf('data: ', startIndex)
if (dataStart === -1) break
const dataEnd = buffer.indexOf('\n\n', dataStart)
if (dataEnd === -1) break
const dataChunk = buffer.substring(dataStart, dataEnd).trim()
chunks.push(dataChunk)
startIndex = dataEnd + 2
}
if (startIndex > 0) {
buffer = buffer.substring(startIndex)
}
for (const item of chunks) {
try {
let dataContent = item.replace("data: ", '')
let decodeJson = isJson(dataContent) ? JSON.parse(dataContent) : null
if (decodeJson === null || !decodeJson.choices || decodeJson.choices.length === 0) {
continue
}
// 提取真实的usage信息(如果上游API提供)
if (decodeJson.usage) {
totalTokens = {
prompt_tokens: decodeJson.usage.prompt_tokens || totalTokens.prompt_tokens,
completion_tokens: decodeJson.usage.completion_tokens || totalTokens.completion_tokens,
total_tokens: decodeJson.usage.total_tokens || totalTokens.total_tokens
}
}
// 处理 web_search 信息
if (decodeJson.choices[0].delta && decodeJson.choices[0].delta.name === 'web_search') {
web_search_info = decodeJson.choices[0].delta.extra.web_search_info
}
if (!decodeJson.choices[0].delta || !decodeJson.choices[0].delta.content ||
(decodeJson.choices[0].delta.phase !== 'think' && decodeJson.choices[0].delta.phase !== 'answer')) {
continue
}
let content = decodeJson.choices[0].delta.content
// 处理thinking模式
if (decodeJson.choices[0].delta.phase === 'think' && !thinking_start) {
thinking_start = true
if (web_search_info) {
const webSearchTable = await accountManager.generateMarkdownTable(web_search_info, config.searchInfoMode)
content = `<think>\n\n${webSearchTable}\n\n${content}`
} else {
content = `<think>\n\n${content}`
}
}
if (decodeJson.choices[0].delta.phase === 'answer' && !thinking_end && thinking_start) {
thinking_end = true
content = `\n\n</think>\n${content}`
}
fullContent += content
} catch (error) {
logger.error('非流式数据处理错误', 'CHAT', '', error)
}
}
})
response.on('end', () => {
resolve()
})
response.on('error', (error) => {
logger.error('非流式响应流读取错误', 'CHAT', '', error)
reject(error)
})
})
// 处理最终的搜索信息
if ((config.outThink === false || !enable_thinking) && web_search_info && config.searchInfoMode === "text") {
const webSearchTable = await accountManager.generateMarkdownTable(web_search_info, "text")
fullContent += `\n\n---\n${webSearchTable}`
}
// 计算最终的token使用量
if (totalTokens.prompt_tokens === 0 && totalTokens.completion_tokens === 0) {
totalTokens = createUsageObject(requestBody?.messages || promptText, fullContent, null)
logger.info(`非流式使用tiktoken计算 - Prompt: ${totalTokens.prompt_tokens}, Completion: ${totalTokens.completion_tokens}, Total: ${totalTokens.total_tokens}`, 'CHAT')
} else {
logger.info(`非流式使用上游真实Token - Prompt: ${totalTokens.prompt_tokens}, Completion: ${totalTokens.completion_tokens}, Total: ${totalTokens.total_tokens}`, 'CHAT')
}
// 确保token数量的有效性
totalTokens.prompt_tokens = Math.max(0, totalTokens.prompt_tokens || 0)
totalTokens.completion_tokens = Math.max(0, totalTokens.completion_tokens || 0)
totalTokens.total_tokens = totalTokens.prompt_tokens + totalTokens.completion_tokens
// 返回完整的JSON响应
const bodyTemplate = {
"id": `chatcmpl-${generateUUID()}`,
"object": "chat.completion",
"created": new Date().getTime(),
"model": model,
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": fullContent
},
"finish_reason": "stop"
}
],
"usage": totalTokens
}
res.json(bodyTemplate)
await usageStats.track({ model, success: true, usage: totalTokens })
} catch (error) {
logger.error('非流式聊天处理错误', 'CHAT', '', error)
await usageStats.track({ model, success: false, usage: { total_tokens: 0 } })
res.status(500)
.json({
error: "服务错误!!!"
})
}
}
/**
* 主要的聊天完成处理函数
* @param {object} req - Express 请求对象
* @param {object} res - Express 响应对象
*/
const handleChatCompletion = async (req, res) => {
const { stream, model } = req.body
const enable_thinking = req.enable_thinking
const enable_web_search = req.enable_web_search
try {
const response_data = await sendChatRequest(req.body)
if (!response_data.status || !response_data.response) {
await usageStats.track({ model, success: false, usage: { total_tokens: 0 } })
res.status(500)
.json({
error: "请求发送失败!!!"
})
return
}
if (stream) {
setResponseHeaders(res, true)
await handleStreamResponse(res, response_data.response, enable_thinking, enable_web_search, req.body)
} else {
setResponseHeaders(res, false)
await handleNonStreamResponse(res, response_data.response, enable_thinking, enable_web_search, model, req.body)
}
} catch (error) {
logger.error('聊天处理错误', 'CHAT', '', error)
await usageStats.track({ model, success: false, usage: { total_tokens: 0 } })
res.status(500)
.json({
error: "token无效,请求发送失败!!!"
})
}
}
module.exports = {
handleChatCompletion,
handleStreamResponse,
handleNonStreamResponse,
setResponseHeaders
}