Spaces:
Sleeping
Sleeping
Mohammed AL Sarraj commited on
Commit ·
85a09fa
1
Parent(s): 186efee
fix: Cohere V2 API handler, handle 402 errors, fix model names
Browse files- app/core/ai.py +24 -4
app/core/ai.py
CHANGED
|
@@ -41,7 +41,7 @@ _PREMIUM_MODELS = {
|
|
| 41 |
"openai": "gpt-4o-mini",
|
| 42 |
"deepseek": "deepseek-chat",
|
| 43 |
"together": "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
|
| 44 |
-
"cohere": "command-r
|
| 45 |
}
|
| 46 |
|
| 47 |
# ── Task-specific model routing ──
|
|
@@ -57,7 +57,7 @@ _TASK_MODELS = {
|
|
| 57 |
"mistral": "mistral-medium-latest",
|
| 58 |
"deepseek": "deepseek-chat",
|
| 59 |
"together": "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
|
| 60 |
-
"cohere": "command-r
|
| 61 |
},
|
| 62 |
"code": {
|
| 63 |
"groq": "llama-3.3-70b-versatile",
|
|
@@ -66,7 +66,7 @@ _TASK_MODELS = {
|
|
| 66 |
"mistral": "mistral-medium-latest",
|
| 67 |
"deepseek": "deepseek-chat",
|
| 68 |
"together": "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
|
| 69 |
-
"cohere": "command-r
|
| 70 |
},
|
| 71 |
"fast": {
|
| 72 |
"groq": "llama-3.1-8b-instant",
|
|
@@ -196,6 +196,8 @@ def call_ai_single(provider_name: str, messages: list, system: str = "",
|
|
| 196 |
messages = [{"role": "system", "content": system}] + messages
|
| 197 |
if provider_name == "gemini":
|
| 198 |
return _post_gemini(prov["key"], model, messages, max_tokens, prov["timeout"])
|
|
|
|
|
|
|
| 199 |
return _post_openai(
|
| 200 |
prov["url"], prov["key"], model,
|
| 201 |
messages, max_tokens, prov["extra"], prov["timeout"]
|
|
@@ -221,6 +223,22 @@ def _post_openai(url, key, model, messages, max_tokens, extra_headers, timeout=6
|
|
| 221 |
return _clean(r.json()["choices"][0]["message"]["content"])
|
| 222 |
|
| 223 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
def _build_chain(task_hint: str) -> list[dict]:
|
| 225 |
"""Build an ordered provider chain for the given task hint."""
|
| 226 |
hint = task_hint if task_hint in _TASK_PRIORITY else "default"
|
|
@@ -281,13 +299,15 @@ def call_ai(messages: list, system: str = "", max_tokens: int = 2048,
|
|
| 281 |
logger.debug("Trying %s/%s for task=%s", prov["name"], prov["model"], task_hint)
|
| 282 |
if prov["name"] == "gemini":
|
| 283 |
return _post_gemini(prov["key"], prov["model"], messages, max_tokens, prov["timeout"])
|
|
|
|
|
|
|
| 284 |
return _post_openai(
|
| 285 |
prov["url"], prov["key"], prov["model"],
|
| 286 |
messages, max_tokens, prov["extra"], prov["timeout"]
|
| 287 |
)
|
| 288 |
except requests.exceptions.HTTPError as e:
|
| 289 |
status = e.response.status_code if e.response is not None else 0
|
| 290 |
-
if status in (429, 503, 502):
|
| 291 |
logger.debug("Provider %s returned %s, trying next", prov["name"], status)
|
| 292 |
last_exc = e
|
| 293 |
continue
|
|
|
|
| 41 |
"openai": "gpt-4o-mini",
|
| 42 |
"deepseek": "deepseek-chat",
|
| 43 |
"together": "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
|
| 44 |
+
"cohere": "command-r",
|
| 45 |
}
|
| 46 |
|
| 47 |
# ── Task-specific model routing ──
|
|
|
|
| 57 |
"mistral": "mistral-medium-latest",
|
| 58 |
"deepseek": "deepseek-chat",
|
| 59 |
"together": "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
|
| 60 |
+
"cohere": "command-r",
|
| 61 |
},
|
| 62 |
"code": {
|
| 63 |
"groq": "llama-3.3-70b-versatile",
|
|
|
|
| 66 |
"mistral": "mistral-medium-latest",
|
| 67 |
"deepseek": "deepseek-chat",
|
| 68 |
"together": "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
|
| 69 |
+
"cohere": "command-r",
|
| 70 |
},
|
| 71 |
"fast": {
|
| 72 |
"groq": "llama-3.1-8b-instant",
|
|
|
|
| 196 |
messages = [{"role": "system", "content": system}] + messages
|
| 197 |
if provider_name == "gemini":
|
| 198 |
return _post_gemini(prov["key"], model, messages, max_tokens, prov["timeout"])
|
| 199 |
+
if provider_name == "cohere":
|
| 200 |
+
return _post_cohere(prov["key"], model, messages, max_tokens, prov["timeout"])
|
| 201 |
return _post_openai(
|
| 202 |
prov["url"], prov["key"], model,
|
| 203 |
messages, max_tokens, prov["extra"], prov["timeout"]
|
|
|
|
| 223 |
return _clean(r.json()["choices"][0]["message"]["content"])
|
| 224 |
|
| 225 |
|
| 226 |
+
def _post_cohere(key: str, model: str, messages: list, max_tokens: int, timeout: int = 45) -> str:
|
| 227 |
+
"""Call Cohere V2 Chat API."""
|
| 228 |
+
headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"}
|
| 229 |
+
r = requests.post("https://api.cohere.com/v2/chat",
|
| 230 |
+
headers=headers,
|
| 231 |
+
json={"model": model, "messages": messages, "max_tokens": max_tokens},
|
| 232 |
+
timeout=timeout)
|
| 233 |
+
r.raise_for_status()
|
| 234 |
+
data = r.json()
|
| 235 |
+
# V2 returns content as list of blocks
|
| 236 |
+
content = data.get("message", {}).get("content", [])
|
| 237 |
+
if content and isinstance(content, list):
|
| 238 |
+
return _clean(content[0].get("text", ""))
|
| 239 |
+
return _clean(str(data))
|
| 240 |
+
|
| 241 |
+
|
| 242 |
def _build_chain(task_hint: str) -> list[dict]:
|
| 243 |
"""Build an ordered provider chain for the given task hint."""
|
| 244 |
hint = task_hint if task_hint in _TASK_PRIORITY else "default"
|
|
|
|
| 299 |
logger.debug("Trying %s/%s for task=%s", prov["name"], prov["model"], task_hint)
|
| 300 |
if prov["name"] == "gemini":
|
| 301 |
return _post_gemini(prov["key"], prov["model"], messages, max_tokens, prov["timeout"])
|
| 302 |
+
if prov["name"] == "cohere":
|
| 303 |
+
return _post_cohere(prov["key"], prov["model"], messages, max_tokens, prov["timeout"])
|
| 304 |
return _post_openai(
|
| 305 |
prov["url"], prov["key"], prov["model"],
|
| 306 |
messages, max_tokens, prov["extra"], prov["timeout"]
|
| 307 |
)
|
| 308 |
except requests.exceptions.HTTPError as e:
|
| 309 |
status = e.response.status_code if e.response is not None else 0
|
| 310 |
+
if status in (402, 429, 503, 502):
|
| 311 |
logger.debug("Provider %s returned %s, trying next", prov["name"], status)
|
| 312 |
last_exc = e
|
| 313 |
continue
|