File size: 6,541 Bytes
5c5b371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import { Request, RequestHandler, Router } from "express";
import { createPreprocessorMiddleware } from "./middleware/request";
import { ipLimiter } from "./rate-limit";
import { createQueuedProxyMiddleware } from "./middleware/request/proxy-middleware-factory";
import { addKey, finalizeBody } from "./middleware/request";
import { ProxyResHandlerWithBody } from "./middleware/response";
import axios from "axios";
import { CohereKey, keyPool } from "../shared/key-management";
import { isCohereModel, normalizeMessages } from "../shared/api-schemas/cohere";
import { logger } from "../logger";

const log = logger.child({ module: "proxy", service: "cohere" });
let modelsCache: any = null;
let modelsCacheTime = 0;

const cohereResponseHandler: ProxyResHandlerWithBody = async (
  _proxyRes,
  req,
  res,
  body
) => {
  if (typeof body !== "object") {
    throw new Error("Expected body to be an object");
  }

  res.status(200).json({ ...body, proxy: body.proxy });
};

const getModelsResponse = async () => {
  // Return cache if less than 1 minute old
  if (new Date().getTime() - modelsCacheTime < 1000 * 60) {
    return modelsCache;
  }

  try {
    // Get a Cohere key directly
    const modelToUse = "command"; // Use any Cohere model here - just for key selection
    const cohereKey = keyPool.get(modelToUse, "cohere") as CohereKey;
    
    if (!cohereKey || !cohereKey.key) {
      log.warn("No valid Cohere key available for model listing");
      throw new Error("No valid Cohere API key available");
    }

    // Fetch models directly from Cohere API
    const response = await axios.get("https://api.cohere.com/v1/models", {
      headers: {
        "Content-Type": "application/json",
        "Authorization": `Bearer ${cohereKey.key}`,
        "Cohere-Version": "2022-12-06"
      },
    });

    if (!response.data || !response.data.models) {
      throw new Error("Unexpected response format from Cohere API");
    }

    // Extract models and filter by those that support the chat endpoint
    const filteredModels = response.data.models
      .filter((model: any) => {
        return model.endpoints && model.endpoints.includes("chat");
      })
      .map((model: any) => ({
        id: model.name,
        name: model.name,
        // Adding additional OpenAI-compatible fields
        context_window: model.context_window_size || 4096,
        max_tokens: model.max_tokens || 4096
      }));

    log.debug({ modelCount: filteredModels.length, models: filteredModels.map((m: any) => m.id) }, "Filtered models from Cohere API");

    // Format response to ensure OpenAI compatibility
    const models = {
      object: "list",
      data: filteredModels.map((model: any) => ({
        id: model.id,
        object: "model",
        created: Math.floor(Date.now() / 1000),
        owned_by: "cohere",
        permission: [],
        root: model.id,
        parent: null,
        context_length: model.context_window,
      })),
    };

    log.debug({ modelCount: filteredModels.length }, "Retrieved models from Cohere API");

    // Cache the response
    modelsCache = models;
    modelsCacheTime = new Date().getTime();
    return models;
  } catch (error) {
    // Provide detailed logging for better troubleshooting
    if (error instanceof Error) {
      log.error(
        { errorMessage: error.message, stack: error.stack },
        "Error fetching Cohere models"
      );
    } else {
      log.error({ error }, "Unknown error fetching Cohere models");
    }
    
    // Return empty list as fallback
    return {
      object: "list",
      data: [],
    };
  }
};

const handleModelRequest: RequestHandler = async (_req, res) => {
  try {
    const models = await getModelsResponse();
    res.status(200).json(models);
  } catch (error) {
    if (error instanceof Error) {
      log.error(
        { errorMessage: error.message, stack: error.stack },
        "Error handling model request"
      );
    } else {
      log.error({ error }, "Unknown error handling model request");
    }
    res.status(500).json({ error: "Failed to fetch models" });
  }
};

// Function to prepare messages for Cohere API
function prepareMessages(req: Request) {
  if (req.body.messages && Array.isArray(req.body.messages)) {
    req.body.messages = normalizeMessages(req.body.messages);
  }
}

// Function to remove parameters not supported by Cohere models
function removeUnsupportedParameters(req: Request) {
  const model = req.body.model;
  
  // Remove parameters that Cohere doesn't support
  if (req.body.logit_bias !== undefined) {
    delete req.body.logit_bias;
  }
  
  if (req.body.top_logprobs !== undefined) {
    delete req.body.top_logprobs;
  }
  
  if (req.body.max_completion_tokens !== undefined) {
    delete req.body.max_completion_tokens;
  }
  
  // Handle structured output format
  if (req.body.response_format && req.body.response_format.schema) {
    // Transform to Cohere's format if needed
    const jsonSchema = req.body.response_format.schema;
    req.body.response_format = {
      type: "json_object",
      schema: jsonSchema
    };
  }

  // Logging for debugging
  if (process.env.NODE_ENV !== 'production') {
    log.debug({ body: req.body }, "Request after parameter cleanup");
  }
}

// Set up count token functionality for Cohere models
function countCohereTokens(req: Request) {
  const model = req.body.model;
  
  if (isCohereModel(model)) {
    // Count tokens using prompt tokens (simplified)
    if (req.promptTokens) {
      req.log.debug(
        { tokens: req.promptTokens },
        "Estimated token count for Cohere prompt"
      );
    }
  }
}

const cohereProxy = createQueuedProxyMiddleware({
  mutations: [
    addKey,
    // Add Cohere-Version header to every request
    (manager) => {
      manager.setHeader("Cohere-Version", "2022-12-06");
    },
    finalizeBody
  ],
  target: "https://api.cohere.ai/compatibility",
  blockingResponseHandler: cohereResponseHandler,
});

const cohereRouter = Router();

cohereRouter.post(
  "/v1/chat/completions",
  ipLimiter,
  createPreprocessorMiddleware(
    { inApi: "openai", outApi: "openai", service: "cohere" },
    { afterTransform: [ prepareMessages, removeUnsupportedParameters, countCohereTokens ] }
  ),
  cohereProxy
);

cohereRouter.post(
  "/v1/embeddings",
  ipLimiter,
  createPreprocessorMiddleware(
    { inApi: "openai", outApi: "openai", service: "cohere" },
    { afterTransform: [] }
  ),
  cohereProxy
);

cohereRouter.get("/v1/models", handleModelRequest);

export const cohere = cohereRouter;