Rfym21 commited on
Commit
1d562b0
·
verified ·
1 Parent(s): d13129a

Update index.js

Browse files
Files changed (1) hide show
  1. index.js +267 -269
index.js CHANGED
@@ -1,270 +1,268 @@
1
- import grpc from '@grpc/grpc-js';
2
- import protoLoader from '@grpc/proto-loader';
3
- import { AutoRouter, json, error, cors } from 'itty-router';
4
- import dotenv from 'dotenv';
5
- import { createServerAdapter } from '@whatwg-node/server';
6
- import { createServer } from 'http';
7
-
8
- // 加载环境变量
9
- dotenv.config();
10
- // 初始化配置
11
- class Config {
12
- constructor() {
13
- this.API_PREFIX = process.env.API_PREFIX || '/';
14
- this.API_KEY = process.env.API_KEY || '';
15
- this.MAX_RETRY_COUNT = process.env.MAX_RETRY_COUNT || 3;
16
- this.RETRY_DELAY = process.env.RETRY_DELAY || 5000;
17
- this.COMMON_GRPC = 'runtime-native-io-vertex-inference-grpc-service-lmuw6mcn3q-ul.a.run.app';
18
- this.COMMON_PROTO = './VertexInferenceService.proto';
19
- this.GPT_GRPC = 'runtime-native-io-gpt-inference-grpc-service-lmuw6mcn3q-ul.a.run.app';
20
- this.GPT_PROTO = './GPTInferenceService.proto';
21
- this.PORT = process.env.PORT || 8787;
22
- }
23
- }
24
- const config = new Config();
25
- // 中间件
26
- // 添加运行回源
27
- const { preflight, corsify } = cors({
28
- origin: '*',
29
- allowMethods: '*',
30
- exposeHeaders: '*',
31
- });
32
-
33
- // 添加认证
34
- const withAuth = (request) => {
35
- if (config.API_KEY) {
36
- const authHeader = request.headers.get('Authorization');
37
- if (!authHeader || !authHeader.startsWith('Bearer ')) {
38
- return error(401, 'Unauthorized: Missing or invalid Authorization header');
39
- }
40
- const token = authHeader.substring(7);
41
- if (token !== config.API_KEY) {
42
- return error(403, 'Forbidden: Invalid API key');
43
- }
44
- }
45
- };
46
- // 返回运行信息
47
- const logger = (res, req) => {
48
- console.log(req.method, res.status, req.url, Date.now() - req.start, 'ms');
49
- };
50
- const router = AutoRouter({
51
- before: [preflight, withAuth],
52
- missing: () => error(404, '404 not found.'),
53
- finally: [corsify, logger],
54
- });
55
- // Router路径
56
- router.get('/', () => json({ message: 'API 服务运行中~' }));
57
- router.get('/ping', () => json({ message: 'pong' }));
58
- router.post(config.API_PREFIX + '/v1/chat/completions', (req) => handleCompletion(req));
59
-
60
- async function GrpcToPieces(models, message, rules,stream,temperature,top_p) {
61
- // 在非GPT类型的模型中,temperature和top_p是无效的
62
- // 使用系统的根证书
63
- const credentials = grpc.credentials.createSsl();
64
- if (models.includes('gpt')){
65
- // 加载proto文件
66
- const packageDefinition = protoLoader.loadSync(config.GPT_PROTO, {
67
- keepCase: true,
68
- longs: String,
69
- enums: String,
70
- defaults: true,
71
- oneofs: true
72
- });
73
- // 构建请求消息
74
- const request = {
75
- models: models,
76
- messages: [
77
- {role: 0, message: rules}, // system
78
- {role: 1, message: message} // user
79
- ],
80
- temperature:temperature || 0.1,
81
- top_p:top_p ?? 1,
82
- }
83
- // 获取gRPC对象
84
- const GRPCobjects = grpc.loadPackageDefinition(packageDefinition).runtime.aot.machine_learning.parents.gpt;
85
- const client = new GRPCobjects.GPTInferenceService(config.GPT_GRPC, credentials);
86
- for (let retryCount = 0; retryCount <= config.MAX_RETRY_COUNT; retryCount++) {
87
- try {
88
- // 使用 Promise 包装异步 gRPC 调用
89
- const response = await new Promise((resolve, reject) => {
90
- client.Predict(request, (err, response) => {
91
- if (err) {
92
- reject(err);
93
- } else {
94
- resolve(response);
95
- }
96
- });
97
- });
98
- // 处理响应
99
- let response_code = response.response_code;
100
- let response_message = response.body.message_warpper.message;
101
- // 检查解构结果
102
- if (!response_code || !response_message) {
103
- console.error('Invalid response format, retrying...');
104
- continue; // 继续重试
105
- }
106
- console.log('Received response from server', response);
107
-
108
- // 如果响应成功,返回结果
109
- if (+response_code === 200) {
110
- return { response_code, response_message };
111
- } else {
112
- // 如果响应码不是200,继续重试
113
- console.error('Non-success response code, retrying...');
114
- }
115
- } catch (err) {
116
- // 捕获错误并重试
117
- console.error('Error occurred during gRPC call:', err);
118
- }
119
- }
120
- } else {
121
- // 加载proto文件
122
- const packageDefinition = protoLoader.loadSync(config.COMMON_PROTO,{
123
- keepCase: true,
124
- longs: String,
125
- enums: String,
126
- defaults: true,
127
- oneofs: true
128
- });
129
- // 构建请求消息
130
- const request = {
131
- models: models,
132
- args: {
133
- messages: {
134
- unknown: 1,
135
- message: message
136
- },
137
- rules: rules
138
- }
139
- };
140
- // 获取gRPC对象
141
- const GRPCobjects = grpc.loadPackageDefinition(packageDefinition).runtime.aot.machine_learning.parents.vertex;
142
- const client = new GRPCobjects.VertexInferenceService(config.COMMON_GRPC, credentials);
143
- for (let retryCount = 0; retryCount <= config.MAX_RETRY_COUNT; retryCount++) {
144
- try {
145
- // 使用 Promise 包装异步 gRPC 调用
146
- const response = await new Promise((resolve, reject) => {
147
- client.Predict(request, (err, response) => {
148
- if (err) {
149
- reject(err);
150
- } else {
151
- resolve(response);
152
- }
153
- });
154
- });
155
-
156
- // 处理响应
157
- let response_code = response.response_code;
158
- let response_message = response.args.args.args.message;
159
- // 检查解构结果
160
- if (!response_code || !response_message) {
161
- console.error('Invalid response format, retrying...');
162
- continue; // 继续重试
163
- }
164
- console.log('Received response from server', response);
165
-
166
- // 如果响应成功,返回结果
167
- if (+response_code === 200) {
168
- return { response_code, response_message };
169
- } else {
170
- // 如果响应码不是200,继续重试
171
- console.error('Non-success response code, retrying...');
172
- }
173
- } catch (err) {
174
- // 捕获错误并重试
175
- console.error('Error occurred during gRPC call:', err);
176
- }
177
- }
178
- }
179
- }
180
-
181
- async function messagesProcess(messages) {
182
- let rules = '';
183
- let message = '';
184
-
185
- for (const msg of messages) {
186
- let role = msg.role;
187
- // 格式化为字符串
188
- const contentStr = Array.isArray(msg.content)
189
- ? msg.content
190
- .filter((item) => item.text)
191
- .map((item) => item.text)
192
- .join('') || ''
193
- : msg.content;
194
- // 判断身份
195
- if (role === 'system') {
196
- rules += `system:${contentStr};\r\n`;
197
- } else if (['user', 'assistant'].includes(role)) {
198
- message += `${role}:${contentStr};\r\n`;
199
- }
200
- }
201
-
202
- return { rules, message };
203
- }
204
-
205
- async function ConvertOpenai(messages,response_code,stream) {
206
- if (response_code !== 200) {
207
- //todo 不知道返回什么
208
- }
209
- if (stream){
210
- // todo
211
- } else {
212
- return new Response(JSON.stringify(ChatCompletionWithModel(messages, response_code)), {
213
- headers: {
214
- 'Content-Type': 'application/json',
215
- },
216
- });
217
- }
218
- }
219
-
220
- function ChatCompletionWithModel(message, model) {
221
- return {
222
- id: 'Chat-Nekohy',
223
- object: 'chat.completion',
224
- created: Date.now(),
225
- model,
226
- usage: {
227
- prompt_tokens: 0,
228
- completion_tokens: 0,
229
- total_tokens: 0,
230
- },
231
- choices: [
232
- {
233
- message: {
234
- content: message,
235
- role: 'assistant',
236
- },
237
- index: 0,
238
- },
239
- ],
240
- };
241
- }
242
-
243
- async function handleCompletion(request) {
244
- try {
245
- // todo stream逆向接口
246
- // 解析openai格式API请求
247
- const { model: inputModel, messages, stream:todo,temperature,top_p} = await request.json();
248
- console.log(inputModel,messages,todo)
249
- let stream = false;
250
- // 解析system和user/assistant消息
251
- const { rules, message:content } = await messagesProcess(messages);
252
- console.log(rules,content)
253
- // 响应码,回复的消息
254
- const { response_code, response_message } = await GrpcToPieces(inputModel, content, rules, stream, temperature, top_p);
255
- // 转换为OpenAi格式
256
- return await ConvertOpenai(response_message,response_code,stream)
257
- } catch (err) {
258
- return error(500, err.message);
259
- }
260
- }
261
-
262
- (async () => {
263
- //For Cloudflare Workers
264
- if (typeof addEventListener === 'function') return;
265
- // For Nodejs
266
- const ittyServer = createServerAdapter(router.fetch);
267
- console.log(`Listening on http://localhost:${config.PORT}`);
268
- const httpServer = createServer(ittyServer);
269
- httpServer.listen(config.PORT);
270
  })();
 
1
+ import grpc from '@grpc/grpc-js';
2
+ import protoLoader from '@grpc/proto-loader';
3
+ import {AutoRouter, cors, error, json} from 'itty-router';
4
+ import dotenv from 'dotenv';
5
+ import {createServerAdapter} from '@whatwg-node/server';
6
+ import {createServer} from 'http';
7
+
8
+ // 加载环境变量
9
+ dotenv.config();
10
+ // 初始化配置
11
+ class Config {
12
+ constructor() {
13
+ this.API_PREFIX = process.env.API_PREFIX || '/';
14
+ this.API_KEY = process.env.API_KEY || '';
15
+ this.MAX_RETRY_COUNT = process.env.MAX_RETRY_COUNT || 3;
16
+ this.RETRY_DELAY = process.env.RETRY_DELAY || 5000;
17
+ this.COMMON_GRPC = 'runtime-native-io-vertex-inference-grpc-service-lmuw6mcn3q-ul.a.run.app';
18
+ this.COMMON_PROTO = './VertexInferenceService.proto';
19
+ this.GPT_GRPC = 'runtime-native-io-gpt-inference-grpc-service-lmuw6mcn3q-ul.a.run.app';
20
+ this.GPT_PROTO = './GPTInferenceService.proto';
21
+ this.PORT = process.env.PORT || 8787;
22
+ }
23
+ }
24
+ class GRPCHandler {
25
+ constructor(protoFilePath) {
26
+ // 动态加载传入的 .proto 文件路径
27
+ this.packageDefinition = protoLoader.loadSync(protoFilePath, {
28
+ keepCase: true,
29
+ longs: String,
30
+ enums: String,
31
+ defaults: true,
32
+ oneofs: true
33
+ });
34
+ }
35
+ }
36
+ const config = new Config();
37
+ // 中间件
38
+ // 添加运行回源
39
+ const { preflight, corsify } = cors({
40
+ origin: '*',
41
+ allowMethods: '*',
42
+ exposeHeaders: '*',
43
+ });
44
+
45
+ // 添加认证
46
+ const withAuth = (request) => {
47
+ if (config.API_KEY) {
48
+ const authHeader = request.headers.get('Authorization');
49
+ if (!authHeader || !authHeader.startsWith('Bearer ')) {
50
+ return error(401, 'Unauthorized: Missing or invalid Authorization header');
51
+ }
52
+ const token = authHeader.substring(7);
53
+ if (token !== config.API_KEY) {
54
+ return error(403, 'Forbidden: Invalid API key');
55
+ }
56
+ }
57
+ };
58
+ // 返回运行信息
59
+ const logger = (res, req) => {
60
+ console.log(req.method, res.status, req.url, Date.now() - req.start, 'ms');
61
+ };
62
+ const router = AutoRouter({
63
+ before: [preflight, withAuth],
64
+ missing: () => error(404, '404 not found.'),
65
+ finally: [corsify, logger],
66
+ });
67
+ // Router路径
68
+ router.get('/', () => json({ message: 'API 服务运行中~' }));
69
+ router.get('/ping', () => json({ message: 'pong' }));
70
+ router.post(config.API_PREFIX + '/v1/chat/completions', (req) => handleCompletion(req));
71
+
72
+ async function GrpcToPieces(models, message, rules, stream, temperature, top_p) {
73
+ // 在非GPT类型的模型中,temperature和top_p是无效的
74
+ // 使用系统的根证书
75
+ const credentials = grpc.credentials.createSsl();
76
+ let client,request;
77
+ if (models.includes('gpt')){
78
+ // 加载proto文件
79
+ const packageDefinition = new GRPCHandler(config.GPT_PROTO).packageDefinition;
80
+ // 构建请求消息
81
+ request = {
82
+ models: models,
83
+ messages: [
84
+ {role: 0, message: rules}, // system
85
+ {role: 1, message: message} // user
86
+ ],
87
+ temperature:temperature || 0.1,
88
+ top_p:top_p ?? 1,
89
+ }
90
+ // 获取gRPC对象
91
+ const GRPCobjects = grpc.loadPackageDefinition(packageDefinition).runtime.aot.machine_learning.parents.gpt;
92
+ client = new GRPCobjects.GPTInferenceService(config.GPT_GRPC, credentials);
93
+ } else {
94
+ // 加载proto文件
95
+ const packageDefinition = new GRPCHandler(config.COMMON_PROTO).packageDefinition;
96
+ // 构建请求消息
97
+ request = {
98
+ models: models,
99
+ args: {
100
+ messages: {
101
+ unknown: 1,
102
+ message: message
103
+ },
104
+ rules: rules
105
+ }
106
+ };
107
+ // 获取gRPC对象
108
+ const GRPCobjects = grpc.loadPackageDefinition(packageDefinition).runtime.aot.machine_learning.parents.vertex;
109
+ client = new GRPCobjects.VertexInferenceService(config.COMMON_GRPC, credentials);
110
+ }
111
+ return await ConvertOpenai(client,request,models,stream);
112
+ }
113
+
114
+ async function messagesProcess(messages) {
115
+ let rules = '';
116
+ let message = '';
117
+
118
+ for (const msg of messages) {
119
+ let role = msg.role;
120
+ // 格式化为字符串
121
+ const contentStr = Array.isArray(msg.content)
122
+ ? msg.content
123
+ .filter((item) => item.text)
124
+ .map((item) => item.text)
125
+ .join('') || ''
126
+ : msg.content;
127
+ // 判断身份
128
+ if (role === 'system') {
129
+ rules += `system:${contentStr};\r\n`;
130
+ } else if (['user', 'assistant'].includes(role)) {
131
+ message += `${role}:${contentStr};\r\n`;
132
+ }
133
+ }
134
+
135
+ return { rules, message };
136
+ }
137
+
138
+ async function ConvertOpenai(client,request,model,stream) {
139
+ for (let i = 0; i < config.MAX_RETRY_COUNT; i++) {
140
+ try {
141
+ if (stream) {
142
+ const call = client.PredictWithStream(request);
143
+ const encoder = new TextEncoder();
144
+ const ReturnStream = new ReadableStream({
145
+ start(controller) {
146
+ call.on('data', (response) => {
147
+ let response_code = Number(response.response_code);
148
+ if (response_code === 204) {
149
+ // 如果 response_code 是 204,关闭流
150
+ controller.close()
151
+ call.destroy()
152
+ } else if (response_code === 200) {
153
+ let response_message
154
+ if (model.includes('gpt')) {
155
+ response_message = response.body.message_warpper.message.message;
156
+ } else {
157
+ response_message = response.args.args.args.message;
158
+ }
159
+ // 否则,将数据块加入流中
160
+ controller.enqueue(encoder.encode(`data: ${JSON.stringify(ChatCompletionStreamWithModel(response_message, model))}\n\n`));
161
+ } else {
162
+ controller.error(new Error(`Error: stream chunk is not success`));
163
+ controller.close()
164
+ }
165
+ })
166
+ }
167
+ });
168
+ return new Response(ReturnStream, {
169
+ headers: {
170
+ 'Content-Type': 'text/event-stream',
171
+ },
172
+ })
173
+ } else {
174
+ const call = await new Promise((resolve, reject) => {
175
+ client.Predict(request, (err, response) => {
176
+ if (err) reject(err);
177
+ else resolve(response);
178
+ });
179
+ });
180
+ let response_code = Number(call.response_code);
181
+ if (response_code === 200) {
182
+ let response_message
183
+ if (model.includes('gpt')) {
184
+ response_message = call.body.message_warpper.message.message;
185
+ } else {
186
+ response_message = call.args.args.args.message;
187
+ }
188
+ return new Response(JSON.stringify(ChatCompletionWithModel(response_message, model)), {
189
+ headers: {
190
+ 'Content-Type': 'application/json',
191
+ },
192
+ });
193
+ }
194
+ }
195
+ } catch (err) {
196
+ console.error(err);
197
+ await new Promise((resolve) => setTimeout(resolve, config.RETRY_DELAY));
198
+ }
199
+ }
200
+ return error(500, err.message);
201
+ }
202
+
203
+ function ChatCompletionWithModel(message, model) {
204
+ return {
205
+ id: 'Chat-Nekohy',
206
+ object: 'chat.completion',
207
+ created: Date.now(),
208
+ model,
209
+ usage: {
210
+ prompt_tokens: 0,
211
+ completion_tokens: 0,
212
+ total_tokens: 0,
213
+ },
214
+ choices: [
215
+ {
216
+ message: {
217
+ content: message,
218
+ role: 'assistant',
219
+ },
220
+ index: 0,
221
+ },
222
+ ],
223
+ };
224
+ }
225
+
226
+ function ChatCompletionStreamWithModel(text, model) {
227
+ return {
228
+ id: 'chatcmpl-QXlha2FBbmROaXhpZUFyZUF3ZXNvbWUK',
229
+ object: 'chat.completion.chunk',
230
+ created: 0,
231
+ model,
232
+ choices: [
233
+ {
234
+ index: 0,
235
+ delta: {
236
+ content: text,
237
+ },
238
+ finish_reason: null,
239
+ },
240
+ ],
241
+ };
242
+ }
243
+
244
+ async function handleCompletion(request) {
245
+ try {
246
+ // todo stream逆向接口
247
+ // 解析openai格式API请求
248
+ const { model: inputModel, messages, stream,temperature,top_p} = await request.json();
249
+ console.log(inputModel,messages,stream)
250
+ // 解析system和user/assistant消息
251
+ const { rules, message:content } = await messagesProcess(messages);
252
+ console.log(rules,content)
253
+ // 响应码,回复的消息
254
+ return await GrpcToPieces(inputModel, content, rules, stream, temperature, top_p);
255
+ } catch (err) {
256
+ return error(500, err.message);
257
+ }
258
+ }
259
+
260
+ (async () => {
261
+ //For Cloudflare Workers
262
+ if (typeof addEventListener === 'function') return;
263
+ // For Nodejs
264
+ const ittyServer = createServerAdapter(router.fetch);
265
+ console.log(`Listening on http://localhost:${config.PORT}`);
266
+ const httpServer = createServer(ittyServer);
267
+ httpServer.listen(config.PORT);
 
 
268
  })();