smgc commited on
Commit
a50a9e8
·
verified ·
1 Parent(s): f8fecf7

Create app.js

Browse files
Files changed (1) hide show
  1. app.js +265 -0
app.js ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // app.js
2
+ const express = require('express');
3
+ const { v4: uuidv4 } = require('uuid');
4
+ require('dotenv').config();
5
+
6
+ const app = express();
7
+
8
+ // 中间件配置
9
+ app.use(express.json());
10
+ app.use(express.urlencoded({ extended: true }));
11
+
12
+ // Helper function to convert string to hex bytes
13
+ function stringToHex(str, modelName) {
14
+ const bytes = Buffer.from(str, 'utf-8');
15
+ const byteLength = bytes.length;
16
+
17
+ // Calculate lengths and fields similar to Python version
18
+ const FIXED_HEADER = 2;
19
+ const SEPARATOR = 1;
20
+ const FIXED_SUFFIX_LENGTH = 0xA3 + modelName.length;
21
+
22
+ // 计算文本长度字段 (类似 Python 中的 base_length1)
23
+ let textLengthField1, textLengthFieldSize1;
24
+ if (byteLength < 128) {
25
+ textLengthField1 = byteLength.toString(16).padStart(2, '0');
26
+ textLengthFieldSize1 = 1;
27
+ } else {
28
+ const lowByte1 = (byteLength & 0x7F) | 0x80;
29
+ const highByte1 = (byteLength >> 7) & 0xFF;
30
+ textLengthField1 = lowByte1.toString(16).padStart(2, '0') + highByte1.toString(16).padStart(2, '0');
31
+ textLengthFieldSize1 = 2;
32
+ }
33
+
34
+ // 计算基础长度 (类似 Python 中的 base_length)
35
+ const baseLength = byteLength + 0x2A;
36
+ let textLengthField, textLengthFieldSize;
37
+ if (baseLength < 128) {
38
+ textLengthField = baseLength.toString(16).padStart(2, '0');
39
+ textLengthFieldSize = 1;
40
+ } else {
41
+ const lowByte = (baseLength & 0x7F) | 0x80;
42
+ const highByte = (baseLength >> 7) & 0xFF;
43
+ textLengthField = lowByte.toString(16).padStart(2, '0') + highByte.toString(16).padStart(2, '0');
44
+ textLengthFieldSize = 2;
45
+ }
46
+
47
+ // 计算总消息长度
48
+ const messageTotalLength = FIXED_HEADER + textLengthFieldSize + SEPARATOR +
49
+ textLengthFieldSize1 + byteLength + FIXED_SUFFIX_LENGTH;
50
+
51
+ const messageLengthHex = messageTotalLength.toString(16).padStart(10, '0');
52
+
53
+ // 构造完整的十六进制字符串
54
+ const hexString = (
55
+ messageLengthHex +
56
+ '12' +
57
+ textLengthField +
58
+ '0A' +
59
+ textLengthField1 +
60
+ bytes.toString('hex') +
61
+ '10016A2432343163636435662D393162612D343131382D393239612D3936626330313631626432612' +
62
+ '2002A132F643A2F6964656150726F2F656475626F73733A1E0A' +
63
+ // 将模型名称长度转换为两位十六进制,并确保是大写
64
+ Buffer.from(modelName, 'utf-8').length.toString(16).padStart(2, '0').toUpperCase() +
65
+ Buffer.from(modelName, 'utf-8').toString('hex').toUpperCase() +
66
+ '22004A' +
67
+ '24' + '61383761396133342D323164642D343863372D623434662D616636633365636536663765' +
68
+ '680070007A2436393337376535612D386332642D343835342D623564392D653062623232336163303061' +
69
+ '800101B00100C00100E00100E80100'
70
+ ).toUpperCase();
71
+ return Buffer.from(hexString, 'hex');
72
+ }
73
+
74
+ // 封装函数,用于将 chunk 转换为 UTF-8 字符串
75
+ function chunkToUtf8String(chunk) {
76
+ // 只处理以 0x00 0x00 0x00 0x00 开头的 chunk,其他不处理,不然会有乱码
77
+ if (!(chunk[0] === 0x00 && chunk[1] === 0x00)) {
78
+ return '';
79
+ }
80
+
81
+ console.log('chunk:', Buffer.from(chunk).toString('hex'));
82
+ console.log('chunk string:', Buffer.from(chunk).toString('utf-8'));
83
+
84
+ // 去掉 chunk 中 0x0A 以及之前的字符
85
+ chunk = chunk.slice(chunk.indexOf(0x0A) + 1);
86
+
87
+ let filteredChunk = [];
88
+ let i = 0;
89
+ while (i < chunk.length) {
90
+ // 新的条件过滤:如果遇到连续4个0x00,则移除其之后所有的以 0 开头的字节(0x00 到 0x0F)
91
+ if (chunk.slice(i, i + 4).every(byte => byte === 0x00)) {
92
+ i += 4; // 跳过这4个0x00
93
+ while (i < chunk.length && chunk[i] >= 0x00 && chunk[i] <= 0x0F) {
94
+ i++; // 跳过所有以 0 开头的字节
95
+ }
96
+ continue;
97
+ }
98
+
99
+ if (chunk[i] === 0x0C) {
100
+ // 遇到 0x0C 时,跳过 0x0C 以及后续的所有连续的 0x0A
101
+ i++; // 跳过 0x0C
102
+ while (i < chunk.length && chunk[i] === 0x0A) {
103
+ i++; // 跳过所有连续的 0x0A
104
+ }
105
+ } else if (
106
+ i > 0 &&
107
+ chunk[i] === 0x0A &&
108
+ chunk[i - 1] >= 0x00 &&
109
+ chunk[i - 1] <= 0x09
110
+ ) {
111
+ // 如果当前字节是 0x0A,且前一个字节在 0x00 至 0x09 之间,跳过前一个字节和当前字节
112
+ filteredChunk.pop(); // 移除已添加的前一个字节
113
+ i++; // 跳过当前的 0x0A
114
+ } else {
115
+ filteredChunk.push(chunk[i]);
116
+ i++;
117
+ }
118
+ }
119
+
120
+ // 第二步:去除所有的 0x00 和 0x0C
121
+ filteredChunk = filteredChunk.filter((byte) => byte !== 0x00 && byte !== 0x0C);
122
+
123
+ // 去除小于 0x0A 的字节
124
+ filteredChunk = filteredChunk.filter((byte) => byte >= 0x0A);
125
+
126
+ const hexString = Buffer.from(filteredChunk).toString('hex');
127
+ console.log('hexString:', hexString);
128
+ const utf8String = Buffer.from(filteredChunk).toString('utf-8');
129
+ console.log('utf8String:', utf8String);
130
+ return utf8String;
131
+ }
132
+
133
+ app.post('/ai/v1/chat/completions', async (req, res) => {
134
+ // o1开头的模型,不支持流式输出
135
+ if (req.body.model.startsWith('o1-') && req.body.stream) {
136
+ return res.status(400).json({
137
+ error: 'Model not supported stream'
138
+ });
139
+ }
140
+
141
+ let currentKeyIndex = 0;
142
+ try {
143
+ const { model, messages, stream = false } = req.body;
144
+ let authToken = req.headers.authorization?.replace('Bearer ', '');
145
+ // 处理逗号分隔的密钥
146
+ const keys = authToken.split(',').map(key => key.trim());
147
+ if (keys.length > 0) {
148
+ // 确保 currentKeyIndex 不会越界
149
+ if (currentKeyIndex >= keys.length) {
150
+ currentKeyIndex = 0;
151
+ }
152
+ // 使用当前索引获取密钥
153
+ authToken = keys[currentKeyIndex];
154
+ // 更新索引
155
+ currentKeyIndex = (currentKeyIndex + 1);
156
+ }
157
+ if (authToken && authToken.includes('%3A%3A')) {
158
+ authToken = authToken.split('%3A%3A')[1];
159
+ }
160
+ if (!messages || !Array.isArray(messages) || messages.length === 0 || !authToken) {
161
+ return res.status(400).json({
162
+ error: 'Invalid request. Messages should be a non-empty array and authorization is required'
163
+ });
164
+ }
165
+
166
+ const formattedMessages = messages.map(msg => `${msg.role}:${msg.content}`).join('\n');
167
+ const hexData = stringToHex(formattedMessages, model);
168
+
169
+ const response = await fetch('https://api2.cursor.sh/aiserver.v1.AiService/StreamChat', {
170
+ method: 'POST',
171
+ headers: {
172
+ 'Content-Type': 'application/connect+proto',
173
+ authorization: `Bearer ${authToken}`,
174
+ 'connect-accept-encoding': 'gzip,br',
175
+ 'connect-protocol-version': '1',
176
+ 'user-agent': 'connect-es/1.4.0',
177
+ 'x-amzn-trace-id': `Root=${uuidv4()}`,
178
+ 'x-cursor-checksum': 'zo6Qjequ9b9734d1f13c3438ba25ea31ac93d9287248b9d30434934e9fcbfa6b3b22029e/7e4af391f67188693b722eff0090e8e6608bca8fa320ef20a0ccb5d7d62dfdef',
179
+ 'x-cursor-client-version': '0.42.3',
180
+ 'x-cursor-timezone': 'Asia/Shanghai',
181
+ 'x-ghost-mode': 'false',
182
+ 'x-request-id': uuidv4(),
183
+ Host: 'api2.cursor.sh'
184
+ },
185
+ body: hexData
186
+ });
187
+
188
+ if (stream) {
189
+ res.setHeader('Content-Type', 'text/event-stream');
190
+ res.setHeader('Cache-Control', 'no-cache');
191
+ res.setHeader('Connection', 'keep-alive');
192
+
193
+ const responseId = `chatcmpl-${uuidv4()}`;
194
+
195
+ // 使用封装的函数处理 chunk
196
+ for await (const chunk of response.body) {
197
+ const text = chunkToUtf8String(chunk);
198
+
199
+ if (text.length > 0) {
200
+ res.write(`data: ${JSON.stringify({
201
+ id: responseId,
202
+ object: 'chat.completion.chunk',
203
+ created: Math.floor(Date.now() / 1000),
204
+ model,
205
+ choices: [{
206
+ index: 0,
207
+ delta: {
208
+ content: text
209
+ }
210
+ }]
211
+ })}\n\n`);
212
+ }
213
+ }
214
+
215
+ res.write('data: [DONE]\n\n');
216
+ return res.end();
217
+ } else {
218
+ let text = '';
219
+ // 在非流模式下也使用封装的函数
220
+ for await (const chunk of response.body) {
221
+ text += chunkToUtf8String(chunk);
222
+ }
223
+ // 对解析后的字符串进行进一步处理
224
+ text = text.replace(/^.*<\|END_USER\|>/s, '');
225
+ text = text.replace(/^\n[a-zA-Z]?/, '').trim();
226
+ console.log(text);
227
+
228
+ return res.json({
229
+ id: `chatcmpl-${uuidv4()}`,
230
+ object: 'chat.completion',
231
+ created: Math.floor(Date.now() / 1000),
232
+ model,
233
+ choices: [{
234
+ index: 0,
235
+ message: {
236
+ role: 'assistant',
237
+ content: text
238
+ },
239
+ finish_reason: 'stop'
240
+ }],
241
+ usage: {
242
+ prompt_tokens: 0,
243
+ completion_tokens: 0,
244
+ total_tokens: 0
245
+ }
246
+ });
247
+ }
248
+ } catch (error) {
249
+ console.error('Error:', error);
250
+ if (!res.headersSent) {
251
+ if (req.body.stream) {
252
+ res.write(`data: ${JSON.stringify({ error: 'Internal server error' })}\n\n`);
253
+ return res.end();
254
+ } else {
255
+ return res.status(500).json({ error: 'Internal server error' });
256
+ }
257
+ }
258
+ }
259
+ });
260
+
261
+ // 启动服务器
262
+ const PORT = process.env.PORT || 3000;
263
+ app.listen(PORT, () => {
264
+ console.log(`服务器运行在端口 ${PORT}`);
265
+ });