smgc commited on
Commit
06c2475
·
verified ·
1 Parent(s): 073567b

Update api/index.js

Browse files
Files changed (1) hide show
  1. api/index.js +235 -248
api/index.js CHANGED
@@ -1,4 +1,4 @@
1
- import grpc from '@grpc/grpc-js';
2
  import protoLoader from '@grpc/proto-loader';
3
  import {AutoRouter, cors, error, json} from 'itty-router';
4
  import dotenv from 'dotenv';
@@ -12,6 +12,7 @@ dotenv.config();
12
  // 获取当前文件的目录路径(ESM 方式)
13
  const __dirname = dirname(fileURLToPath(import.meta.url));
14
  // 初始化配置
 
15
  class Config {
16
  constructor() {
17
  this.API_PREFIX = process.env.API_PREFIX || '/';
@@ -23,6 +24,32 @@ class Config {
23
  this.GPT_GRPC = 'runtime-native-io-gpt-inference-grpc-service-lmuw6mcn3q-ul.a.run.app';
24
  this.GPT_PROTO = path.join(__dirname,'..', 'protos', 'GPTInferenceService.proto')
25
  this.PORT = process.env.PORT || 8787;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  }
27
  }
28
  class GRPCHandler {
@@ -59,167 +86,80 @@ const withAuth = (request) => {
59
  }
60
  }
61
  };
62
-
63
  // 返回运行信息
64
  const logger = (res, req) => {
65
  console.log(req.method, res.status, req.url, Date.now() - req.start, 'ms');
66
  };
67
-
68
- // 定义模型映射信息
69
- const MODEL_INFO = {
70
- "claude-3-sonnet-20240229": {
71
- "provider": "anthropic",
72
- "mapping": "claude-3-sonnet@20240229"
73
- },
74
- "claude-3-opus-20240229": {
75
- "provider": "anthropic",
76
- "mapping": "claude-3-opus@20240229"
77
- },
78
- "claude-3-haiku-20240307": {
79
- "provider": "anthropic",
80
- "mapping": "claude-3-haiku@20240307"
81
- },
82
- "claude-3-5-sonnet-20240620": {
83
- "provider": "anthropic",
84
- "mapping": "claude-3-5-sonnet@20240620"
85
- },
86
- "gpt-4o-mini": {
87
- "provider": "openai",
88
- "mapping": "gpt-4o-mini"
89
- },
90
- "gpt-4o": {
91
- "provider": "openai",
92
- "mapping": "gpt-4o"
93
- },
94
- "gpt-4-turbo": {
95
- "provider": "openai",
96
- "mapping": "gpt-4-turbo"
97
- },
98
- "gpt-4": {
99
- "provider": "openai",
100
- "mapping": "gpt-4"
101
- },
102
- "gpt-3.5-turbo": {
103
- "provider": "openai",
104
- "mapping": "gpt-3.5-turbo"
105
- },
106
- "gemini-1.5-pro": {
107
- "provider": "google",
108
- "mapping": "gemini-1.5-pro"
109
- },
110
- "gemini-1.5-flash": {
111
- "provider": "google",
112
- "mapping": "gemini-1.5-flash"
113
- },
114
- "chat-bison": {
115
- "provider": "pieces-os",
116
- "mapping": "chat-bison"
117
- },
118
- "codechat-bison": {
119
- "provider": "pieces-os",
120
- "mapping": "codechat-bison"
121
- }
122
- };
123
-
124
- // 定义路由
125
  const router = AutoRouter({
126
- before: [preflight], // 只保留 CORS preflight 检查
127
- missing: () => error(404, '404 not found.'),
128
- finally: [corsify, logger],
129
  });
130
-
131
- // 根路由
132
- router.get('/', () => json({
133
- service: "AI Chat Completion Proxy",
134
- usage: {
135
- endpoint: "/v1/chat/completions",
136
- method: "POST",
137
- headers: {
138
- "Content-Type": "application/json",
139
- "Authorization": "Bearer YOUR_API_KEY"
140
- },
141
- body: {
142
- model: "One of: " + Object.keys(MODEL_INFO).join(", "),
143
- messages: [
144
- { role: "system", content: "You are a helpful assistant." },
145
- { role: "user", content: "Hello, who are you?" }
146
- ],
147
- stream: false,
148
- temperature: 0.7,
149
- top_p: 1
150
- }
151
- },
152
- note: "Replace YOUR_API_KEY with your actual API key."
153
- }));
154
-
155
- // models 路由
156
- router.get(config.API_PREFIX + '/v1/models', withAuth, () =>
157
  json({
158
- object: "list",
159
- data: Object.entries(MODEL_INFO).map(([modelId, info]) => ({
160
- id: modelId,
161
- object: "model",
162
- created: Date.now(),
163
- owned_by: "pieces-os",
164
- permission: [],
165
- root: modelId,
166
- parent: null,
167
- mapping: info.mapping,
168
- provider: info.provider
169
- }))
 
 
 
 
170
  })
171
  );
 
172
 
173
- // chat 路由
174
- router.post(config.API_PREFIX + '/v1/chat/completions', withAuth, (req) => handleCompletion(req));
175
-
176
- async function GrpcToPieces(models, message, rules, stream, temperature, top_p) {
177
- const credentials = grpc.credentials.createSsl();
178
-
179
- function getMetadata() {
180
- const metadata = new grpc.Metadata();
181
- metadata.set('user-agent', 'dart-grpc/2.0.0');
182
- return metadata;
183
- }
184
-
185
- const metadata = getMetadata();
186
- const options = {
187
- 'grpc.primary_user_agent': 'dart-grpc/2.0.0'
188
- };
189
-
190
- let client, request;
191
-
192
- if (models.includes('gpt')) {
193
- const packageDefinition = new GRPCHandler(config.GPT_PROTO).packageDefinition;
194
- request = {
195
- models: models,
196
- messages: [
197
- { role: 0, message: rules }, // system
198
- { role: 1, message: message } // user
199
- ],
200
- temperature: temperature || 0.1,
201
- top_p: top_p ?? 1,
202
- };
203
- const GRPCobjects = grpc.loadPackageDefinition(packageDefinition).runtime.aot.machine_learning.parents.gpt;
204
- client = new GRPCobjects.GPTInferenceService(config.GPT_GRPC, credentials, options);
205
- } else {
206
- const packageDefinition = new GRPCHandler(config.COMMON_PROTO).packageDefinition;
207
- request = {
208
- models: models,
209
- args: {
210
- messages: {
211
- unknown: 1,
212
- message: message
213
- },
214
- rules: rules
215
- }
216
- };
217
- const GRPCobjects = grpc.loadPackageDefinition(packageDefinition).runtime.aot.machine_learning.parents.vertex;
218
- client = new GRPCobjects.VertexInferenceService(config.COMMON_GRPC, credentials, options);
219
- }
220
-
221
- console.log('Request:', JSON.stringify(request, null, 2));
222
- return await ConvertOpenai(client, request, models, stream, metadata);
223
  }
224
 
225
  async function messagesProcess(messages) {
@@ -246,89 +186,120 @@ async function messagesProcess(messages) {
246
  return { rules, message };
247
  }
248
 
249
- async function ConvertOpenai(client, request, model, stream, metadata) {
250
- for (let i = 0; i < config.MAX_RETRY_COUNT; i++) {
251
- try {
252
- if (stream) {
253
- const call = client.PredictWithStream(request, metadata);
254
- const encoder = new TextEncoder();
255
- const ReturnStream = new ReadableStream({
256
- start(controller) {
257
- call.on('data', (response) => {
258
- console.log('Stream response:', JSON.stringify(response, null, 2));
259
- let response_code = Number(response.response_code);
260
- if (response_code === 204) {
261
- controller.close();
262
- call.destroy();
263
- } else if (response_code === 200) {
264
- let response_message;
265
- if (model.includes('gpt')) {
266
- response_message = response.body?.message_warpper?.message?.message;
267
- } else {
268
- response_message = response.args?.args?.args?.message;
269
- }
270
- if (response_message) {
271
- controller.enqueue(encoder.encode(`data: ${JSON.stringify(ChatCompletionStreamWithModel(response_message, model))}\n\n`));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  }
273
- } else {
274
- controller.error(new Error(`Error: stream chunk response code ${response_code}`));
275
- controller.close();
276
- }
277
- });
278
- call.on('error', (error) => {
279
- console.error('Stream error:', error);
280
- controller.error(error);
281
- controller.close();
282
- });
283
- call.on('end', () => {
284
- controller.close();
285
- });
286
- }
287
- });
288
- return new Response(ReturnStream, {
289
- headers: {
290
- 'Content-Type': 'text/event-stream',
291
- },
292
- });
293
- } else {
294
- const call = await new Promise((resolve, reject) => {
295
- client.Predict(request, metadata, (err, response) => {
296
- if (err) reject(err);
297
- else resolve(response);
298
- });
299
- });
300
- console.log('Non-stream response:', JSON.stringify(call, null, 2));
301
- let response_code = Number(call.response_code);
302
- if (response_code === 200) {
303
- let response_message;
304
- if (model.includes('gpt')) {
305
- response_message = call.body?.message_warpper?.message?.message;
306
- } else {
307
- response_message = call.args?.args?.args?.message;
308
- }
309
- if (response_message) {
310
- return new Response(JSON.stringify(ChatCompletionWithModel(response_message, model)), {
311
- headers: {
312
- 'Content-Type': 'application/json',
313
- },
314
- });
315
- } else {
316
- throw new Error('Response message is empty or undefined');
317
- }
318
- } else {
319
- throw new Error(`Error: response code ${response_code}`);
320
  }
321
- }
322
- } catch (err) {
323
- console.error(`Attempt ${i + 1} failed:`, err);
324
- if (i === config.MAX_RETRY_COUNT - 1) {
325
- return error(500, `All retry attempts failed. Last error: ${err.message}`);
326
- }
327
- await new Promise((resolve) => setTimeout(resolve, config.RETRY_DELAY));
328
  }
329
- }
330
  }
331
 
 
332
  function ChatCompletionWithModel(message, model) {
333
  return {
334
  id: 'Chat-Nekohy',
@@ -371,25 +342,41 @@ function ChatCompletionStreamWithModel(text, model) {
371
  }
372
 
373
  async function handleCompletion(request) {
374
- try {
375
- const { model: inputModel, messages, stream, temperature, top_p } = await request.json();
376
-
377
- // 获取模型映射
378
- const modelInfo = MODEL_INFO[inputModel];
379
- if (!modelInfo) {
380
- return error(400, `Unsupported model: ${inputModel}`);
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381
  }
382
-
383
- const mappedModel = modelInfo.mapping;
384
-
385
- // 解析 system 和 user/assistant 消息
386
- const { rules, message: content } = await messagesProcess(messages);
387
-
388
- // 使用映射后的模型名称
389
- return await GrpcToPieces(mappedModel, content, rules, stream, temperature, top_p);
390
- } catch (err) {
391
- return error(500, err.message);
392
- }
393
  }
394
 
395
  (async () => {
 
1
+ import grpc from '@huayue/grpc-js';
2
  import protoLoader from '@grpc/proto-loader';
3
  import {AutoRouter, cors, error, json} from 'itty-router';
4
  import dotenv from 'dotenv';
 
12
  // 获取当前文件的目录路径(ESM 方式)
13
  const __dirname = dirname(fileURLToPath(import.meta.url));
14
  // 初始化配置
15
+ // 初始化配置
16
  class Config {
17
  constructor() {
18
  this.API_PREFIX = process.env.API_PREFIX || '/';
 
24
  this.GPT_GRPC = 'runtime-native-io-gpt-inference-grpc-service-lmuw6mcn3q-ul.a.run.app';
25
  this.GPT_PROTO = path.join(__dirname,'..', 'protos', 'GPTInferenceService.proto')
26
  this.PORT = process.env.PORT || 8787;
27
+ // 添加支持的模型列表
28
+ this.SUPPORTED_MODELS = process.env.SUPPORTED_MODELS || [
29
+ "gpt-4o-mini",
30
+ "gpt-4o",
31
+ "gpt-4-turbo",
32
+ "gpt-4",
33
+ "gpt-3.5-turbo",
34
+ "claude-3-sonnet@20240229",
35
+ "claude-3-opus@20240229",
36
+ "claude-3-haiku@20240307",
37
+ "claude-3-5-sonnet@20240620",
38
+ "gemini-1.5-flash",
39
+ "gemini-1.5-pro",
40
+ "chat-bison",
41
+ "codechat-bison"
42
+ ];
43
+ }
44
+
45
+ // 添加模型验证方法
46
+ isValidModel(model) {
47
+ // 处理 Claude 模型的特殊格式
48
+ const RegexInput = /^(claude-3-(5-sonnet|haiku|sonnet|opus))-(\d{8})$/;
49
+ const matchInput = model.match(RegexInput);
50
+ const normalizedModel = matchInput ? `${matchInput[1]}@${matchInput[3]}` : model;
51
+
52
+ return this.SUPPORTED_MODELS.includes(normalizedModel);
53
  }
54
  }
55
  class GRPCHandler {
 
86
  }
87
  }
88
  };
 
89
  // 返回运行信息
90
  const logger = (res, req) => {
91
  console.log(req.method, res.status, req.url, Date.now() - req.start, 'ms');
92
  };
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
  const router = AutoRouter({
94
+ before: [preflight, withAuth],
95
+ missing: () => error(404, '404 not found.'),
96
+ finally: [corsify, logger],
97
  });
98
+ // Router路径
99
+ router.get('/', () => json({ message: 'API 服务运行中~' }));
100
+ router.get('/ping', () => json({ message: 'pong' }));
101
+ router.get(config.API_PREFIX + '/v1/models', () =>
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  json({
103
+ object: 'list',
104
+ data: [
105
+ { id: "gpt-4o-mini", object: "model", owned_by: "pieces-os" },
106
+ { id: "gpt-4o", object: "model", owned_by: "pieces-os" },
107
+ { id: "gpt-4-turbo", object: "model", owned_by: "pieces-os" },
108
+ { id: "gpt-4", object: "model", owned_by: "pieces-os" },
109
+ { id: "gpt-3.5-turbo", object: "model", owned_by: "pieces-os" },
110
+ { id: "claude-3-sonnet@20240229", object: "model", owned_by: "pieces-os" },
111
+ { id: "claude-3-opus@20240229", object: "model", owned_by: "pieces-os" },
112
+ { id: "claude-3-haiku@20240307", object: "model", owned_by: "pieces-os" },
113
+ { id: "claude-3-5-sonnet@20240620", object: "model", owned_by: "pieces-os" },
114
+ { id: "gemini-1.5-flash", object: "model", owned_by: "pieces-os" },
115
+ { id: "gemini-1.5-pro", object: "model", owned_by: "pieces-os" },
116
+ { id: "chat-bison", object: "model", owned_by: "pieces-os" },
117
+ { id: "codechat-bison", object: "model", owned_by: "pieces-os" },
118
+ ],
119
  })
120
  );
121
+ router.post(config.API_PREFIX + '/v1/chat/completions', (req) => handleCompletion(req));
122
 
123
+ async function GrpcToPieces(inputModel,OriginModel,message, rules, stream, temperature, top_p) {
124
+ // 在非GPT类型的模型中,temperature和top_p是无效的
125
+ // 使用系统的根证书
126
+ const credentials = grpc.credentials.createSsl();
127
+ let client,request;
128
+ if (inputModel.includes('gpt')){
129
+ // 加载proto文件
130
+ const packageDefinition = new GRPCHandler(config.GPT_PROTO).packageDefinition;
131
+ // 构建请求消息
132
+ request = {
133
+ models: inputModel,
134
+ messages: [
135
+ {role: 0, message: rules}, // system
136
+ {role: 1, message: message} // user
137
+ ],
138
+ temperature:temperature || 0.1,
139
+ top_p:top_p ?? 1,
140
+ }
141
+ // 获取gRPC对象
142
+ const GRPCobjects = grpc.loadPackageDefinition(packageDefinition).runtime.aot.machine_learning.parents.gpt;
143
+ client = new GRPCobjects.GPTInferenceService(config.GPT_GRPC, credentials);
144
+ } else {
145
+ // 加载proto文件
146
+ const packageDefinition = new GRPCHandler(config.COMMON_PROTO).packageDefinition;
147
+ // 构建请求消息
148
+ request = {
149
+ models: inputModel,
150
+ args: {
151
+ messages: {
152
+ unknown: 1,
153
+ message: message
154
+ },
155
+ rules: rules
156
+ }
157
+ };
158
+ // 获取gRPC对象
159
+ const GRPCobjects = grpc.loadPackageDefinition(packageDefinition).runtime.aot.machine_learning.parents.vertex;
160
+ client = new GRPCobjects.VertexInferenceService(config.COMMON_GRPC, credentials);
161
+ }
162
+ return await ConvertOpenai(client,request,inputModel,OriginModel,stream);
 
 
 
 
 
 
 
 
 
 
163
  }
164
 
165
  async function messagesProcess(messages) {
 
186
  return { rules, message };
187
  }
188
 
189
+ async function ConvertOpenai(client, request, inputModel, OriginModel, stream) {
190
+ const metadata = new grpc.Metadata();
191
+ metadata.set('User-Agent', 'dart-grpc/2.0.0');
192
+ for (let i = 0; i < config.MAX_RETRY_COUNT; i++) {
193
+ try {
194
+ if (stream) {
195
+ const call = client.PredictWithStream(request,metadata);
196
+ const encoder = new TextEncoder();
197
+ const ReturnStream = new ReadableStream({
198
+ start(controller) {
199
+ // 处理数据
200
+ call.on('data', (response) => {
201
+ try {
202
+ let response_code = Number(response.response_code);
203
+ if (response_code === 204) {
204
+ controller.close();
205
+ call.destroy();
206
+ } else if (response_code === 200) {
207
+ let response_message;
208
+ if (inputModel.includes('gpt')) {
209
+ response_message = response.body.message_warpper.message.message;
210
+ } else {
211
+ response_message = response.args.args.args.message;
212
+ }
213
+ controller.enqueue(encoder.encode(`data: ${JSON.stringify(ChatCompletionStreamWithModel(response_message, OriginModel))}\n\n`));
214
+ } else {
215
+ console.error(`Invalid response code: ${response_code}`);
216
+ controller.error(error);
217
+ }
218
+ } catch (error) {
219
+ console.error('Error processing stream data:', error);
220
+ controller.error(error);
221
+ }
222
+ });
223
+
224
+ // 处理错误
225
+ call.on('error', (error) => {
226
+ console.error('Stream error:', error);
227
+ // 如果是 INTERNAL 错误且包含 RST_STREAM,可能是正常的流结束
228
+ if (error.code === 13 && error.details.includes('RST_STREAM')) {
229
+ controller.close();
230
+ } else {
231
+ controller.error(error);
232
+ }
233
+ call.destroy();
234
+ });
235
+
236
+ // 处理结束
237
+ call.on('end', () => {
238
+ controller.close();
239
+ });
240
+
241
+ // 处理取消
242
+ return () => {
243
+ call.destroy();
244
+ };
245
+ }
246
+ });
247
+
248
+ return new Response(ReturnStream, {
249
+ headers: {
250
+ 'Content-Type': 'text/event-stream',
251
+ 'Connection': 'keep-alive',
252
+ 'Cache-Control': 'no-cache',
253
+ 'Transfer-Encoding': 'chunked'
254
+ },
255
+ });
256
+ } else {
257
+ // 非流式调用保持不变
258
+ const call = await new Promise((resolve, reject) => {
259
+ client.Predict(request,metadata, (err, response) => {
260
+ if (err) reject(err);
261
+ else resolve(response);
262
+ });
263
+ });
264
+
265
+ let response_code = Number(call.response_code);
266
+ if (response_code === 200) {
267
+ let response_message;
268
+ if (inputModel.includes('gpt')) {
269
+ response_message = call.body.message_warpper.message.message;
270
+ } else {
271
+ response_message = call.args.args.args.message;
272
+ }
273
+ return new Response(JSON.stringify(ChatCompletionWithModel(response_message, OriginModel)), {
274
+ headers: {
275
+ 'Content-Type': 'application/json',
276
+ },
277
+ });
278
  }
279
+ }
280
+ } catch (err) {
281
+ console.error(`Attempt ${i + 1} failed:`, err);
282
+ if (i === config.MAX_RETRY_COUNT - 1) {
283
+ return new Response(JSON.stringify({
284
+ error: {
285
+ message: "An error occurred while processing your request",
286
+ type: "server_error",
287
+ code: "internal_error",
288
+ param: null
289
+ }
290
+ }), {
291
+ status: 500,
292
+ headers: {
293
+ 'Content-Type': 'application/json'
294
+ }
295
+ });
296
+ }
297
+ await new Promise((resolve) => setTimeout(resolve, config.RETRY_DELAY));
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  }
 
 
 
 
 
 
 
299
  }
 
300
  }
301
 
302
+
303
  function ChatCompletionWithModel(message, model) {
304
  return {
305
  id: 'Chat-Nekohy',
 
342
  }
343
 
344
  async function handleCompletion(request) {
345
+ try {
346
+ // todo stream逆向接口
347
+ // 解析openai格式API请求
348
+ const { model: OriginModel, messages, stream,temperature,top_p} = await request.json();
349
+ const RegexInput = /^(claude-3-(5-sonnet|haiku|sonnet|opus))-(\d{8})$/;
350
+ const matchInput = OriginModel.match(RegexInput);
351
+ const inputModel = matchInput ? `${matchInput[1]}@${matchInput[3]}` : OriginModel;
352
+ // 添加模型验证
353
+ if (!config.isValidModel(inputModel)) {
354
+ return new Response(
355
+ JSON.stringify({
356
+ error: {
357
+ message: `Model '${OriginModel}' does not exist`,
358
+ type: "invalid_request_error",
359
+ param: "model",
360
+ code: "model_not_found"
361
+ }
362
+ }),
363
+ {
364
+ status: 404,
365
+ headers: {
366
+ 'Content-Type': 'application/json'
367
+ }
368
+ }
369
+ );
370
+ }
371
+ console.log(inputModel,messages,stream)
372
+ // 解析system和user/assistant消息
373
+ const { rules, message:content } = await messagesProcess(messages);
374
+ console.log(rules,content)
375
+ // 响应码,回复的消息
376
+ return await GrpcToPieces(inputModel,OriginModel,content, rules, stream, temperature, top_p);
377
+ } catch (err) {
378
+ return error(500, err.message);
379
  }
 
 
 
 
 
 
 
 
 
 
 
380
  }
381
 
382
  (async () => {