liuw15 commited on
Commit
37a00c9
·
1 Parent(s): ec6e4a1

兼容google的接口(思考模型不能连续调用工具)

Browse files
Files changed (2) hide show
  1. src/server/index.js +214 -1
  2. src/utils/utils.js +78 -0
src/server/index.js CHANGED
@@ -4,7 +4,7 @@ import path from 'path';
4
  import fs from 'fs';
5
  import { fileURLToPath } from 'url';
6
  import { generateAssistantResponse, generateAssistantResponseNoStream, getAvailableModels, generateImageForSD, closeRequester } from '../api/client.js';
7
- import { generateRequestBody, prepareImageRequest } from '../utils/utils.js';
8
  import logger from '../utils/logger.js';
9
  import config from '../config/config.js';
10
  import tokenManager from '../auth/token_manager.js';
@@ -202,6 +202,72 @@ const buildOpenAIErrorPayload = (error, statusCode) => {
202
  };
203
  };
204
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  app.use(cors());
206
  app.use(express.json({ limit: config.security.maxRequestSize }));
207
 
@@ -398,6 +464,153 @@ app.post('/v1/chat/completions', async (req, res) => {
398
  }
399
  });
400
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  const server = app.listen(config.server.port, config.server.host, () => {
402
  logger.info(`服务器已启动: ${config.server.host}:${config.server.port}`);
403
  });
 
4
  import fs from 'fs';
5
  import { fileURLToPath } from 'url';
6
  import { generateAssistantResponse, generateAssistantResponseNoStream, getAvailableModels, generateImageForSD, closeRequester } from '../api/client.js';
7
+ import { generateRequestBody, generateGeminiRequestBody, prepareImageRequest } from '../utils/utils.js';
8
  import logger from '../utils/logger.js';
9
  import config from '../config/config.js';
10
  import tokenManager from '../auth/token_manager.js';
 
202
  };
203
  };
204
 
205
+ // Gemini 兼容错误响应构造
206
+ const buildGeminiErrorPayload = (error, statusCode) => {
207
+ // 尝试解析原始错误信息
208
+ let message = error.message || 'Internal server error';
209
+ if (error.isUpstreamApiError && error.rawBody) {
210
+ try {
211
+ const raw = typeof error.rawBody === 'string' ? JSON.parse(error.rawBody) : error.rawBody;
212
+ message = raw.error?.message || raw.message || message;
213
+ } catch {}
214
+ }
215
+
216
+ return {
217
+ error: {
218
+ code: statusCode,
219
+ message: message,
220
+ status: "INTERNAL" // 简单映射,实际可根据 statusCode 细化
221
+ }
222
+ };
223
+ };
224
+
225
+ // Gemini 响应构建工具
226
+ const createGeminiResponse = (content, reasoning, toolCalls, finishReason, usage) => {
227
+ const parts = [];
228
+ if (reasoning) {
229
+ parts.push({ text: reasoning, thought: true });
230
+ }
231
+ if (content) {
232
+ parts.push({ text: content });
233
+ }
234
+ if (toolCalls && toolCalls.length > 0) {
235
+ toolCalls.forEach(tc => {
236
+ try {
237
+ parts.push({
238
+ functionCall: {
239
+ name: tc.function.name,
240
+ args: JSON.parse(tc.function.arguments)
241
+ }
242
+ });
243
+ } catch (e) {
244
+ // 忽略解析错误
245
+ }
246
+ });
247
+ }
248
+
249
+ const response = {
250
+ candidates: [{
251
+ content: {
252
+ parts: parts,
253
+ role: "model"
254
+ },
255
+ finishReason: finishReason || "STOP",
256
+ index: 0
257
+ }]
258
+ };
259
+
260
+ if (usage) {
261
+ response.usageMetadata = {
262
+ promptTokenCount: usage.prompt_tokens,
263
+ candidatesTokenCount: usage.completion_tokens,
264
+ totalTokenCount: usage.total_tokens
265
+ };
266
+ }
267
+
268
+ return response;
269
+ };
270
+
271
  app.use(cors());
272
  app.use(express.json({ limit: config.security.maxRequestSize }));
273
 
 
464
  }
465
  });
466
 
467
+ // Gemini 模型列表格式转换
468
+ const convertToGeminiModelList = (openaiModels) => {
469
+ const models = openaiModels.data.map(model => ({
470
+ name: `models/${model.id}`,
471
+ version: "001",
472
+ displayName: model.id,
473
+ description: "Imported model",
474
+ inputTokenLimit: 32768, // 默认值
475
+ outputTokenLimit: 8192, // 默认值
476
+ supportedGenerationMethods: ["generateContent", "countTokens"],
477
+ temperature: 0.9,
478
+ topP: 1.0,
479
+ topK: 40
480
+ }));
481
+ return { models };
482
+ };
483
+
484
+ // Gemini API 路由
485
+ app.get('/v1beta/models', async (req, res) => {
486
+ try {
487
+ const openaiModels = await getAvailableModels();
488
+ const geminiModels = convertToGeminiModelList(openaiModels);
489
+ res.json(geminiModels);
490
+ } catch (error) {
491
+ logger.error('获取模型列表失败:', error.message);
492
+ res.status(500).json({ error: { code: 500, message: error.message, status: "INTERNAL" } });
493
+ }
494
+ });
495
+
496
+ app.get('/v1beta/models/:model', async (req, res) => {
497
+ try {
498
+ const modelId = req.params.model.replace(/^models\//, '');
499
+ const openaiModels = await getAvailableModels();
500
+ const model = openaiModels.data.find(m => m.id === modelId);
501
+
502
+ if (model) {
503
+ const geminiModel = {
504
+ name: `models/${model.id}`,
505
+ version: "001",
506
+ displayName: model.id,
507
+ description: "Imported model",
508
+ inputTokenLimit: 32768,
509
+ outputTokenLimit: 8192,
510
+ supportedGenerationMethods: ["generateContent", "countTokens"],
511
+ temperature: 0.9,
512
+ topP: 1.0,
513
+ topK: 40
514
+ };
515
+ res.json(geminiModel);
516
+ } else {
517
+ res.status(404).json({ error: { code: 404, message: `Model ${modelId} not found`, status: "NOT_FOUND" } });
518
+ }
519
+ } catch (error) {
520
+ logger.error('获取模型详情失败:', error.message);
521
+ res.status(500).json({ error: { code: 500, message: error.message, status: "INTERNAL" } });
522
+ }
523
+ });
524
+
525
+ const handleGeminiRequest = async (req, res, modelName, isStream) => {
526
+ try {
527
+ const token = await tokenManager.getToken();
528
+ if (!token) {
529
+ throw new Error('没有可用的token,请运行 npm run login 获取token');
530
+ }
531
+
532
+ const requestBody = generateGeminiRequestBody(req.body, modelName, token);
533
+ const maxRetries = Number(config.retryTimes || 0);
534
+ const safeRetries = maxRetries > 0 ? Math.floor(maxRetries) : 0;
535
+
536
+ if (isStream) {
537
+ setStreamHeaders(res);
538
+ const heartbeatTimer = createHeartbeat(res);
539
+
540
+ try {
541
+ let usageData = null;
542
+ let hasToolCall = false;
543
+
544
+ await with429Retry(
545
+ () => generateAssistantResponse(requestBody, token, (data) => {
546
+ if (data.type === 'usage') {
547
+ usageData = data.usage;
548
+ } else if (data.type === 'reasoning') {
549
+ // Gemini 思考内容
550
+ const chunk = createGeminiResponse(null, data.reasoning_content, null, null, null);
551
+ writeStreamData(res, chunk);
552
+ } else if (data.type === 'tool_calls') {
553
+ hasToolCall = true;
554
+ // Gemini 工具调用
555
+ const chunk = createGeminiResponse(null, null, data.tool_calls, null, null);
556
+ writeStreamData(res, chunk);
557
+ } else {
558
+ // 普通文本
559
+ const chunk = createGeminiResponse(data.content, null, null, null, null);
560
+ writeStreamData(res, chunk);
561
+ }
562
+ }),
563
+ safeRetries,
564
+ 'gemini.stream '
565
+ );
566
+
567
+ // 发送结束块和 usage
568
+ const finishReason = hasToolCall ? "STOP" : "STOP"; // Gemini 工具调用也是 STOP
569
+ const finalChunk = createGeminiResponse(null, null, null, finishReason, usageData);
570
+ writeStreamData(res, finalChunk);
571
+
572
+ clearInterval(heartbeatTimer);
573
+ endStream(res);
574
+ } catch (error) {
575
+ clearInterval(heartbeatTimer);
576
+ throw error;
577
+ }
578
+ } else {
579
+ // 非流式
580
+ req.setTimeout(0);
581
+ res.setTimeout(0);
582
+
583
+ const { content, reasoningContent, toolCalls, usage } = await with429Retry(
584
+ () => generateAssistantResponseNoStream(requestBody, token),
585
+ safeRetries,
586
+ 'gemini.no_stream '
587
+ );
588
+
589
+ const finishReason = toolCalls.length > 0 ? "STOP" : "STOP";
590
+ const response = createGeminiResponse(content, reasoningContent, toolCalls, finishReason, usage);
591
+ res.json(response);
592
+ }
593
+ } catch (error) {
594
+ logger.error('Gemini 请求失败:', error.message);
595
+ if (res.headersSent) return;
596
+
597
+ const statusCode = Number(error.status) || 500;
598
+ const errorPayload = buildGeminiErrorPayload(error, statusCode);
599
+ res.status(statusCode).json(errorPayload);
600
+ }
601
+ };
602
+
603
+ app.post('/v1beta/models/:model\\:streamGenerateContent', (req, res) => {
604
+ const modelName = req.params.model;
605
+ handleGeminiRequest(req, res, modelName, true);
606
+ });
607
+
608
+ app.post('/v1beta/models/:model\\:generateContent', (req, res) => {
609
+ const modelName = req.params.model;
610
+ const isStream = req.query.alt === 'sse';
611
+ handleGeminiRequest(req, res, modelName, isStream);
612
+ });
613
+
614
  const server = app.listen(config.server.port, config.server.host, () => {
615
  logger.info(`服务器已启动: ${config.server.host}:${config.server.port}`);
616
  });
src/utils/utils.js CHANGED
@@ -487,9 +487,87 @@ function getDefaultIp(){
487
  }
488
  return '127.0.0.1';
489
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
490
  export{
491
  generateRequestId,
492
  generateRequestBody,
 
493
  prepareImageRequest,
494
  getDefaultIp
495
  }
 
487
  }
488
  return '127.0.0.1';
489
  }
490
+
491
+ function generateGeminiRequestBody(geminiBody, modelName, token){
492
+ const enableThinking = isEnableThinking(modelName);
493
+ const actualModelName = modelMapping(modelName);
494
+
495
+ // 深拷贝 body,避免修改原始对象
496
+ const request = JSON.parse(JSON.stringify(geminiBody));
497
+ //console.log(JSON.stringify(request,null,2));
498
+
499
+ // 处理 contents 中的 functionCall 和 functionResponse,确保有 id 字段
500
+ if (request.contents && Array.isArray(request.contents)) {
501
+ // 第一遍:收集所有 functionCall 的 name -> id 映射
502
+ const functionCallIds = new Map();
503
+ request.contents.forEach(content => {
504
+ if (content.role === 'model' && content.parts && Array.isArray(content.parts)) {
505
+ content.parts.forEach(part => {
506
+ if (part.functionCall) {
507
+ if (!part.functionCall.id) {
508
+ part.functionCall.id = `call_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
509
+ }
510
+ // 记录 name -> id 映射
511
+ functionCallIds.set(part.functionCall.name, part.functionCall.id);
512
+ }
513
+ });
514
+ }
515
+ });
516
+
517
+ // 第二遍:为 functionResponse 匹配对应的 id
518
+ request.contents.forEach(content => {
519
+ if (content.role === 'user' && content.parts && Array.isArray(content.parts)) {
520
+ content.parts.forEach(part => {
521
+ if (part.functionResponse && !part.functionResponse.id) {
522
+ // 尝试从映射中找到对应的 id
523
+ const matchedId = functionCallIds.get(part.functionResponse.name);
524
+ part.functionResponse.id = matchedId || `call_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
525
+ }
526
+ });
527
+ }
528
+ });
529
+ }
530
+
531
+ // 确保 generationConfig 存在
532
+ if (!request.generationConfig) {
533
+ request.generationConfig = {};
534
+ }
535
+
536
+ // 处理思考模型配置
537
+ if (enableThinking) {
538
+ const defaultThinkingBudget = config.defaults.thinking_budget ?? 1024;
539
+ // 如果没有 thinkingConfig,尝试注入
540
+ if (!request.generationConfig.thinkingConfig) {
541
+ request.generationConfig.thinkingConfig = {
542
+ includeThoughts: true,
543
+ thinkingBudget: defaultThinkingBudget
544
+ };
545
+ }
546
+ }
547
+
548
+ // 强制 candidateCount 为 1
549
+ request.generationConfig.candidateCount = 1;
550
+
551
+ // 注入 sessionId
552
+ request.sessionId = token.sessionId;
553
+ delete request.safetySettings;
554
+
555
+ // 构造 Antigravity 请求体
556
+ const requestBody = {
557
+ project: token.projectId,
558
+ requestId: generateRequestId(),
559
+ request: request,
560
+ model: actualModelName,
561
+ userAgent: "antigravity"
562
+ };
563
+
564
+ return requestBody;
565
+ }
566
+
567
  export{
568
  generateRequestId,
569
  generateRequestBody,
570
+ generateGeminiRequestBody,
571
  prepareImageRequest,
572
  getDefaultIp
573
  }