devkit / app /core /ai.py
Mohammed AL Sarraj
fix: remove paid providers (DeepSeek/Together), fix Cohere model name
009e621
"""Multi-provider AI engine with smart task routing.
Runtime chain: Groq -> Cerebras -> OpenRouter -> Mistral -> Ollama.
Task hints route to the best model for the job:
- "arabic" → large models (70B+) for Arabic NLP quality
- "code" → code-optimized models
- "fast" → smallest/fastest model available
- "default" → standard free-tier chain
"""
import json, logging, os, re, requests
logger = logging.getLogger(__name__)
_OLLAMA_BASE = "http://localhost:11434"
_PROVIDER_URLS = {
"groq": "https://api.groq.com/openai/v1/chat/completions",
"cerebras": "https://api.cerebras.ai/v1/chat/completions",
"openrouter": "https://openrouter.ai/api/v1/chat/completions",
"mistral": "https://api.mistral.ai/v1/chat/completions",
"openai": "https://api.openai.com/v1/chat/completions",
"cohere": "https://api.cohere.com/v2/chat",
}
# ── Model tiers per provider ──
_FREE_MODELS = {
"groq": "llama-3.1-8b-instant",
"cerebras": "llama3.1-8b",
"openrouter": "google/gemma-3-12b-it:free",
"mistral": "mistral-small-latest",
"cohere": "command-r-08-2024",
}
_PREMIUM_MODELS = {
"groq": "llama-3.3-70b-versatile",
"cerebras": "qwen-3-235b-a22b-instruct-2507",
"openrouter": "google/gemma-3-27b-it:free",
"mistral": "mistral-medium-latest",
"openai": "gpt-4o-mini",
"cohere": "command-r-08-2024",
}
# ── Task-specific model routing ──
# Maps task hints to the best model per provider.
# "arabic" needs large models for Arabic morphology, grammar, dialect awareness.
# "code" needs code-tuned models for test generation, SQL, schema analysis.
# "fast" uses smallest models for quick responses.
_TASK_MODELS = {
"arabic": {
"groq": "llama-3.3-70b-versatile",
"cerebras": "qwen-3-235b-a22b-instruct-2507",
"openrouter": "google/gemma-3-27b-it:free",
"mistral": "mistral-medium-latest",
"cohere": "command-r7b-arabic-02-2025",
},
"code": {
"groq": "llama-3.3-70b-versatile",
"cerebras": "qwen-3-235b-a22b-instruct-2507",
"openrouter": "google/gemma-3-27b-it:free",
"mistral": "mistral-medium-latest",
"cohere": "command-r-08-2024",
},
"fast": {
"groq": "llama-3.1-8b-instant",
"cerebras": "llama3.1-8b",
"openrouter": "google/gemma-3-12b-it:free",
"mistral": "mistral-small-latest",
"cohere": "command-r-08-2024",
},
}
# ── Task-specific provider priority ──
_TASK_PRIORITY = {
"arabic": ["cerebras", "groq", "openrouter", "cohere", "mistral"],
"code": ["groq", "cerebras", "openrouter", "cohere", "mistral"],
"fast": ["cerebras", "groq", "openrouter", "cohere", "mistral"],
"default": ["groq", "cerebras", "openrouter", "cohere", "mistral"],
}
_CHAIN_CFG = {
"groq": {"key_env": "GROQ_API_KEY", "timeout": 30, "extra": {}},
"cerebras": {"key_env": "CEREBRAS_API_KEY", "timeout": 30, "extra": {}},
"openrouter": {"key_env": "OPENROUTER_API_KEY", "timeout": 45,
"extra": {"HTTP-Referer": "https://github.com/Moealsarraj", "X-Title": "AI Tools"}},
"mistral": {"key_env": "MISTRAL_API_KEY", "timeout": 40, "extra": {}},
"cohere": {"key_env": "COHERE_API_KEY", "timeout": 45, "extra": {}},
}
# Build available providers (those with valid keys)
_AVAILABLE = {}
for _name, _cfg in _CHAIN_CFG.items():
_k = os.environ.get(_cfg["key_env"], "")
if _k:
_AVAILABLE[_name] = {
"name": _name,
"url": _PROVIDER_URLS[_name],
"key": _k,
"timeout": _cfg["timeout"],
"extra": _cfg["extra"],
}
# Ollama fallback
_OLLAMA_PROVIDER = None
try:
_r = requests.get(f"{_OLLAMA_BASE}/api/tags", timeout=3)
if _r.status_code == 200:
_installed = [m["name"] for m in _r.json().get("models", [])]
if _installed:
_OLLAMA_PROVIDER = {"name": "ollama", "model": _installed[0]}
except Exception:
pass
# ── Google Gemini (special API format) ──
_GEMINI_KEY = os.environ.get("GEMINI_API_KEY", "")
if _GEMINI_KEY:
_AVAILABLE["gemini"] = {
"name": "gemini",
"url": "https://generativelanguage.googleapis.com/v1beta/models",
"key": _GEMINI_KEY,
"timeout": 60,
"extra": {},
}
_FREE_MODELS["gemini"] = "gemini-2.0-flash"
_PREMIUM_MODELS["gemini"] = "gemini-2.0-flash"
for task in _TASK_MODELS:
_TASK_MODELS[task]["gemini"] = "gemini-2.0-flash"
for task in _TASK_PRIORITY:
if "gemini" not in _TASK_PRIORITY[task]:
_TASK_PRIORITY[task].insert(2, "gemini")
_AI_AVAILABLE = bool(_AVAILABLE or _OLLAMA_PROVIDER)
def _post_gemini(key: str, model: str, messages: list, max_tokens: int, timeout: int = 60) -> str:
"""Call Google Gemini API (non-OpenAI format)."""
# Convert OpenAI message format to Gemini format
contents = []
system_text = ""
for msg in messages:
role = msg["role"]
if role == "system":
system_text = msg["content"]
continue
contents.append({
"role": "user" if role == "user" else "model",
"parts": [{"text": msg["content"]}],
})
body = {
"contents": contents,
"generationConfig": {"maxOutputTokens": max_tokens},
}
if system_text:
body["systemInstruction"] = {"parts": [{"text": system_text}]}
url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={key}"
r = requests.post(url, json=body, timeout=timeout)
r.raise_for_status()
data = r.json()
return _clean(data["candidates"][0]["content"]["parts"][0]["text"])
def get_available_providers() -> list[dict]:
"""Return list of available providers with their model info."""
providers = []
for name, prov in _AVAILABLE.items():
providers.append({
"name": name,
"model_free": _FREE_MODELS.get(name, ""),
"model_premium": _PREMIUM_MODELS.get(name, ""),
})
return providers
def call_ai_single(provider_name: str, messages: list, system: str = "",
max_tokens: int = 2048, use_premium: bool = True) -> str:
"""Call a specific provider directly (no fallback chain)."""
if provider_name not in _AVAILABLE:
raise ValueError(f"Provider {provider_name!r} not available")
prov = _AVAILABLE[provider_name]
models = _PREMIUM_MODELS if use_premium else _FREE_MODELS
model = models.get(provider_name, prov.get("model", ""))
if system:
messages = [{"role": "system", "content": system}] + messages
if provider_name == "gemini":
return _post_gemini(prov["key"], model, messages, max_tokens, prov["timeout"])
if provider_name == "cohere":
return _post_cohere(prov["key"], model, messages, max_tokens, prov["timeout"])
return _post_openai(
prov["url"], prov["key"], model,
messages, max_tokens, prov["extra"], prov["timeout"]
)
_RE_THINK = re.compile(r"<think>.*?</think>", re.DOTALL)
_RE_OPEN = re.compile(r"^```[a-z]*\n?", re.MULTILINE)
_RE_CLOSE = re.compile(r"\n?```$", re.MULTILINE)
def _clean(raw: str) -> str:
raw = _RE_THINK.sub("", raw).strip()
raw = _RE_OPEN.sub("", raw)
return _RE_CLOSE.sub("", raw).strip()
def _post_openai(url, key, model, messages, max_tokens, extra_headers, timeout=60):
headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"}
headers.update(extra_headers)
r = requests.post(url, headers=headers,
json={"model": model, "messages": messages, "max_tokens": max_tokens},
timeout=timeout)
r.raise_for_status()
return _clean(r.json()["choices"][0]["message"]["content"])
def _post_cohere(key: str, model: str, messages: list, max_tokens: int, timeout: int = 45) -> str:
"""Call Cohere V2 Chat API."""
headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"}
r = requests.post("https://api.cohere.com/v2/chat",
headers=headers,
json={"model": model, "messages": messages, "max_tokens": max_tokens},
timeout=timeout)
r.raise_for_status()
data = r.json()
# V2 returns content as list of blocks
content = data.get("message", {}).get("content", [])
if content and isinstance(content, list):
return _clean(content[0].get("text", ""))
return _clean(str(data))
def _build_chain(task_hint: str) -> list[dict]:
"""Build an ordered provider chain for the given task hint."""
hint = task_hint if task_hint in _TASK_PRIORITY else "default"
priority = _TASK_PRIORITY[hint]
models = _TASK_MODELS.get(hint, _FREE_MODELS)
chain = []
for name in priority:
if name in _AVAILABLE:
prov = _AVAILABLE[name].copy()
prov["model"] = models.get(name, _FREE_MODELS.get(name, ""))
chain.append(prov)
return chain
def call_ai(messages: list, system: str = "", max_tokens: int = 2048,
api_key_row: dict | None = None, task_hint: str = "default") -> str:
"""Call AI with smart task-based routing.
task_hint: "arabic" | "code" | "fast" | "default"
"""
if system:
messages = [{"role": "system", "content": system}] + messages
# Custom API key path (used by e.g. Wasit/Amin integrations)
if api_key_row:
provider = api_key_row.get("provider", "openai")
key = api_key_row["key"]
url = api_key_row.get("url") or _PROVIDER_URLS.get(provider, "")
model = api_key_row.get("model") or _PREMIUM_MODELS.get(provider, "gpt-4o-mini")
if not url:
raise ValueError(f"No endpoint known for provider {provider!r}")
if provider == "claude":
r = requests.post("https://api.anthropic.com/v1/messages",
headers={"x-api-key": key, "anthropic-version": "2023-06-01",
"content-type": "application/json"},
json={"model": "claude-sonnet-4-6", "max_tokens": max_tokens, "messages": messages},
timeout=60)
r.raise_for_status()
return _clean(r.json()["content"][0]["text"])
return _post_openai(url, key, model, messages, max_tokens, {})
if not _AI_AVAILABLE:
raise RuntimeError("No AI provider. Set GROQ_API_KEY or similar in .env")
# Ollama-only path
if not _AVAILABLE and _OLLAMA_PROVIDER:
r = requests.post(f"{_OLLAMA_BASE}/api/chat",
json={"model": _OLLAMA_PROVIDER["model"], "messages": messages, "stream": False},
timeout=120)
r.raise_for_status()
return _clean(r.json()["message"]["content"])
# Smart task-routed chain
chain = _build_chain(task_hint)
if not chain:
chain = _build_chain("default")
last_exc = None
for prov in chain:
try:
logger.debug("Trying %s/%s for task=%s", prov["name"], prov["model"], task_hint)
if prov["name"] == "gemini":
return _post_gemini(prov["key"], prov["model"], messages, max_tokens, prov["timeout"])
if prov["name"] == "cohere":
return _post_cohere(prov["key"], prov["model"], messages, max_tokens, prov["timeout"])
return _post_openai(
prov["url"], prov["key"], prov["model"],
messages, max_tokens, prov["extra"], prov["timeout"]
)
except requests.exceptions.HTTPError as e:
status = e.response.status_code if e.response is not None else 0
if status in (402, 429, 503, 502):
logger.debug("Provider %s returned %s, trying next", prov["name"], status)
last_exc = e
continue
raise
except (requests.exceptions.ConnectionError,
requests.exceptions.Timeout) as e:
last_exc = e
continue
# Try Ollama as last resort
if _OLLAMA_PROVIDER:
r = requests.post(f"{_OLLAMA_BASE}/api/chat",
json={"model": _OLLAMA_PROVIDER["model"], "messages": messages, "stream": False},
timeout=120)
r.raise_for_status()
return _clean(r.json()["message"]["content"])
raise last_exc or RuntimeError("All AI providers failed or rate-limited")
def _repair_json(text: str) -> str:
"""Escape literal control characters inside JSON string values."""
result = []
in_str = False
esc = False
for c in text:
if esc:
result.append(c)
esc = False
continue
if c == '\\' and in_str:
result.append(c)
esc = True
continue
if c == '"':
in_str = not in_str
result.append(c)
continue
if in_str and c == '\n':
result.append('\\n')
continue
if in_str and c == '\r':
result.append('\\r')
continue
if in_str and c == '\t':
result.append('\\t')
continue
result.append(c)
return ''.join(result)
def _extract_json(raw: str):
"""Try progressively harder to extract valid JSON from raw text."""
raw = raw.strip()
# Direct parse
try:
return json.loads(raw)
except json.JSONDecodeError:
pass
# Repair literal newlines inside strings then retry
repaired = _repair_json(raw)
try:
return json.loads(repaired)
except json.JSONDecodeError:
pass
# Find first { or [ then walk to find matching closer
for source in (repaired, raw):
for start_ch, end_ch in [('{', '}'), ('[', ']')]:
idx = source.find(start_ch)
if idx == -1:
continue
depth = 0
in_str = False
esc = False
for i in range(idx, len(source)):
c = source[i]
if esc:
esc = False
continue
if c == '\\' and in_str:
esc = True
continue
if c == '"':
in_str = not in_str
continue
if in_str:
continue
if c == start_ch:
depth += 1
elif c == end_ch:
depth -= 1
if depth == 0:
candidate = source[idx:i+1]
try:
return json.loads(candidate)
except json.JSONDecodeError:
break
raise ValueError(f"AI returned non-JSON: {raw[:200]}")
def call_ai_json(messages: list, system: str = "", max_tokens: int = 2048,
api_key_row: dict | None = None, task_hint: str = "default") -> dict | list:
raw = call_ai(messages, system=system, max_tokens=max_tokens,
api_key_row=api_key_row, task_hint=task_hint)
return _extract_json(raw)