grok-project / core /main.js
kevin
修改输出顺序
139a44f
import express from 'express';
import crypto from 'crypto';
import { CONFIG, HEADERS } from './config.js';
import { ApiError, AuthenticationError } from './errors.js';
import { getRandomIDPro, validateRequest, withTimeout } from './utils.js';
/**
* Xgrok2Worker 类 - 处理与 Grok-2 模型的所有交互
*/
class Xgrok2Worker {
/**
* 初始化 Xgrok2Worker 实例
* @param {string} modelId - 模型ID
* @throws {ValidationError} 当模型不支持时抛出错误
*/
constructor(modelId) {
if (!CONFIG.SUPPORTED_MODELS.includes(modelId)) {
throw new ValidationError(`Unsupported model: ${modelId}`);
}
this.modelId = modelId;
}
/**
* 构造发送给 Grok API 的请求体
* @param {Array} messages - 消息数组
* @param {string} conversationId - 会话ID
* @returns {Object} 格式化的请求体
*/
constructRequestBody(messages, conversationId) {
return {
responses: messages,
systemPromptName: "",
grokModelOptionId: this.modelId,
conversationId,
returnSearchResults: false,
returnCitations: false,
promptMetadata: {
promptSource: "NATURAL",
action: "INPUT"
},
imageGenerationCount: 1,
requestFeatures: {
eagerTweets: false,
serverHistory: false
}
};
}
/**
* 转换消息格式为 Grok API 所需的格式
* @param {Array} messages - 原始消息数组
* @returns {Array} 转换后的消息数组
*/
transformMessages(messages) {
return messages.reduce((acc, msg) => {
const transformed = {
message: msg.content,
sender: msg.role === 'assistant' ? 2 : 1,
...(msg.role !== 'assistant' && { fileAttachments: [] })
};
// 合并连续的用户消息
if (acc.length > 0 &&
acc[acc.length - 1].sender === 1 &&
transformed.sender === 1) {
acc[acc.length - 1].message += "\n" + transformed.message;
} else {
acc.push(transformed);
}
return acc;
}, []);
}
/**
* 发送聊天请求到 Grok API
* @param {Object} request - 请求对象
* @returns {Promise<Object>} 响应对象
* @throws {ApiError} 当请求失败时抛出错误
*/
async sendChatRequest(request) {
const conversationId = `18758${getRandomIDPro(14)}`;
const transformedMessages = this.transformMessages(request.messages);
const requestBody = this.constructRequestBody(transformedMessages, conversationId);
try {
// 发送请求并设置超时
const response = await withTimeout(
fetch(`${CONFIG.BASE_URL}/2/grok/add_response.json`, {
method: 'POST',
headers: {
...HEADERS,
'x-csrf-token': CONFIG.CT0,
'cookie': `auth_token=${CONFIG.AUTH_TOKEN};ct0=${CONFIG.CT0}`
},
body: JSON.stringify(requestBody)
}),
CONFIG.REQUEST_TIMEOUT
);
if (!response.ok) {
throw new ApiError(`Upstream service error: ${response.status}`);
}
// 根据请求类型返回不同的处理结果
if (request.stream) {
return { response, conversationId };
} else {
const content = await this.processResponse(response);
await this.cleanupConversation(conversationId);
return this.formatResponse(content, request.model);
}
} catch (error) {
throw new ApiError(error.message);
}
}
/**
* 处理流式响应
* @param {Response} response - fetch响应对象
* @param {string} model - 模型名称
* @param {Response} res - Express响应对象
*/
async handleStreamResponse(response, model, res) {
const reader = response.body.getReader();
const decoder = new TextDecoder();
let buffer = '';
const messageId = `chatcmpl-${crypto.randomUUID()}`;
try {
while (true) {
const { done, value } = await reader.read();
if (done) break;
// 解码响应数据并处理缓冲区
buffer += decoder.decode(value, { stream: true });
const lines = buffer.split('\n');
buffer = lines.pop() || ''; // 保留未完成的行
// 处理每一行数据
for (const line of lines) {
if (!line.trim()) continue;
try {
const json = JSON.parse(line);
if (json.result?.message) {
let chunk = json.result.message;
// 去除链接 [link](#tweet=1730030975931846939) ==
// const regex = /\[link\]\(#tweet=\d+\)\n*==/g;
// chunk = chunk.replace(regex, '');
// 解析链接
const regex = /\[link\]\(#tweet=(\d+)\)(\s*==\s*)*/g;
chunk = chunk.replace(regex, (match, tweetId) => {
return `[link](https://twitter.com/i/status/${tweetId})`;
});
process.stdout.write(chunk);
process.stdout.write("|")
const payload = {
id: messageId,
object: 'chat.completion.chunk',
created: Math.floor(Date.now() / 1000),
model,
choices: [{
index: 0,
delta: { content: chunk },
finish_reason: null
}]
};
res.write(`data: ${JSON.stringify(payload)}\n\n`);
}
} catch (e) {
console.error('Error parsing JSON:', e);
}
}
}
// 发送结束标记
const endPayload = {
id: messageId,
object: 'chat.completion.chunk',
created: Math.floor(Date.now() / 1000),
model,
choices: [{
index: 0,
delta: {},
finish_reason: 'stop'
}]
};
res.write(`data: ${JSON.stringify(endPayload)}\n\n`);
res.write('data: [DONE]\n\n'); // 添加明确的结束标记
// 确保最后一行输出在控制台中
process.stdout.write('\n');
// process.stdout.cork();
// process.stdout.uncork();
} finally {
reader.releaseLock();
}
}
/**
* 处理非流式响应
* @param {Response} response - fetch响应对象
* @returns {Promise<string>} 处理后的响应内容
*/
async processResponse(response) {
const reader = response.body.getReader();
const decoder = new TextDecoder();
let fullResponse = '';
while (true) {
const { done, value } = await reader.read();
if (done) break;
fullResponse += decoder.decode(value);
}
return this.extractMessages(fullResponse);
}
/**
* 从响应数据中提取消息
* @param {string} data - 原始响应数据
* @returns {string} 提取的消息内容
*/
async extractMessages(data) {
return data
.trim()
.split('\n')
.map(line => {
try {
const json = JSON.parse(line);
return json.result?.message || '';
} catch {
return '';
}
})
.filter(Boolean)
.join('');
}
/**
* 清理会话
* @param {string} conversationId - 会话ID
*/
async cleanupConversation(conversationId) {
await fetch(`${CONFIG.BASE_URL}/i/api/graphql/TlKHSWVMVeaa-i7dqQqFQA/ConversationItem_DeleteConversationMutation`, {
method: 'POST',
headers: {
...HEADERS,
'x-csrf-token': CONFIG.CT0,
'cookie': `auth_token=${CONFIG.AUTH_TOKEN};ct0=${CONFIG.CT0}`
},
body: JSON.stringify({
variables: { conversationId },
queryId: "TlKHSWVMVeaa-i7dqQqFQA"
})
});
}
/**
* 格式化响应为 OpenAI 格式
* @param {string} content - 响应内容
* @param {string} model - 模型名称
* @returns {Object} 格式化的响应对象
*/
formatResponse(content, model) {
// 去除链接 [link](#tweet=1730030975931846939) ==
// const regex = /\[link\]\(#tweet=\d+\)\n*==/g;
// content = content.replace(regex, '');
// 解析链接
const regex = /\[link\]\(#tweet=(\d+)\)(\s*==\s*)*/g;
content = content.replace(regex, (match, tweetId) => {
return `[link](https://twitter.com/i/status/${tweetId})`;
});
process.stdout.write(content);
process.stdout.write("|")
process.stdout.write("\n")
return {
id: `chatcmpl-${crypto.randomUUID()}`,
object: "chat.completion",
created: Math.floor(Date.now() / 1000),
model,
choices: [{
index: 0,
message: {
role: "assistant",
content
},
finish_reason: "stop"
}],
usage: null
};
}
}
// 初始化 Express 应用
const app = express();
// 配置中间件
app.use(express.json({ limit: '5mb' }));
app.use(express.urlencoded({ extended: true, limit: '5mb' }));
/**
* 错误处理中间件
*/
app.use((err, req, res, next) => {
console.error('Error:', err);
const status = err.status || 500;
res.status(status).json({
error: {
message: err.message,
type: err.code || 'server_error',
param: null,
code: err.code || null
}
});
});
/**
* 认证中间件
*/
const authenticate = (req, res, next) => {
const authToken = req.headers.authorization?.replace('Bearer ', '');
if (authToken !== CONFIG.API_KEY) {
throw new AuthenticationError();
}
next();
};
// API 路由
app.get('/api/v1/models', (req, res) => {
res.json({
object: "list",
data: CONFIG.SUPPORTED_MODELS.map(id => ({
id,
object: "model",
created: 1706745937,
owned_by: "xai",
}))
});
});
/**
* 处理聊天完成请求的路由
*/
app.post('/api/v1/chat/completions', authenticate, async (req, res, next) => {
try {
const validatedRequest = validateRequest(req.body);
const worker = new Xgrok2Worker(validatedRequest.model);
if (validatedRequest.stream) {
// 设置 SSE headers
res.setHeader('Content-Type', 'text/event-stream');
res.setHeader('Cache-Control', 'no-cache');
res.setHeader('Connection', 'keep-alive');
const { response, conversationId } = await worker.sendChatRequest(validatedRequest);
try {
await worker.handleStreamResponse(response, validatedRequest.model, res);
} finally {
// 确保清理会话
await worker.cleanupConversation(conversationId);
res.end();
}
} else {
// 处理非流式响应
const result = await worker.sendChatRequest(validatedRequest);
res.json(result);
}
} catch (error) {
next(error);
}
});
// 处理 404 路由
app.use((req, res) => {
res.status(404).send('请使用正确请求路径');
});
// 启动服务器
app.listen(CONFIG.PORT, () => {
console.log(`服务器运行在端口 ${CONFIG.PORT}`);
});