Spaces:
Sleeping
Sleeping
| """Multi-provider AI engine with smart task routing. | |
| Runtime chain: Groq -> Cerebras -> OpenRouter -> Mistral -> Ollama. | |
| Task hints route to the best model for the job: | |
| - "arabic" → large models (70B+) for Arabic NLP quality | |
| - "code" → code-optimized models | |
| - "fast" → smallest/fastest model available | |
| - "default" → standard free-tier chain | |
| """ | |
| import json, logging, os, re, requests | |
| logger = logging.getLogger(__name__) | |
| _OLLAMA_BASE = "http://localhost:11434" | |
| _PROVIDER_URLS = { | |
| "groq": "https://api.groq.com/openai/v1/chat/completions", | |
| "cerebras": "https://api.cerebras.ai/v1/chat/completions", | |
| "openrouter": "https://openrouter.ai/api/v1/chat/completions", | |
| "mistral": "https://api.mistral.ai/v1/chat/completions", | |
| "openai": "https://api.openai.com/v1/chat/completions", | |
| "cohere": "https://api.cohere.com/v2/chat", | |
| } | |
| # ── Model tiers per provider ── | |
| _FREE_MODELS = { | |
| "groq": "llama-3.1-8b-instant", | |
| "cerebras": "llama3.1-8b", | |
| "openrouter": "google/gemma-3-12b-it:free", | |
| "mistral": "mistral-small-latest", | |
| "cohere": "command-r-08-2024", | |
| } | |
| _PREMIUM_MODELS = { | |
| "groq": "llama-3.3-70b-versatile", | |
| "cerebras": "qwen-3-235b-a22b-instruct-2507", | |
| "openrouter": "google/gemma-3-27b-it:free", | |
| "mistral": "mistral-medium-latest", | |
| "openai": "gpt-4o-mini", | |
| "cohere": "command-r-08-2024", | |
| } | |
| # ── Task-specific model routing ── | |
| # Maps task hints to the best model per provider. | |
| # "arabic" needs large models for Arabic morphology, grammar, dialect awareness. | |
| # "code" needs code-tuned models for test generation, SQL, schema analysis. | |
| # "fast" uses smallest models for quick responses. | |
| _TASK_MODELS = { | |
| "arabic": { | |
| "groq": "llama-3.3-70b-versatile", | |
| "cerebras": "qwen-3-235b-a22b-instruct-2507", | |
| "openrouter": "google/gemma-3-27b-it:free", | |
| "mistral": "mistral-medium-latest", | |
| "cohere": "command-r7b-arabic-02-2025", | |
| }, | |
| "code": { | |
| "groq": "llama-3.3-70b-versatile", | |
| "cerebras": "qwen-3-235b-a22b-instruct-2507", | |
| "openrouter": "google/gemma-3-27b-it:free", | |
| "mistral": "mistral-medium-latest", | |
| "cohere": "command-r-08-2024", | |
| }, | |
| "fast": { | |
| "groq": "llama-3.1-8b-instant", | |
| "cerebras": "llama3.1-8b", | |
| "openrouter": "google/gemma-3-12b-it:free", | |
| "mistral": "mistral-small-latest", | |
| "cohere": "command-r-08-2024", | |
| }, | |
| } | |
| # ── Task-specific provider priority ── | |
| _TASK_PRIORITY = { | |
| "arabic": ["cerebras", "groq", "openrouter", "cohere", "mistral"], | |
| "code": ["groq", "cerebras", "openrouter", "cohere", "mistral"], | |
| "fast": ["cerebras", "groq", "openrouter", "cohere", "mistral"], | |
| "default": ["groq", "cerebras", "openrouter", "cohere", "mistral"], | |
| } | |
| _CHAIN_CFG = { | |
| "groq": {"key_env": "GROQ_API_KEY", "timeout": 30, "extra": {}}, | |
| "cerebras": {"key_env": "CEREBRAS_API_KEY", "timeout": 30, "extra": {}}, | |
| "openrouter": {"key_env": "OPENROUTER_API_KEY", "timeout": 45, | |
| "extra": {"HTTP-Referer": "https://github.com/Moealsarraj", "X-Title": "AI Tools"}}, | |
| "mistral": {"key_env": "MISTRAL_API_KEY", "timeout": 40, "extra": {}}, | |
| "cohere": {"key_env": "COHERE_API_KEY", "timeout": 45, "extra": {}}, | |
| } | |
| # Build available providers (those with valid keys) | |
| _AVAILABLE = {} | |
| for _name, _cfg in _CHAIN_CFG.items(): | |
| _k = os.environ.get(_cfg["key_env"], "") | |
| if _k: | |
| _AVAILABLE[_name] = { | |
| "name": _name, | |
| "url": _PROVIDER_URLS[_name], | |
| "key": _k, | |
| "timeout": _cfg["timeout"], | |
| "extra": _cfg["extra"], | |
| } | |
| # Ollama fallback | |
| _OLLAMA_PROVIDER = None | |
| try: | |
| _r = requests.get(f"{_OLLAMA_BASE}/api/tags", timeout=3) | |
| if _r.status_code == 200: | |
| _installed = [m["name"] for m in _r.json().get("models", [])] | |
| if _installed: | |
| _OLLAMA_PROVIDER = {"name": "ollama", "model": _installed[0]} | |
| except Exception: | |
| pass | |
| # ── Google Gemini (special API format) ── | |
| _GEMINI_KEY = os.environ.get("GEMINI_API_KEY", "") | |
| if _GEMINI_KEY: | |
| _AVAILABLE["gemini"] = { | |
| "name": "gemini", | |
| "url": "https://generativelanguage.googleapis.com/v1beta/models", | |
| "key": _GEMINI_KEY, | |
| "timeout": 60, | |
| "extra": {}, | |
| } | |
| _FREE_MODELS["gemini"] = "gemini-2.0-flash" | |
| _PREMIUM_MODELS["gemini"] = "gemini-2.0-flash" | |
| for task in _TASK_MODELS: | |
| _TASK_MODELS[task]["gemini"] = "gemini-2.0-flash" | |
| for task in _TASK_PRIORITY: | |
| if "gemini" not in _TASK_PRIORITY[task]: | |
| _TASK_PRIORITY[task].insert(2, "gemini") | |
| _AI_AVAILABLE = bool(_AVAILABLE or _OLLAMA_PROVIDER) | |
| def _post_gemini(key: str, model: str, messages: list, max_tokens: int, timeout: int = 60) -> str: | |
| """Call Google Gemini API (non-OpenAI format).""" | |
| # Convert OpenAI message format to Gemini format | |
| contents = [] | |
| system_text = "" | |
| for msg in messages: | |
| role = msg["role"] | |
| if role == "system": | |
| system_text = msg["content"] | |
| continue | |
| contents.append({ | |
| "role": "user" if role == "user" else "model", | |
| "parts": [{"text": msg["content"]}], | |
| }) | |
| body = { | |
| "contents": contents, | |
| "generationConfig": {"maxOutputTokens": max_tokens}, | |
| } | |
| if system_text: | |
| body["systemInstruction"] = {"parts": [{"text": system_text}]} | |
| url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={key}" | |
| r = requests.post(url, json=body, timeout=timeout) | |
| r.raise_for_status() | |
| data = r.json() | |
| return _clean(data["candidates"][0]["content"]["parts"][0]["text"]) | |
| def get_available_providers() -> list[dict]: | |
| """Return list of available providers with their model info.""" | |
| providers = [] | |
| for name, prov in _AVAILABLE.items(): | |
| providers.append({ | |
| "name": name, | |
| "model_free": _FREE_MODELS.get(name, ""), | |
| "model_premium": _PREMIUM_MODELS.get(name, ""), | |
| }) | |
| return providers | |
| def call_ai_single(provider_name: str, messages: list, system: str = "", | |
| max_tokens: int = 2048, use_premium: bool = True) -> str: | |
| """Call a specific provider directly (no fallback chain).""" | |
| if provider_name not in _AVAILABLE: | |
| raise ValueError(f"Provider {provider_name!r} not available") | |
| prov = _AVAILABLE[provider_name] | |
| models = _PREMIUM_MODELS if use_premium else _FREE_MODELS | |
| model = models.get(provider_name, prov.get("model", "")) | |
| if system: | |
| messages = [{"role": "system", "content": system}] + messages | |
| if provider_name == "gemini": | |
| return _post_gemini(prov["key"], model, messages, max_tokens, prov["timeout"]) | |
| if provider_name == "cohere": | |
| return _post_cohere(prov["key"], model, messages, max_tokens, prov["timeout"]) | |
| return _post_openai( | |
| prov["url"], prov["key"], model, | |
| messages, max_tokens, prov["extra"], prov["timeout"] | |
| ) | |
| _RE_THINK = re.compile(r"<think>.*?</think>", re.DOTALL) | |
| _RE_OPEN = re.compile(r"^```[a-z]*\n?", re.MULTILINE) | |
| _RE_CLOSE = re.compile(r"\n?```$", re.MULTILINE) | |
| def _clean(raw: str) -> str: | |
| raw = _RE_THINK.sub("", raw).strip() | |
| raw = _RE_OPEN.sub("", raw) | |
| return _RE_CLOSE.sub("", raw).strip() | |
| def _post_openai(url, key, model, messages, max_tokens, extra_headers, timeout=60): | |
| headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"} | |
| headers.update(extra_headers) | |
| r = requests.post(url, headers=headers, | |
| json={"model": model, "messages": messages, "max_tokens": max_tokens}, | |
| timeout=timeout) | |
| r.raise_for_status() | |
| return _clean(r.json()["choices"][0]["message"]["content"]) | |
| def _post_cohere(key: str, model: str, messages: list, max_tokens: int, timeout: int = 45) -> str: | |
| """Call Cohere V2 Chat API.""" | |
| headers = {"Authorization": f"Bearer {key}", "Content-Type": "application/json"} | |
| r = requests.post("https://api.cohere.com/v2/chat", | |
| headers=headers, | |
| json={"model": model, "messages": messages, "max_tokens": max_tokens}, | |
| timeout=timeout) | |
| r.raise_for_status() | |
| data = r.json() | |
| # V2 returns content as list of blocks | |
| content = data.get("message", {}).get("content", []) | |
| if content and isinstance(content, list): | |
| return _clean(content[0].get("text", "")) | |
| return _clean(str(data)) | |
| def _build_chain(task_hint: str) -> list[dict]: | |
| """Build an ordered provider chain for the given task hint.""" | |
| hint = task_hint if task_hint in _TASK_PRIORITY else "default" | |
| priority = _TASK_PRIORITY[hint] | |
| models = _TASK_MODELS.get(hint, _FREE_MODELS) | |
| chain = [] | |
| for name in priority: | |
| if name in _AVAILABLE: | |
| prov = _AVAILABLE[name].copy() | |
| prov["model"] = models.get(name, _FREE_MODELS.get(name, "")) | |
| chain.append(prov) | |
| return chain | |
| def call_ai(messages: list, system: str = "", max_tokens: int = 2048, | |
| api_key_row: dict | None = None, task_hint: str = "default") -> str: | |
| """Call AI with smart task-based routing. | |
| task_hint: "arabic" | "code" | "fast" | "default" | |
| """ | |
| if system: | |
| messages = [{"role": "system", "content": system}] + messages | |
| # Custom API key path (used by e.g. Wasit/Amin integrations) | |
| if api_key_row: | |
| provider = api_key_row.get("provider", "openai") | |
| key = api_key_row["key"] | |
| url = api_key_row.get("url") or _PROVIDER_URLS.get(provider, "") | |
| model = api_key_row.get("model") or _PREMIUM_MODELS.get(provider, "gpt-4o-mini") | |
| if not url: | |
| raise ValueError(f"No endpoint known for provider {provider!r}") | |
| if provider == "claude": | |
| r = requests.post("https://api.anthropic.com/v1/messages", | |
| headers={"x-api-key": key, "anthropic-version": "2023-06-01", | |
| "content-type": "application/json"}, | |
| json={"model": "claude-sonnet-4-6", "max_tokens": max_tokens, "messages": messages}, | |
| timeout=60) | |
| r.raise_for_status() | |
| return _clean(r.json()["content"][0]["text"]) | |
| return _post_openai(url, key, model, messages, max_tokens, {}) | |
| if not _AI_AVAILABLE: | |
| raise RuntimeError("No AI provider. Set GROQ_API_KEY or similar in .env") | |
| # Ollama-only path | |
| if not _AVAILABLE and _OLLAMA_PROVIDER: | |
| r = requests.post(f"{_OLLAMA_BASE}/api/chat", | |
| json={"model": _OLLAMA_PROVIDER["model"], "messages": messages, "stream": False}, | |
| timeout=120) | |
| r.raise_for_status() | |
| return _clean(r.json()["message"]["content"]) | |
| # Smart task-routed chain | |
| chain = _build_chain(task_hint) | |
| if not chain: | |
| chain = _build_chain("default") | |
| last_exc = None | |
| for prov in chain: | |
| try: | |
| logger.debug("Trying %s/%s for task=%s", prov["name"], prov["model"], task_hint) | |
| if prov["name"] == "gemini": | |
| return _post_gemini(prov["key"], prov["model"], messages, max_tokens, prov["timeout"]) | |
| if prov["name"] == "cohere": | |
| return _post_cohere(prov["key"], prov["model"], messages, max_tokens, prov["timeout"]) | |
| return _post_openai( | |
| prov["url"], prov["key"], prov["model"], | |
| messages, max_tokens, prov["extra"], prov["timeout"] | |
| ) | |
| except requests.exceptions.HTTPError as e: | |
| status = e.response.status_code if e.response is not None else 0 | |
| if status in (402, 429, 503, 502): | |
| logger.debug("Provider %s returned %s, trying next", prov["name"], status) | |
| last_exc = e | |
| continue | |
| raise | |
| except (requests.exceptions.ConnectionError, | |
| requests.exceptions.Timeout) as e: | |
| last_exc = e | |
| continue | |
| # Try Ollama as last resort | |
| if _OLLAMA_PROVIDER: | |
| r = requests.post(f"{_OLLAMA_BASE}/api/chat", | |
| json={"model": _OLLAMA_PROVIDER["model"], "messages": messages, "stream": False}, | |
| timeout=120) | |
| r.raise_for_status() | |
| return _clean(r.json()["message"]["content"]) | |
| raise last_exc or RuntimeError("All AI providers failed or rate-limited") | |
| def _repair_json(text: str) -> str: | |
| """Escape literal control characters inside JSON string values.""" | |
| result = [] | |
| in_str = False | |
| esc = False | |
| for c in text: | |
| if esc: | |
| result.append(c) | |
| esc = False | |
| continue | |
| if c == '\\' and in_str: | |
| result.append(c) | |
| esc = True | |
| continue | |
| if c == '"': | |
| in_str = not in_str | |
| result.append(c) | |
| continue | |
| if in_str and c == '\n': | |
| result.append('\\n') | |
| continue | |
| if in_str and c == '\r': | |
| result.append('\\r') | |
| continue | |
| if in_str and c == '\t': | |
| result.append('\\t') | |
| continue | |
| result.append(c) | |
| return ''.join(result) | |
| def _extract_json(raw: str): | |
| """Try progressively harder to extract valid JSON from raw text.""" | |
| raw = raw.strip() | |
| # Direct parse | |
| try: | |
| return json.loads(raw) | |
| except json.JSONDecodeError: | |
| pass | |
| # Repair literal newlines inside strings then retry | |
| repaired = _repair_json(raw) | |
| try: | |
| return json.loads(repaired) | |
| except json.JSONDecodeError: | |
| pass | |
| # Find first { or [ then walk to find matching closer | |
| for source in (repaired, raw): | |
| for start_ch, end_ch in [('{', '}'), ('[', ']')]: | |
| idx = source.find(start_ch) | |
| if idx == -1: | |
| continue | |
| depth = 0 | |
| in_str = False | |
| esc = False | |
| for i in range(idx, len(source)): | |
| c = source[i] | |
| if esc: | |
| esc = False | |
| continue | |
| if c == '\\' and in_str: | |
| esc = True | |
| continue | |
| if c == '"': | |
| in_str = not in_str | |
| continue | |
| if in_str: | |
| continue | |
| if c == start_ch: | |
| depth += 1 | |
| elif c == end_ch: | |
| depth -= 1 | |
| if depth == 0: | |
| candidate = source[idx:i+1] | |
| try: | |
| return json.loads(candidate) | |
| except json.JSONDecodeError: | |
| break | |
| raise ValueError(f"AI returned non-JSON: {raw[:200]}") | |
| def call_ai_json(messages: list, system: str = "", max_tokens: int = 2048, | |
| api_key_row: dict | None = None, task_hint: str = "default") -> dict | list: | |
| raw = call_ai(messages, system=system, max_tokens=max_tokens, | |
| api_key_row=api_key_row, task_hint=task_hint) | |
| return _extract_json(raw) | |