Mohammed AL Sarraj commited on
Commit
186efee
·
1 Parent(s): e85a23d

feat: add DeepSeek, Gemini, Together AI, Cohere to AI provider stack

Browse files
.env.example CHANGED
@@ -2,4 +2,8 @@ GROQ_API_KEY=
2
  CEREBRAS_API_KEY=
3
  OPENROUTER_API_KEY=
4
  MISTRAL_API_KEY=
 
 
 
 
5
  SECRET_KEY=change-me
 
2
  CEREBRAS_API_KEY=
3
  OPENROUTER_API_KEY=
4
  MISTRAL_API_KEY=
5
+ DEEPSEEK_API_KEY=
6
+ TOGETHER_API_KEY=
7
+ COHERE_API_KEY=
8
+ GEMINI_API_KEY=
9
  SECRET_KEY=change-me
app/core/__pycache__/ai.cpython-314.pyc CHANGED
Binary files a/app/core/__pycache__/ai.cpython-314.pyc and b/app/core/__pycache__/ai.cpython-314.pyc differ
 
app/core/ai.py CHANGED
@@ -1,4 +1,12 @@
1
- """Multi-provider AI engine. Runtime chain: Groq -> Cerebras -> OpenRouter -> Mistral -> Ollama."""
 
 
 
 
 
 
 
 
2
  import json, logging, os, re, requests
3
 
4
  logger = logging.getLogger(__name__)
@@ -10,12 +18,20 @@ _PROVIDER_URLS = {
10
  "openrouter": "https://openrouter.ai/api/v1/chat/completions",
11
  "mistral": "https://api.mistral.ai/v1/chat/completions",
12
  "openai": "https://api.openai.com/v1/chat/completions",
 
 
 
13
  }
 
 
14
  _FREE_MODELS = {
15
  "groq": "llama-3.1-8b-instant",
16
  "cerebras": "llama3.1-8b",
17
  "openrouter": "google/gemma-3-12b-it:free",
18
  "mistral": "mistral-small-latest",
 
 
 
19
  }
20
  _PREMIUM_MODELS = {
21
  "groq": "llama-3.3-70b-versatile",
@@ -23,28 +39,77 @@ _PREMIUM_MODELS = {
23
  "openrouter": "google/gemma-3-27b-it:free",
24
  "mistral": "mistral-medium-latest",
25
  "openai": "gpt-4o-mini",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  }
27
- _CHAIN_CFG = [
28
- {"name": "groq", "key_env": "GROQ_API_KEY", "timeout": 30, "extra": {}},
29
- {"name": "cerebras", "key_env": "CEREBRAS_API_KEY", "timeout": 30, "extra": {}},
30
- {"name": "openrouter", "key_env": "OPENROUTER_API_KEY", "timeout": 45,
31
- "extra": {"HTTP-Referer": "https://github.com/Moealsarraj", "X-Title": "AI Tools"}},
32
- {"name": "mistral", "key_env": "MISTRAL_API_KEY", "timeout": 40, "extra": {}},
33
- ]
34
-
35
- # Build the runtime provider list — all providers with valid keys
36
- _PROVIDERS = []
37
- for _p in _CHAIN_CFG:
38
- _k = os.environ.get(_p["key_env"], "")
39
  if _k:
40
- _PROVIDERS.append({
41
- "name": _p["name"],
42
- "url": _PROVIDER_URLS[_p["name"]],
43
- "model": _FREE_MODELS[_p["name"]],
44
  "key": _k,
45
- "timeout": _p["timeout"],
46
- "extra": _p["extra"],
47
- })
48
 
49
  # Ollama fallback
50
  _OLLAMA_PROVIDER = None
@@ -57,7 +122,85 @@ try:
57
  except Exception:
58
  pass
59
 
60
- _AI_AVAILABLE = bool(_PROVIDERS or _OLLAMA_PROVIDER)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
  _RE_THINK = re.compile(r"<think>.*?</think>", re.DOTALL)
63
  _RE_OPEN = re.compile(r"^```[a-z]*\n?", re.MULTILINE)
@@ -77,8 +220,28 @@ def _post_openai(url, key, model, messages, max_tokens, extra_headers, timeout=6
77
  r.raise_for_status()
78
  return _clean(r.json()["choices"][0]["message"]["content"])
79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  def call_ai(messages: list, system: str = "", max_tokens: int = 2048,
81
- api_key_row: dict | None = None) -> str:
 
 
 
 
82
  if system:
83
  messages = [{"role": "system", "content": system}] + messages
84
  # Custom API key path (used by e.g. Wasit/Amin integrations)
@@ -101,16 +264,23 @@ def call_ai(messages: list, system: str = "", max_tokens: int = 2048,
101
  if not _AI_AVAILABLE:
102
  raise RuntimeError("No AI provider. Set GROQ_API_KEY or similar in .env")
103
  # Ollama-only path
104
- if not _PROVIDERS and _OLLAMA_PROVIDER:
105
  r = requests.post(f"{_OLLAMA_BASE}/api/chat",
106
  json={"model": _OLLAMA_PROVIDER["model"], "messages": messages, "stream": False},
107
  timeout=120)
108
  r.raise_for_status()
109
  return _clean(r.json()["message"]["content"])
110
- # Runtime chain: try each provider, fall back on 429 or transient errors
 
 
 
 
111
  last_exc = None
112
- for prov in _PROVIDERS:
113
  try:
 
 
 
114
  return _post_openai(
115
  prov["url"], prov["key"], prov["model"],
116
  messages, max_tokens, prov["extra"], prov["timeout"]
@@ -214,6 +384,7 @@ def _extract_json(raw: str):
214
  raise ValueError(f"AI returned non-JSON: {raw[:200]}")
215
 
216
  def call_ai_json(messages: list, system: str = "", max_tokens: int = 2048,
217
- api_key_row: dict | None = None) -> dict | list:
218
- raw = call_ai(messages, system=system, max_tokens=max_tokens, api_key_row=api_key_row)
 
219
  return _extract_json(raw)
 
1
+ """Multi-provider AI engine with smart task routing.
2
+
3
+ Runtime chain: Groq -> Cerebras -> OpenRouter -> Mistral -> Ollama.
4
+ Task hints route to the best model for the job:
5
+ - "arabic" → large models (70B+) for Arabic NLP quality
6
+ - "code" → code-optimized models
7
+ - "fast" → smallest/fastest model available
8
+ - "default" → standard free-tier chain
9
+ """
10
  import json, logging, os, re, requests
11
 
12
  logger = logging.getLogger(__name__)
 
18
  "openrouter": "https://openrouter.ai/api/v1/chat/completions",
19
  "mistral": "https://api.mistral.ai/v1/chat/completions",
20
  "openai": "https://api.openai.com/v1/chat/completions",
21
+ "deepseek": "https://api.deepseek.com/chat/completions",
22
+ "together": "https://api.together.xyz/v1/chat/completions",
23
+ "cohere": "https://api.cohere.com/v2/chat",
24
  }
25
+
26
+ # ── Model tiers per provider ──
27
  _FREE_MODELS = {
28
  "groq": "llama-3.1-8b-instant",
29
  "cerebras": "llama3.1-8b",
30
  "openrouter": "google/gemma-3-12b-it:free",
31
  "mistral": "mistral-small-latest",
32
+ "deepseek": "deepseek-chat",
33
+ "together": "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
34
+ "cohere": "command-r",
35
  }
36
  _PREMIUM_MODELS = {
37
  "groq": "llama-3.3-70b-versatile",
 
39
  "openrouter": "google/gemma-3-27b-it:free",
40
  "mistral": "mistral-medium-latest",
41
  "openai": "gpt-4o-mini",
42
+ "deepseek": "deepseek-chat",
43
+ "together": "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
44
+ "cohere": "command-r-plus",
45
+ }
46
+
47
+ # ── Task-specific model routing ──
48
+ # Maps task hints to the best model per provider.
49
+ # "arabic" needs large models for Arabic morphology, grammar, dialect awareness.
50
+ # "code" needs code-tuned models for test generation, SQL, schema analysis.
51
+ # "fast" uses smallest models for quick responses.
52
+ _TASK_MODELS = {
53
+ "arabic": {
54
+ "groq": "llama-3.3-70b-versatile",
55
+ "cerebras": "qwen-3-235b-a22b-instruct-2507",
56
+ "openrouter": "google/gemma-3-27b-it:free",
57
+ "mistral": "mistral-medium-latest",
58
+ "deepseek": "deepseek-chat",
59
+ "together": "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
60
+ "cohere": "command-r-plus",
61
+ },
62
+ "code": {
63
+ "groq": "llama-3.3-70b-versatile",
64
+ "cerebras": "qwen-3-235b-a22b-instruct-2507",
65
+ "openrouter": "google/gemma-3-27b-it:free",
66
+ "mistral": "mistral-medium-latest",
67
+ "deepseek": "deepseek-chat",
68
+ "together": "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
69
+ "cohere": "command-r-plus",
70
+ },
71
+ "fast": {
72
+ "groq": "llama-3.1-8b-instant",
73
+ "cerebras": "llama3.1-8b",
74
+ "openrouter": "google/gemma-3-12b-it:free",
75
+ "mistral": "mistral-small-latest",
76
+ "deepseek": "deepseek-chat",
77
+ "together": "meta-llama/Llama-3.3-70B-Instruct-Turbo-Free",
78
+ "cohere": "command-r",
79
+ },
80
+ }
81
+
82
+ # ── Task-specific provider priority ──
83
+ _TASK_PRIORITY = {
84
+ "arabic": ["cerebras", "deepseek", "groq", "together", "openrouter", "cohere", "mistral"],
85
+ "code": ["deepseek", "groq", "cerebras", "together", "openrouter", "cohere", "mistral"],
86
+ "fast": ["cerebras", "groq", "together", "deepseek", "openrouter", "cohere", "mistral"],
87
+ "default": ["groq", "cerebras", "deepseek", "together", "openrouter", "cohere", "mistral"],
88
+ }
89
+
90
+ _CHAIN_CFG = {
91
+ "groq": {"key_env": "GROQ_API_KEY", "timeout": 30, "extra": {}},
92
+ "cerebras": {"key_env": "CEREBRAS_API_KEY", "timeout": 30, "extra": {}},
93
+ "openrouter": {"key_env": "OPENROUTER_API_KEY", "timeout": 45,
94
+ "extra": {"HTTP-Referer": "https://github.com/Moealsarraj", "X-Title": "AI Tools"}},
95
+ "mistral": {"key_env": "MISTRAL_API_KEY", "timeout": 40, "extra": {}},
96
+ "deepseek": {"key_env": "DEEPSEEK_API_KEY", "timeout": 60, "extra": {}},
97
+ "together": {"key_env": "TOGETHER_API_KEY", "timeout": 45, "extra": {}},
98
+ "cohere": {"key_env": "COHERE_API_KEY", "timeout": 45, "extra": {}},
99
  }
100
+
101
+ # Build available providers (those with valid keys)
102
+ _AVAILABLE = {}
103
+ for _name, _cfg in _CHAIN_CFG.items():
104
+ _k = os.environ.get(_cfg["key_env"], "")
 
 
 
 
 
 
 
105
  if _k:
106
+ _AVAILABLE[_name] = {
107
+ "name": _name,
108
+ "url": _PROVIDER_URLS[_name],
 
109
  "key": _k,
110
+ "timeout": _cfg["timeout"],
111
+ "extra": _cfg["extra"],
112
+ }
113
 
114
  # Ollama fallback
115
  _OLLAMA_PROVIDER = None
 
122
  except Exception:
123
  pass
124
 
125
+ # ── Google Gemini (special API format) ──
126
+ _GEMINI_KEY = os.environ.get("GEMINI_API_KEY", "")
127
+ if _GEMINI_KEY:
128
+ _AVAILABLE["gemini"] = {
129
+ "name": "gemini",
130
+ "url": "https://generativelanguage.googleapis.com/v1beta/models",
131
+ "key": _GEMINI_KEY,
132
+ "timeout": 60,
133
+ "extra": {},
134
+ }
135
+ _FREE_MODELS["gemini"] = "gemini-2.0-flash"
136
+ _PREMIUM_MODELS["gemini"] = "gemini-2.0-flash"
137
+ for task in _TASK_MODELS:
138
+ _TASK_MODELS[task]["gemini"] = "gemini-2.0-flash"
139
+ for task in _TASK_PRIORITY:
140
+ if "gemini" not in _TASK_PRIORITY[task]:
141
+ _TASK_PRIORITY[task].insert(2, "gemini")
142
+
143
+ _AI_AVAILABLE = bool(_AVAILABLE or _OLLAMA_PROVIDER)
144
+
145
+
146
+ def _post_gemini(key: str, model: str, messages: list, max_tokens: int, timeout: int = 60) -> str:
147
+ """Call Google Gemini API (non-OpenAI format)."""
148
+ # Convert OpenAI message format to Gemini format
149
+ contents = []
150
+ system_text = ""
151
+ for msg in messages:
152
+ role = msg["role"]
153
+ if role == "system":
154
+ system_text = msg["content"]
155
+ continue
156
+ contents.append({
157
+ "role": "user" if role == "user" else "model",
158
+ "parts": [{"text": msg["content"]}],
159
+ })
160
+
161
+ body = {
162
+ "contents": contents,
163
+ "generationConfig": {"maxOutputTokens": max_tokens},
164
+ }
165
+ if system_text:
166
+ body["systemInstruction"] = {"parts": [{"text": system_text}]}
167
+
168
+ url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={key}"
169
+ r = requests.post(url, json=body, timeout=timeout)
170
+ r.raise_for_status()
171
+ data = r.json()
172
+ return _clean(data["candidates"][0]["content"]["parts"][0]["text"])
173
+
174
+
175
+ def get_available_providers() -> list[dict]:
176
+ """Return list of available providers with their model info."""
177
+ providers = []
178
+ for name, prov in _AVAILABLE.items():
179
+ providers.append({
180
+ "name": name,
181
+ "model_free": _FREE_MODELS.get(name, ""),
182
+ "model_premium": _PREMIUM_MODELS.get(name, ""),
183
+ })
184
+ return providers
185
+
186
+
187
+ def call_ai_single(provider_name: str, messages: list, system: str = "",
188
+ max_tokens: int = 2048, use_premium: bool = True) -> str:
189
+ """Call a specific provider directly (no fallback chain)."""
190
+ if provider_name not in _AVAILABLE:
191
+ raise ValueError(f"Provider {provider_name!r} not available")
192
+ prov = _AVAILABLE[provider_name]
193
+ models = _PREMIUM_MODELS if use_premium else _FREE_MODELS
194
+ model = models.get(provider_name, prov.get("model", ""))
195
+ if system:
196
+ messages = [{"role": "system", "content": system}] + messages
197
+ if provider_name == "gemini":
198
+ return _post_gemini(prov["key"], model, messages, max_tokens, prov["timeout"])
199
+ return _post_openai(
200
+ prov["url"], prov["key"], model,
201
+ messages, max_tokens, prov["extra"], prov["timeout"]
202
+ )
203
+
204
 
205
  _RE_THINK = re.compile(r"<think>.*?</think>", re.DOTALL)
206
  _RE_OPEN = re.compile(r"^```[a-z]*\n?", re.MULTILINE)
 
220
  r.raise_for_status()
221
  return _clean(r.json()["choices"][0]["message"]["content"])
222
 
223
+
224
+ def _build_chain(task_hint: str) -> list[dict]:
225
+ """Build an ordered provider chain for the given task hint."""
226
+ hint = task_hint if task_hint in _TASK_PRIORITY else "default"
227
+ priority = _TASK_PRIORITY[hint]
228
+ models = _TASK_MODELS.get(hint, _FREE_MODELS)
229
+
230
+ chain = []
231
+ for name in priority:
232
+ if name in _AVAILABLE:
233
+ prov = _AVAILABLE[name].copy()
234
+ prov["model"] = models.get(name, _FREE_MODELS.get(name, ""))
235
+ chain.append(prov)
236
+ return chain
237
+
238
+
239
  def call_ai(messages: list, system: str = "", max_tokens: int = 2048,
240
+ api_key_row: dict | None = None, task_hint: str = "default") -> str:
241
+ """Call AI with smart task-based routing.
242
+
243
+ task_hint: "arabic" | "code" | "fast" | "default"
244
+ """
245
  if system:
246
  messages = [{"role": "system", "content": system}] + messages
247
  # Custom API key path (used by e.g. Wasit/Amin integrations)
 
264
  if not _AI_AVAILABLE:
265
  raise RuntimeError("No AI provider. Set GROQ_API_KEY or similar in .env")
266
  # Ollama-only path
267
+ if not _AVAILABLE and _OLLAMA_PROVIDER:
268
  r = requests.post(f"{_OLLAMA_BASE}/api/chat",
269
  json={"model": _OLLAMA_PROVIDER["model"], "messages": messages, "stream": False},
270
  timeout=120)
271
  r.raise_for_status()
272
  return _clean(r.json()["message"]["content"])
273
+ # Smart task-routed chain
274
+ chain = _build_chain(task_hint)
275
+ if not chain:
276
+ chain = _build_chain("default")
277
+
278
  last_exc = None
279
+ for prov in chain:
280
  try:
281
+ logger.debug("Trying %s/%s for task=%s", prov["name"], prov["model"], task_hint)
282
+ if prov["name"] == "gemini":
283
+ return _post_gemini(prov["key"], prov["model"], messages, max_tokens, prov["timeout"])
284
  return _post_openai(
285
  prov["url"], prov["key"], prov["model"],
286
  messages, max_tokens, prov["extra"], prov["timeout"]
 
384
  raise ValueError(f"AI returned non-JSON: {raw[:200]}")
385
 
386
  def call_ai_json(messages: list, system: str = "", max_tokens: int = 2048,
387
+ api_key_row: dict | None = None, task_hint: str = "default") -> dict | list:
388
+ raw = call_ai(messages, system=system, max_tokens=max_tokens,
389
+ api_key_row=api_key_row, task_hint=task_hint)
390
  return _extract_json(raw)