Mohammed AL Sarraj commited on
Commit
85a09fa
·
1 Parent(s): 186efee

fix: Cohere V2 API handler, handle 402 errors, fix model names

Browse files
Files changed (1) hide show
  1. app/core/ai.py +24 -4
app/core/ai.py CHANGED
@@ -41,7 +41,7 @@ _PREMIUM_MODELS = {
41
  "openai": "gpt-4o-mini",
42
  "deepseek": "deepseek-chat",
43
  "together": "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
44
- "cohere": "command-r-plus",
45
  }
46
 
47
  # ── Task-specific model routing ──
@@ -57,7 +57,7 @@ _TASK_MODELS = {
57
  "mistral": "mistral-medium-latest",
58
  "deepseek": "deepseek-chat",
59
  "together": "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
60
- "cohere": "command-r-plus",
61
  },
62
  "code": {
63
  "groq": "llama-3.3-70b-versatile",
@@ -66,7 +66,7 @@ _TASK_MODELS = {
66
  "mistral": "mistral-medium-latest",
67
  "deepseek": "deepseek-chat",
68
  "together": "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
69
- "cohere": "command-r-plus",
70
  },
71
  "fast": {
72
  "groq": "llama-3.1-8b-instant",
@@ -196,6 +196,8 @@ def call_ai_single(provider_name: str, messages: list, system: str = "",
196
  messages = [{"role": "system", "content": system}] + messages
197
  if provider_name == "gemini":
198
  return _post_gemini(prov["key"], model, messages, max_tokens, prov["timeout"])
 
 
199
  return _post_openai(
200
  prov["url"], prov["key"], model,
201
  messages, max_tokens, prov["extra"], prov["timeout"]
@@ -221,6 +223,22 @@ def _post_openai(url, key, model, messages, max_tokens, extra_headers, timeout=6
221
  return _clean(r.json()["choices"][0]["message"]["content"])
222
 
223
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
  def _build_chain(task_hint: str) -> list[dict]:
225
  """Build an ordered provider chain for the given task hint."""
226
  hint = task_hint if task_hint in _TASK_PRIORITY else "default"
@@ -281,13 +299,15 @@ def call_ai(messages: list, system: str = "", max_tokens: int = 2048,
281
  logger.debug("Trying %s/%s for task=%s", prov["name"], prov["model"], task_hint)
282
  if prov["name"] == "gemini":
283
  return _post_gemini(prov["key"], prov["model"], messages, max_tokens, prov["timeout"])
 
 
284
  return _post_openai(
285
  prov["url"], prov["key"], prov["model"],
286
  messages, max_tokens, prov["extra"], prov["timeout"]
287
  )
288
  except requests.exceptions.HTTPError as e:
289
  status = e.response.status_code if e.response is not None else 0
290
- if status in (429, 503, 502):
291
  logger.debug("Provider %s returned %s, trying next", prov["name"], status)
292
  last_exc = e
293
  continue
 
41
  "openai": "gpt-4o-mini",
42
  "deepseek": "deepseek-chat",
43
  "together": "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
44
+ "cohere": "command-r",
45
  }
46
 
47
  # ── Task-specific model routing ──
 
57
  "mistral": "mistral-medium-latest",
58
  "deepseek": "deepseek-chat",
59
  "together": "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
60
+ "cohere": "command-r",
61
  },
62
  "code": {
63
  "groq": "llama-3.3-70b-versatile",
 
66
  "mistral": "mistral-medium-latest",
67
  "deepseek": "deepseek-chat",
68
  "together": "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
69
+ "cohere": "command-r",
70
  },
71
  "fast": {
72
  "groq": "llama-3.1-8b-instant",
 
196
  messages = [{"role": "system", "content": system}] + messages
197
  if provider_name == "gemini":
198
  return _post_gemini(prov["key"], model, messages, max_tokens, prov["timeout"])
199
+ if provider_name == "cohere":
200
+ return _post_cohere(prov["key"], model, messages, max_tokens, prov["timeout"])
201
  return _post_openai(
202
  prov["url"], prov["key"], model,
203
  messages, max_tokens, prov["extra"], prov["timeout"]
 
223
  return _clean(r.json()["choices"][0]["message"]["content"])
224
 
225
 
226
+ def _post_cohere(key: str, model: str, messages: list, max_tokens: int, timeout: int = 45) -> str:
227
+ """Call Cohere V2 Chat API."""
228
+ headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"}
229
+ r = requests.post("https://api.cohere.com/v2/chat",
230
+ headers=headers,
231
+ json={"model": model, "messages": messages, "max_tokens": max_tokens},
232
+ timeout=timeout)
233
+ r.raise_for_status()
234
+ data = r.json()
235
+ # V2 returns content as list of blocks
236
+ content = data.get("message", {}).get("content", [])
237
+ if content and isinstance(content, list):
238
+ return _clean(content[0].get("text", ""))
239
+ return _clean(str(data))
240
+
241
+
242
  def _build_chain(task_hint: str) -> list[dict]:
243
  """Build an ordered provider chain for the given task hint."""
244
  hint = task_hint if task_hint in _TASK_PRIORITY else "default"
 
299
  logger.debug("Trying %s/%s for task=%s", prov["name"], prov["model"], task_hint)
300
  if prov["name"] == "gemini":
301
  return _post_gemini(prov["key"], prov["model"], messages, max_tokens, prov["timeout"])
302
+ if prov["name"] == "cohere":
303
+ return _post_cohere(prov["key"], prov["model"], messages, max_tokens, prov["timeout"])
304
  return _post_openai(
305
  prov["url"], prov["key"], prov["model"],
306
  messages, max_tokens, prov["extra"], prov["timeout"]
307
  )
308
  except requests.exceptions.HTTPError as e:
309
  status = e.response.status_code if e.response is not None else 0
310
+ if status in (402, 429, 503, 502):
311
  logger.debug("Provider %s returned %s, trying next", prov["name"], status)
312
  last_exc = e
313
  continue