|
|
import { Request, RequestHandler, Router, Response, NextFunction } from "express"; |
|
|
import { v4 } from "uuid"; |
|
|
import { GoogleAIKey, keyPool } from "../shared/key-management"; |
|
|
import { config } from "../config"; |
|
|
import { ipLimiter } from "./rate-limit"; |
|
|
import { |
|
|
createPreprocessorMiddleware, |
|
|
finalizeSignedRequest, |
|
|
} from "./middleware/request"; |
|
|
import { ProxyResHandlerWithBody } from "./middleware/response"; |
|
|
import { addGoogleAIKey } from "./middleware/request/mutators/add-google-ai-key"; |
|
|
import { createQueuedProxyMiddleware } from "./middleware/request/proxy-middleware-factory"; |
|
|
import axios from "axios"; |
|
|
|
|
|
let modelsCache: any = null; |
|
|
let modelsCacheTime = 0; |
|
|
|
|
|
|
|
|
let nativeModelsCache: any = null; |
|
|
let nativeModelsCacheTime = 0; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const getModelsResponse = () => { |
|
|
if (new Date().getTime() - modelsCacheTime < 1000 * 60) { |
|
|
return modelsCache; |
|
|
} |
|
|
|
|
|
if (!config.googleAIKey) return { object: "list", data: [] }; |
|
|
|
|
|
const keys = keyPool |
|
|
.list() |
|
|
.filter((k) => k.service === "google-ai") as GoogleAIKey[]; |
|
|
if (keys.length === 0) { |
|
|
modelsCache = { object: "list", data: [] }; |
|
|
modelsCacheTime = new Date().getTime(); |
|
|
return modelsCache; |
|
|
} |
|
|
|
|
|
|
|
|
const modelIds = Array.from( |
|
|
new Set(keys.map((k) => k.modelIds).flat()) |
|
|
).filter((id) => id.startsWith("models/") && !id.includes("bard")); |
|
|
|
|
|
|
|
|
const models = modelIds.map((id) => ({ |
|
|
|
|
|
id: id.startsWith("models/") ? id.slice("models/".length) : id, |
|
|
object: "model", |
|
|
created: new Date().getTime(), |
|
|
owned_by: "google", |
|
|
permission: [], |
|
|
root: "google", |
|
|
parent: null, |
|
|
})); |
|
|
|
|
|
modelsCache = { object: "list", data: models }; |
|
|
modelsCacheTime = new Date().getTime(); |
|
|
|
|
|
return modelsCache; |
|
|
}; |
|
|
|
|
|
|
|
|
const getNativeModelsResponse = async () => { |
|
|
|
|
|
if (new Date().getTime() - nativeModelsCacheTime < 1000 * 60) { |
|
|
return nativeModelsCache; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const openaiStyle = getModelsResponse(); |
|
|
const models = (openaiStyle.data || []).map((m: any) => ({ |
|
|
|
|
|
name: `models/${m.id}`, |
|
|
supportedGenerationMethods: ["generateContent"], |
|
|
})); |
|
|
|
|
|
nativeModelsCache = { models }; |
|
|
nativeModelsCacheTime = new Date().getTime(); |
|
|
return nativeModelsCache; |
|
|
}; |
|
|
|
|
|
const handleModelRequest: RequestHandler = (_req: Request, res: any) => { |
|
|
res.status(200).json(getModelsResponse()); |
|
|
}; |
|
|
|
|
|
|
|
|
const handleNativeModelRequest: RequestHandler = async (_req: Request, res: any) => { |
|
|
try { |
|
|
const modelsResponse = await getNativeModelsResponse(); |
|
|
res.status(200).json(modelsResponse); |
|
|
} catch (error) { |
|
|
console.error("Error in handleNativeModelRequest:", error); |
|
|
res.status(500).json({ error: "Failed to fetch models" }); |
|
|
} |
|
|
}; |
|
|
|
|
|
const googleAIBlockingResponseHandler: ProxyResHandlerWithBody = async ( |
|
|
_proxyRes, |
|
|
req, |
|
|
res, |
|
|
body |
|
|
) => { |
|
|
if (typeof body !== "object") { |
|
|
throw new Error("Expected body to be an object"); |
|
|
} |
|
|
|
|
|
let newBody = body; |
|
|
if (req.inboundApi === "openai") { |
|
|
req.log.info("Transforming Google AI response to OpenAI format"); |
|
|
newBody = transformGoogleAIResponse(body, req); |
|
|
} |
|
|
|
|
|
res.status(200).json({ ...newBody, proxy: body.proxy }); |
|
|
}; |
|
|
|
|
|
function transformGoogleAIResponse( |
|
|
resBody: Record<string, any>, |
|
|
req: Request |
|
|
): Record<string, any> { |
|
|
const totalTokens = (req.promptTokens ?? 0) + (req.outputTokens ?? 0); |
|
|
|
|
|
|
|
|
let content = ""; |
|
|
|
|
|
|
|
|
if (resBody.candidates && resBody.candidates[0]) { |
|
|
const candidate = resBody.candidates[0]; |
|
|
|
|
|
|
|
|
if (candidate.content?.parts && candidate.content.parts[0]?.text) { |
|
|
|
|
|
content = candidate.content.parts[0].text; |
|
|
} else if (candidate.content?.text) { |
|
|
|
|
|
content = candidate.content.text; |
|
|
} else if (typeof candidate.content?.parts?.[0] === 'string') { |
|
|
|
|
|
content = candidate.content.parts[0]; |
|
|
} |
|
|
|
|
|
|
|
|
content = content.replace(/^(.{0,50}?): /, () => ""); |
|
|
} |
|
|
|
|
|
return { |
|
|
id: "goo-" + v4(), |
|
|
object: "chat.completion", |
|
|
created: Date.now(), |
|
|
model: req.body.model, |
|
|
usage: { |
|
|
prompt_tokens: req.promptTokens, |
|
|
completion_tokens: req.outputTokens, |
|
|
total_tokens: totalTokens, |
|
|
}, |
|
|
choices: [ |
|
|
{ |
|
|
message: { role: "assistant", content }, |
|
|
finish_reason: resBody.candidates?.[0]?.finishReason || "STOP", |
|
|
index: 0, |
|
|
}, |
|
|
], |
|
|
}; |
|
|
} |
|
|
|
|
|
const googleAIProxy = createQueuedProxyMiddleware({ |
|
|
target: ({ signedRequest }: { signedRequest: any }) => { |
|
|
if (!signedRequest) throw new Error("Must sign request before proxying"); |
|
|
const { protocol, hostname} = signedRequest; |
|
|
return `${protocol}//${hostname}`; |
|
|
}, |
|
|
mutations: [addGoogleAIKey, finalizeSignedRequest], |
|
|
blockingResponseHandler: googleAIBlockingResponseHandler, |
|
|
}); |
|
|
|
|
|
const googleAIRouter = Router(); |
|
|
googleAIRouter.get("/v1/models", handleModelRequest); |
|
|
googleAIRouter.get("/:apiVersion(v1alpha|v1beta)/models", handleNativeModelRequest); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function processThinkingBudget(req: Request) { |
|
|
if (req.body.generationConfig?.thinkingConfig?.thinkingBudget !== undefined) { |
|
|
|
|
|
const budget = req.body.generationConfig.thinkingConfig.thinkingBudget; |
|
|
|
|
|
|
|
|
if (typeof budget === 'number') { |
|
|
req.body.generationConfig.thinkingConfig.thinkingBudget = |
|
|
Math.max(0, Math.min(budget, 24576)); |
|
|
} |
|
|
|
|
|
} |
|
|
} |
|
|
|
|
|
function setStreamFlag(req: Request) { |
|
|
const isStreaming = req.url.includes("streamGenerateContent"); |
|
|
if (isStreaming) { |
|
|
req.body.stream = true; |
|
|
req.isStreaming = true; |
|
|
} else { |
|
|
req.body.stream = false; |
|
|
req.isStreaming = false; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function maybeReassignModel(req: Request) { |
|
|
|
|
|
const model = req.body.model || req.url.split("/").pop()?.split(":").shift(); |
|
|
if (!model) { |
|
|
throw new Error("You must specify a model with your request."); |
|
|
} |
|
|
req.body.model = model; |
|
|
|
|
|
|
|
|
if (model.startsWith("models/")) { |
|
|
req.body.model = model.slice("models/".length); |
|
|
req.log.info({ originalModel: model, updatedModel: req.body.model }, "Stripped 'models/' prefix from model ID"); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
function checkAndBlockExperimentalModels(req: Request) { |
|
|
const modelId = req.body.model as string | undefined; |
|
|
|
|
|
|
|
|
if (modelId && modelId.toLowerCase().includes("exp")) { |
|
|
req.log.warn({ modelId }, "Blocking request to experimental Google AI model."); |
|
|
const err: any = new Error("Experimental models are too unstable to be supported in proxy code. Please use preview models instead."); |
|
|
err.statusCode = 400; |
|
|
throw err; |
|
|
} |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
googleAIRouter.post( |
|
|
"/:apiVersion(v1alpha|v1beta)/models/:modelId:(generateContent|streamGenerateContent)", |
|
|
ipLimiter, |
|
|
createPreprocessorMiddleware( |
|
|
{ inApi: "google-ai", outApi: "google-ai", service: "google-ai" }, |
|
|
{ |
|
|
beforeTransform: [maybeReassignModel], |
|
|
afterTransform: [checkAndBlockExperimentalModels, setStreamFlag, processThinkingBudget] |
|
|
} |
|
|
), |
|
|
googleAIProxy |
|
|
); |
|
|
|
|
|
|
|
|
googleAIRouter.post( |
|
|
"/v1/chat/completions", |
|
|
ipLimiter, |
|
|
createPreprocessorMiddleware( |
|
|
{ inApi: "openai", outApi: "google-ai", service: "google-ai" }, |
|
|
{ |
|
|
afterTransform: [maybeReassignModel, checkAndBlockExperimentalModels, processThinkingBudget] |
|
|
} |
|
|
), |
|
|
googleAIProxy |
|
|
); |
|
|
|
|
|
export const googleAI = googleAIRouter; |
|
|
|