|
|
import { Request, RequestHandler, Router } from "express"; |
|
|
import { createPreprocessorMiddleware } from "./middleware/request"; |
|
|
import { ipLimiter } from "./rate-limit"; |
|
|
import { createQueuedProxyMiddleware } from "./middleware/request/proxy-middleware-factory"; |
|
|
import { addKey, finalizeBody } from "./middleware/request"; |
|
|
import { ProxyResHandlerWithBody } from "./middleware/response"; |
|
|
import axios from "axios"; |
|
|
import { CohereKey, keyPool } from "../shared/key-management"; |
|
|
import { isCohereModel, normalizeMessages } from "../shared/api-schemas/cohere"; |
|
|
import { logger } from "../logger"; |
|
|
|
|
|
const log = logger.child({ module: "proxy", service: "cohere" }); |
|
|
let modelsCache: any = null; |
|
|
let modelsCacheTime = 0; |
|
|
|
|
|
const cohereResponseHandler: ProxyResHandlerWithBody = async ( |
|
|
_proxyRes, |
|
|
req, |
|
|
res, |
|
|
body |
|
|
) => { |
|
|
if (typeof body !== "object") { |
|
|
throw new Error("Expected body to be an object"); |
|
|
} |
|
|
|
|
|
res.status(200).json({ ...body, proxy: body.proxy }); |
|
|
}; |
|
|
|
|
|
const getModelsResponse = async () => { |
|
|
|
|
|
if (new Date().getTime() - modelsCacheTime < 1000 * 60) { |
|
|
return modelsCache; |
|
|
} |
|
|
|
|
|
try { |
|
|
|
|
|
const modelToUse = "command"; |
|
|
const cohereKey = keyPool.get(modelToUse, "cohere") as CohereKey; |
|
|
|
|
|
if (!cohereKey || !cohereKey.key) { |
|
|
log.warn("No valid Cohere key available for model listing"); |
|
|
throw new Error("No valid Cohere API key available"); |
|
|
} |
|
|
|
|
|
|
|
|
const response = await axios.get("https://api.cohere.com/v1/models", { |
|
|
headers: { |
|
|
"Content-Type": "application/json", |
|
|
"Authorization": `Bearer ${cohereKey.key}`, |
|
|
"Cohere-Version": "2022-12-06" |
|
|
}, |
|
|
}); |
|
|
|
|
|
if (!response.data || !response.data.models) { |
|
|
throw new Error("Unexpected response format from Cohere API"); |
|
|
} |
|
|
|
|
|
|
|
|
const filteredModels = response.data.models |
|
|
.filter((model: any) => { |
|
|
return model.endpoints && model.endpoints.includes("chat"); |
|
|
}) |
|
|
.map((model: any) => ({ |
|
|
id: model.name, |
|
|
name: model.name, |
|
|
|
|
|
context_window: model.context_window_size || 4096, |
|
|
max_tokens: model.max_tokens || 4096 |
|
|
})); |
|
|
|
|
|
log.debug({ modelCount: filteredModels.length, models: filteredModels.map((m: any) => m.id) }, "Filtered models from Cohere API"); |
|
|
|
|
|
|
|
|
const models = { |
|
|
object: "list", |
|
|
data: filteredModels.map((model: any) => ({ |
|
|
id: model.id, |
|
|
object: "model", |
|
|
created: Math.floor(Date.now() / 1000), |
|
|
owned_by: "cohere", |
|
|
permission: [], |
|
|
root: model.id, |
|
|
parent: null, |
|
|
context_length: model.context_window, |
|
|
})), |
|
|
}; |
|
|
|
|
|
log.debug({ modelCount: filteredModels.length }, "Retrieved models from Cohere API"); |
|
|
|
|
|
|
|
|
modelsCache = models; |
|
|
modelsCacheTime = new Date().getTime(); |
|
|
return models; |
|
|
} catch (error) { |
|
|
|
|
|
if (error instanceof Error) { |
|
|
log.error( |
|
|
{ errorMessage: error.message, stack: error.stack }, |
|
|
"Error fetching Cohere models" |
|
|
); |
|
|
} else { |
|
|
log.error({ error }, "Unknown error fetching Cohere models"); |
|
|
} |
|
|
|
|
|
|
|
|
return { |
|
|
object: "list", |
|
|
data: [], |
|
|
}; |
|
|
} |
|
|
}; |
|
|
|
|
|
const handleModelRequest: RequestHandler = async (_req, res) => { |
|
|
try { |
|
|
const models = await getModelsResponse(); |
|
|
res.status(200).json(models); |
|
|
} catch (error) { |
|
|
if (error instanceof Error) { |
|
|
log.error( |
|
|
{ errorMessage: error.message, stack: error.stack }, |
|
|
"Error handling model request" |
|
|
); |
|
|
} else { |
|
|
log.error({ error }, "Unknown error handling model request"); |
|
|
} |
|
|
res.status(500).json({ error: "Failed to fetch models" }); |
|
|
} |
|
|
}; |
|
|
|
|
|
|
|
|
function prepareMessages(req: Request) { |
|
|
if (req.body.messages && Array.isArray(req.body.messages)) { |
|
|
req.body.messages = normalizeMessages(req.body.messages); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
function removeUnsupportedParameters(req: Request) { |
|
|
const model = req.body.model; |
|
|
|
|
|
|
|
|
if (req.body.logit_bias !== undefined) { |
|
|
delete req.body.logit_bias; |
|
|
} |
|
|
|
|
|
if (req.body.top_logprobs !== undefined) { |
|
|
delete req.body.top_logprobs; |
|
|
} |
|
|
|
|
|
if (req.body.max_completion_tokens !== undefined) { |
|
|
delete req.body.max_completion_tokens; |
|
|
} |
|
|
|
|
|
|
|
|
if (req.body.response_format && req.body.response_format.schema) { |
|
|
|
|
|
const jsonSchema = req.body.response_format.schema; |
|
|
req.body.response_format = { |
|
|
type: "json_object", |
|
|
schema: jsonSchema |
|
|
}; |
|
|
} |
|
|
|
|
|
|
|
|
if (process.env.NODE_ENV !== 'production') { |
|
|
log.debug({ body: req.body }, "Request after parameter cleanup"); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
function countCohereTokens(req: Request) { |
|
|
const model = req.body.model; |
|
|
|
|
|
if (isCohereModel(model)) { |
|
|
|
|
|
if (req.promptTokens) { |
|
|
req.log.debug( |
|
|
{ tokens: req.promptTokens }, |
|
|
"Estimated token count for Cohere prompt" |
|
|
); |
|
|
} |
|
|
} |
|
|
} |
|
|
|
|
|
const cohereProxy = createQueuedProxyMiddleware({ |
|
|
mutations: [ |
|
|
addKey, |
|
|
|
|
|
(manager) => { |
|
|
manager.setHeader("Cohere-Version", "2022-12-06"); |
|
|
}, |
|
|
finalizeBody |
|
|
], |
|
|
target: "https://api.cohere.ai/compatibility", |
|
|
blockingResponseHandler: cohereResponseHandler, |
|
|
}); |
|
|
|
|
|
const cohereRouter = Router(); |
|
|
|
|
|
cohereRouter.post( |
|
|
"/v1/chat/completions", |
|
|
ipLimiter, |
|
|
createPreprocessorMiddleware( |
|
|
{ inApi: "openai", outApi: "openai", service: "cohere" }, |
|
|
{ afterTransform: [ prepareMessages, removeUnsupportedParameters, countCohereTokens ] } |
|
|
), |
|
|
cohereProxy |
|
|
); |
|
|
|
|
|
cohereRouter.post( |
|
|
"/v1/embeddings", |
|
|
ipLimiter, |
|
|
createPreprocessorMiddleware( |
|
|
{ inApi: "openai", outApi: "openai", service: "cohere" }, |
|
|
{ afterTransform: [] } |
|
|
), |
|
|
cohereProxy |
|
|
); |
|
|
|
|
|
cohereRouter.get("/v1/models", handleModelRequest); |
|
|
|
|
|
export const cohere = cohereRouter; |
|
|
|