Spaces:
Paused
Paused
| import grpc from '@huayue/grpc-js' | |
| import protoLoader from '@grpc/proto-loader' | |
| import { AutoRouter, cors, error, json } from 'itty-router' | |
| import dotenv from 'dotenv' | |
| import path, { dirname } from 'path' | |
| import { fileURLToPath } from 'url' | |
| import { createServerAdapter } from '@whatwg-node/server' | |
| import { createServer } from 'http' | |
| // 加载环境变量 | |
| dotenv.config(); | |
| // 获取当前文件的目录路径(ESM 方式) | |
| const __dirname = dirname(fileURLToPath(import.meta.url)); | |
| // 初始化配置 | |
| class Config { | |
| constructor() { | |
| this.API_PREFIX = process.env.API_PREFIX || '/' | |
| this.API_KEY = process.env.API_KEY || '' | |
| this.MAX_RETRY_COUNT = process.env.MAX_RETRY_COUNT || 3 | |
| this.RETRY_DELAY = process.env.RETRY_DELAY || 5000 | |
| this.COMMON_GRPC = 'runtime-native-io-vertex-inference-grpc-service-lmuw6mcn3q-ul.a.run.app' | |
| this.COMMON_PROTO = path.join(__dirname, '..', 'protos', 'VertexInferenceService.proto') | |
| this.GPT_GRPC = 'runtime-native-io-gpt-inference-grpc-service-lmuw6mcn3q-ul.a.run.app' | |
| this.GPT_PROTO = path.join(__dirname, '..', 'protos', 'GPTInferenceService.proto') | |
| this.PORT = process.env.PORT || 8787 | |
| // 添加支持的模型列表 | |
| this.SUPPORTED_MODELS = process.env.SUPPORTED_MODELS || [ | |
| 'gpt-4o-mini', | |
| 'gpt-4o', | |
| 'gpt-4-turbo', | |
| 'gpt-4', | |
| 'gpt-3.5-turbo', | |
| 'claude-3-sonnet@20240229', | |
| 'claude-3-opus@20240229', | |
| 'claude-3-haiku@20240307', | |
| 'claude-3-5-sonnet@20240620', | |
| 'gemini-1.5-flash', | |
| 'gemini-1.5-pro', | |
| 'chat-bison', | |
| 'codechat-bison', | |
| ] | |
| } | |
| // 添加模型验证方法 | |
| isValidModel(model) { | |
| // 处理 Claude 模型的特殊格式 | |
| const RegexInput = /^(claude-3-(5-sonnet|haiku|sonnet|opus))-(\d{8})$/ | |
| const matchInput = model.match(RegexInput) | |
| const normalizedModel = matchInput ? `${matchInput[1]}@${matchInput[3]}` : model | |
| return this.SUPPORTED_MODELS.includes(normalizedModel) | |
| } | |
| } | |
| class GRPCHandler { | |
| constructor(protoFilePath) { | |
| // 动态加载传入的 .proto 文件路径 | |
| this.packageDefinition = protoLoader.loadSync(protoFilePath, { | |
| keepCase: true, | |
| longs: String, | |
| enums: String, | |
| defaults: true, | |
| oneofs: true, | |
| }) | |
| } | |
| } | |
| const config = new Config() | |
| // 中间件 | |
| // 添加运行回源 | |
| const { preflight, corsify } = cors({ | |
| origin: '*', | |
| allowMethods: '*', | |
| exposeHeaders: '*', | |
| }) | |
| // 添加认证 | |
| const withAuth = (request) => { | |
| if (config.API_KEY) { | |
| const authHeader = request.headers.get('Authorization') | |
| if (!authHeader || !authHeader.startsWith('Bearer ')) { | |
| return error(401, 'Unauthorized: Missing or invalid Authorization header') | |
| } | |
| const token = authHeader.substring(7) | |
| if (token !== config.API_KEY) { | |
| return error(403, 'Forbidden: Invalid API key') | |
| } | |
| } | |
| } | |
| // 返回运行信息 | |
| const logger = (res, req) => { | |
| console.log(req.method, res.status, req.url, Date.now() - req.start, 'ms') | |
| } | |
| const router = AutoRouter({ | |
| before: [preflight], | |
| missing: () => error(404, '404 not found.'), | |
| finally: [corsify, logger], | |
| }) | |
| // Router路径 | |
| router.get('/', () => json({ | |
| service: "AI Chat Completion Proxy", | |
| usage: { | |
| endpoint: "/v1/chat/completions", | |
| method: "POST", | |
| headers: { | |
| "Content-Type": "application/json", | |
| "Authorization": "Bearer YOUR_API_KEY" | |
| }, | |
| body: { | |
| model: "One of: gpt-4o-mini, gpt-4o, gpt-4-turbo, gpt-4, gpt-3.5-turbo, claude-3-sonnet-20240229, claude-3-opus-20240229, claude-3-haiku-20240307, claude-3-5-sonnet-20240620, gemini-1.5-flash, gemini-1.5-pro, chat-bison, codechat-bison", | |
| messages: [ | |
| { role: "system", content: "You are a helpful assistant." }, | |
| { role: "user", content: "Hello, who are you?" } | |
| ], | |
| stream: false, | |
| temperature: 0.7, | |
| top_p: 1 | |
| } | |
| }, | |
| availableModels: [ | |
| "gpt-4o-mini", | |
| "gpt-4o", | |
| "gpt-4-turbo", | |
| "gpt-4", | |
| "gpt-3.5-turbo", | |
| "claude-3-sonnet-20240229", | |
| "claude-3-opus-20240229", | |
| "claude-3-haiku-20240307", | |
| "claude-3-5-sonnet-20240620", | |
| "gemini-1.5-flash", | |
| "gemini-1.5-pro", | |
| "chat-bison", | |
| "codechat-bison" | |
| ], | |
| note: "Replace YOUR_API_KEY with your actual API key." | |
| })); | |
| router.get(config.API_PREFIX + '/v1/models', withAuth, () => | |
| json({ | |
| object: 'list', | |
| data: [ | |
| { id: 'gpt-4o-mini', object: 'model', owned_by: 'pieces-os' }, | |
| { id: 'gpt-4o', object: 'model', owned_by: 'pieces-os' }, | |
| { id: 'gpt-4-turbo', object: 'model', owned_by: 'pieces-os' }, | |
| { id: 'gpt-4', object: 'model', owned_by: 'pieces-os' }, | |
| { id: 'gpt-3.5-turbo', object: 'model', owned_by: 'pieces-os' }, | |
| { id: 'claude-3-sonnet-20240229', object: 'model', owned_by: 'pieces-os' }, | |
| { id: 'claude-3-opus-20240229', object: 'model', owned_by: 'pieces-os' }, | |
| { id: 'claude-3-haiku-20240307', object: 'model', owned_by: 'pieces-os' }, | |
| { id: 'claude-3-5-sonnet-20240620', object: 'model', owned_by: 'pieces-os' }, | |
| { id: 'gemini-1.5-flash', object: 'model', owned_by: 'pieces-os' }, | |
| { id: 'gemini-1.5-pro', object: 'model', owned_by: 'pieces-os' }, | |
| { id: 'chat-bison', object: 'model', owned_by: 'pieces-os' }, | |
| { id: 'codechat-bison', object: 'model', owned_by: 'pieces-os' }, | |
| ], | |
| }), | |
| ) | |
| router.post(config.API_PREFIX + '/v1/chat/completions', withAuth, (req) => handleCompletion(req)); | |
| async function GrpcToPieces(inputModel, OriginModel, message, rules, stream, temperature, top_p) { | |
| // 在非GPT类型的模型中,temperature和top_p是无效的 | |
| // 使用系统的根证书 | |
| const credentials = grpc.credentials.createSsl() | |
| let client, request | |
| if (inputModel.includes('gpt')) { | |
| // 加载proto文件 | |
| const packageDefinition = new GRPCHandler(config.GPT_PROTO).packageDefinition | |
| // 构建请求消息 | |
| request = { | |
| models: inputModel, | |
| messages: [ | |
| { role: 0, message: rules }, // system | |
| { role: 1, message: message }, // user | |
| ], | |
| temperature: temperature || 0.1, | |
| top_p: top_p ?? 1, | |
| } | |
| // 获取gRPC对象 | |
| const GRPCobjects = grpc.loadPackageDefinition(packageDefinition).runtime.aot.machine_learning.parents.gpt | |
| client = new GRPCobjects.GPTInferenceService(config.GPT_GRPC, credentials) | |
| } else { | |
| // 加载proto文件 | |
| const packageDefinition = new GRPCHandler(config.COMMON_PROTO).packageDefinition | |
| // 构建请求消息 | |
| request = { | |
| models: inputModel, | |
| args: { | |
| messages: { | |
| unknown: 1, | |
| message: message, | |
| }, | |
| rules: rules, | |
| }, | |
| } | |
| // 获取gRPC对象 | |
| const GRPCobjects = grpc.loadPackageDefinition(packageDefinition).runtime.aot.machine_learning.parents.vertex | |
| client = new GRPCobjects.VertexInferenceService(config.COMMON_GRPC, credentials) | |
| } | |
| return await ConvertOpenai(client, request, inputModel, OriginModel, stream) | |
| } | |
| async function messagesProcess(messages) { | |
| let rules = '' | |
| let message = '' | |
| for (const msg of messages) { | |
| let role = msg.role | |
| // 格式化为字符串 | |
| const contentStr = Array.isArray(msg.content) | |
| ? msg.content | |
| .filter((item) => item.text) | |
| .map((item) => item.text) | |
| .join('') || '' | |
| : msg.content | |
| // 判断身份 | |
| if (role === 'system') { | |
| rules += `system:${contentStr};\r\n` | |
| } else if (['user', 'assistant'].includes(role)) { | |
| message += `${role}:${contentStr};\r\n` | |
| } | |
| } | |
| return { rules, message } | |
| } | |
| async function ConvertOpenai(client, request, inputModel, OriginModel, stream) { | |
| const metadata = new grpc.Metadata() | |
| metadata.set('User-Agent', 'dart-grpc/2.0.0') | |
| for (let i = 0; i < config.MAX_RETRY_COUNT; i++) { | |
| try { | |
| if (stream) { | |
| const call = client.PredictWithStream(request, metadata) | |
| const encoder = new TextEncoder() | |
| const ReturnStream = new ReadableStream({ | |
| start(controller) { | |
| // 处理数据 | |
| call.on('data', (response) => { | |
| try { | |
| let response_code = Number(response.response_code) | |
| if (response_code === 204) { | |
| controller.close() | |
| call.destroy() | |
| } else if (response_code === 200) { | |
| let response_message | |
| if (inputModel.includes('gpt')) { | |
| response_message = response.body.message_warpper.message.message | |
| } else { | |
| response_message = response.args.args.args.message | |
| } | |
| controller.enqueue( | |
| encoder.encode(`data: ${JSON.stringify(ChatCompletionStreamWithModel(response_message, OriginModel))}\n\n`), | |
| ) | |
| } else { | |
| console.error(`Invalid response code: ${response_code}`) | |
| controller.error(error) | |
| } | |
| } catch (error) { | |
| console.error('Error processing stream data:', error) | |
| controller.error(error) | |
| } | |
| }) | |
| // 处理错误 | |
| call.on('error', (error) => { | |
| console.error('Stream error:', error) | |
| // 如果是 INTERNAL 错误且包含 RST_STREAM,可能是正常的流结束 | |
| if (error.code === 13 && error.details.includes('RST_STREAM')) { | |
| controller.close() | |
| } else { | |
| controller.error(error) | |
| } | |
| call.destroy() | |
| }) | |
| // 处理结束 | |
| call.on('end', () => { | |
| controller.close() | |
| }) | |
| // 处理取消 | |
| return () => { | |
| call.destroy() | |
| } | |
| }, | |
| }) | |
| return new Response(ReturnStream, { | |
| headers: { | |
| 'Content-Type': 'text/event-stream', | |
| Connection: 'keep-alive', | |
| 'Cache-Control': 'no-cache', | |
| 'Transfer-Encoding': 'chunked', | |
| }, | |
| }) | |
| } else { | |
| // 非流式调用保持不变 | |
| const call = await new Promise((resolve, reject) => { | |
| client.Predict(request, metadata, (err, response) => { | |
| if (err) reject(err) | |
| else resolve(response) | |
| }) | |
| }) | |
| let response_code = Number(call.response_code) | |
| if (response_code === 200) { | |
| let response_message | |
| if (inputModel.includes('gpt')) { | |
| response_message = call.body.message_warpper.message.message | |
| } else { | |
| response_message = call.args.args.args.message | |
| } | |
| return new Response(JSON.stringify(ChatCompletionWithModel(response_message, OriginModel)), { | |
| headers: { | |
| 'Content-Type': 'application/json', | |
| }, | |
| }) | |
| } else { | |
| throw new Error(`Invalid response code: ${response_code}`) | |
| } | |
| } | |
| } catch (err) { | |
| console.error(`Attempt ${i + 1} failed:`, err) | |
| await new Promise((resolve) => setTimeout(resolve, config.RETRY_DELAY)) | |
| } | |
| } | |
| return new Response( | |
| JSON.stringify({ | |
| error: { | |
| message: 'An error occurred while processing your request', | |
| type: 'server_error', | |
| code: 'internal_error', | |
| param: null, | |
| }, | |
| }), | |
| { | |
| status: 500, | |
| headers: { | |
| 'Content-Type': 'application/json', | |
| }, | |
| }, | |
| ) | |
| } | |
| function ChatCompletionWithModel(message, model) { | |
| return { | |
| id: 'Chat-Nekohy', | |
| object: 'chat.completion', | |
| created: Date.now(), | |
| model, | |
| usage: { | |
| prompt_tokens: 0, | |
| completion_tokens: 0, | |
| total_tokens: 0, | |
| }, | |
| choices: [ | |
| { | |
| message: { | |
| content: message, | |
| role: 'assistant', | |
| }, | |
| index: 0, | |
| }, | |
| ], | |
| } | |
| } | |
| function ChatCompletionStreamWithModel(text, model) { | |
| return { | |
| id: 'chatcmpl-Nekohy', | |
| object: 'chat.completion.chunk', | |
| created: 0, | |
| model, | |
| choices: [ | |
| { | |
| index: 0, | |
| delta: { | |
| content: text, | |
| }, | |
| finish_reason: null, | |
| }, | |
| ], | |
| } | |
| } | |
| async function handleCompletion(request) { | |
| try { | |
| // 解析openai格式API请求 | |
| const { model: OriginModel, messages, stream, temperature, top_p } = await request.json() | |
| const RegexInput = /^(claude-3-(5-sonnet|haiku|sonnet|opus))-(\d{8})$/ | |
| const matchInput = OriginModel.match(RegexInput) | |
| const inputModel = matchInput ? `${matchInput[1]}@${matchInput[3]}` : OriginModel | |
| // 添加模型验证 | |
| if (!config.isValidModel(inputModel)) { | |
| return new Response( | |
| JSON.stringify({ | |
| error: { | |
| message: `Model '${OriginModel}' does not exist`, | |
| type: 'invalid_request_error', | |
| param: 'model', | |
| code: 'model_not_found', | |
| }, | |
| }), | |
| { | |
| status: 404, | |
| headers: { | |
| 'Content-Type': 'application/json', | |
| }, | |
| }, | |
| ) | |
| } | |
| console.log(inputModel, messages, stream) | |
| // 解析system和user/assistant消息 | |
| const { rules, message: content } = await messagesProcess(messages) | |
| console.log(rules, content) | |
| // 响应码,回复的消息 | |
| return await GrpcToPieces(inputModel, OriginModel, content, rules, stream, temperature, top_p) | |
| } catch (err) { | |
| return error(500, err.message) | |
| } | |
| } | |
| ;(async () => { | |
| //For Cloudflare Workers | |
| if (typeof addEventListener === 'function') return | |
| // For Nodejs | |
| const ittyServer = createServerAdapter(router.fetch) | |
| console.log(`Listening on http://localhost:${config.PORT}`) | |
| const httpServer = createServer(ittyServer) | |
| httpServer.listen(config.PORT) | |
| })() |