PYAE1994 commited on
Commit
b84e91f
Β·
verified Β·
1 Parent(s): 75e14b2

Fix: update ai_router/router.py

Browse files
Files changed (1) hide show
  1. ai_router/router.py +180 -178
ai_router/router.py CHANGED
@@ -1,8 +1,7 @@
1
  """
2
- πŸš€ GOD MODE+ LLM Router β€” Unified AI Gateway
3
- Primary: Cloudflare AI Gateway β†’ Groq β†’ OpenAI β†’ HF Inference
4
- ALL LLM calls MUST go through LLMRouter.ask()
5
- No direct API calls allowed anywhere else.
6
  """
7
 
8
  import asyncio
@@ -16,247 +15,250 @@ import structlog
16
 
17
  log = structlog.get_logger()
18
 
19
- # ─── Gateway Config ────────────────────────────────────────────────────────────
20
- CF_GATEWAY_URL = os.environ.get(
21
- "CF_GATEWAY_URL",
22
- "https://gateway.pyaesone-gtckglay.workers.dev/v1/chat/completions"
23
- )
24
- CF_GATEWAY_KEY = os.environ.get("CF_GATEWAY_KEY", "")
25
-
26
- GROQ_API_KEY = os.environ.get("GROQ_API_KEY", "")
27
- OPENAI_API_KEY = os.environ.get("OPENAI_API_KEY", "")
28
- HF_TOKEN = os.environ.get("HF_TOKEN", "")
29
-
30
- DEFAULT_MODEL = os.environ.get("DEFAULT_MODEL", "meta-llama/llama-4-scout-17b-16e-instruct")
31
- FALLBACK_MODEL = "llama-3.3-70b-versatile"
32
-
33
-
34
- # ─── Provider Fallback Chain ───────────────────────────────────────────────────
35
- PROVIDER_CHAIN = [
36
  {
37
- "name": "cloudflare_gateway",
38
- "url": CF_GATEWAY_URL,
39
- "key_fn": lambda: CF_GATEWAY_KEY or OPENAI_API_KEY or GROQ_API_KEY,
40
- "model": DEFAULT_MODEL,
41
- "enabled": lambda: bool(CF_GATEWAY_KEY or OPENAI_API_KEY or GROQ_API_KEY),
42
  },
43
  {
44
- "name": "groq",
45
- "url": "https://api.groq.com/openai/v1/chat/completions",
46
- "key_fn": lambda: GROQ_API_KEY,
47
- "model": FALLBACK_MODEL,
48
- "enabled": lambda: bool(GROQ_API_KEY),
49
  },
50
  {
51
- "name": "openai",
52
- "url": os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1") + "/chat/completions",
53
- "key_fn": lambda: OPENAI_API_KEY,
54
- "model": "gpt-4o",
55
- "enabled": lambda: bool(OPENAI_API_KEY),
56
  },
57
  {
58
- "name": "hf_inference",
59
- "url": "https://api-inference.huggingface.co/models/meta-llama/Meta-Llama-3-8B-Instruct/v1/chat/completions",
60
- "key_fn": lambda: HF_TOKEN,
61
- "model": "meta-llama/Meta-Llama-3-8B-Instruct",
62
- "enabled": lambda: bool(HF_TOKEN),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  },
64
  ]
65
 
66
 
67
- class LLMRouter:
68
  """
69
- Unified LLM Router β€” all agents call LLMRouter.ask()
70
- Automatic failover: CF Gateway β†’ Groq β†’ OpenAI β†’ HF
71
- Supports streaming via WebSocket emit.
72
  """
73
 
74
  def __init__(self, ws_manager=None):
75
  self.ws = ws_manager
76
- self._stats: Dict[str, Dict] = {
77
- p["name"]: {"calls": 0, "errors": 0, "latency_ms": []}
78
- for p in PROVIDER_CHAIN
79
- }
80
 
81
- # ─── PRIMARY ENTRY POINT ──────────────────────────────────────────────────
 
82
 
83
- async def ask(
 
 
 
 
 
 
84
  self,
85
  messages: List[Dict],
86
  task_id: str = "",
87
  session_id: str = "",
88
  temperature: float = 0.7,
89
  max_tokens: int = 4096,
90
- model: str = "",
91
  stream: bool = True,
92
  ) -> str:
93
- """
94
- Route LLM call through provider chain with automatic failover.
95
- Returns full response text.
96
- """
97
- active_providers = [p for p in PROVIDER_CHAIN if p["enabled"]()]
98
 
99
- if not active_providers:
100
- log.warning("No LLM providers available β€” returning demo response")
101
- return await self._demo_response(messages, task_id, session_id)
102
 
103
  last_error = None
104
- for provider in active_providers:
105
  try:
106
  start = time.time()
107
- result = await self._call_provider(
108
- provider=provider,
109
- messages=messages,
110
- task_id=task_id,
111
- session_id=session_id,
112
- temperature=temperature,
113
- max_tokens=max_tokens,
114
- model_override=model,
115
- stream=stream,
116
- )
117
- elapsed_ms = round((time.time() - start) * 1000)
118
  self._stats[provider["name"]]["calls"] += 1
119
- self._stats[provider["name"]]["latency_ms"].append(elapsed_ms)
120
- log.info("LLMRouter success", provider=provider["name"], ms=elapsed_ms, chars=len(result))
121
  return result
122
-
123
  except Exception as e:
124
  last_error = e
125
  self._stats[provider["name"]]["errors"] += 1
126
- log.warning("LLMRouter failover", provider=provider["name"], error=str(e)[:200])
127
  continue
128
 
129
- log.error("All LLM providers failed", last_error=str(last_error))
130
- return await self._demo_response(messages, task_id, session_id)
131
 
132
- # ─── Provider Call ─────────────────────────────────────────────────────────
133
 
134
- async def _call_provider(
135
- self,
136
- provider: Dict,
137
- messages: List[Dict],
138
- task_id: str,
139
- session_id: str,
140
- temperature: float,
141
- max_tokens: int,
142
- model_override: str,
143
- stream: bool,
144
  ) -> str:
145
- key = provider["key_fn"]()
146
- model = model_override or provider["model"]
147
- url = provider["url"]
148
-
149
- headers = {
150
- "Authorization": f"Bearer {key}",
151
- "Content-Type": "application/json",
152
- }
153
  payload = {
154
  "model": model,
155
  "messages": messages,
 
156
  "temperature": temperature,
157
  "max_tokens": max_tokens,
158
- "stream": stream,
159
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  full_text = ""
162
- async with httpx.AsyncClient(timeout=120.0) as client:
163
- if stream:
164
- async with client.stream("POST", url, headers=headers, json=payload) as resp:
165
- resp.raise_for_status()
166
- async for line in resp.aiter_lines():
167
- if not line.startswith("data:"):
168
- continue
169
- chunk_str = line[5:].strip()
170
- if chunk_str == "[DONE]":
171
- break
172
- try:
173
- data = json.loads(chunk_str)
174
- delta = (
175
- data.get("choices", [{}])[0]
176
- .get("delta", {})
177
- .get("content", "")
178
- )
179
  if delta:
180
  full_text += delta
181
- await self._emit(delta, task_id, session_id)
182
- except Exception:
183
- pass
184
- else:
185
- resp = await client.post(url, headers=headers, json={**payload, "stream": False})
186
- resp.raise_for_status()
187
- data = resp.json()
188
- full_text = (
189
- data.get("choices", [{}])[0]
190
- .get("message", {})
191
- .get("content", "")
192
- )
193
-
194
  return full_text
195
 
196
- # ─── Demo Response ─────────────────────────────────────────────────────────
197
 
198
- async def _demo_response(self, messages: List[Dict], task_id: str, session_id: str) -> str:
199
  last_user = next(
200
  (m["content"] for m in reversed(messages) if m["role"] == "user"), "Hello"
201
  )
202
- text = (
203
- f"πŸ€– **God Mode+ AI** (No API Key β€” Demo Mode)\n\n"
204
- f"Received: *{last_user[:120]}*\n\n"
205
- f"To enable real AI, set one of:\n"
206
- f"- `CF_GATEWAY_KEY` (Cloudflare Gateway β€” recommended)\n"
207
- f"- `GROQ_API_KEY` (Groq Llama 3.3 70B β€” free)\n"
208
  f"- `OPENAI_API_KEY` (GPT-4o)\n"
209
- f"- `HF_TOKEN` (HuggingFace)\n\n"
210
- f"**System Status:** All 10 agents online βœ…"
 
 
 
 
 
 
 
 
211
  )
212
- full = ""
213
- for word in text.split():
214
  chunk = word + " "
215
- full += chunk
216
- await asyncio.sleep(0.015)
217
- await self._emit(chunk, task_id, session_id, demo=True)
218
- return full
219
 
220
- # ─── Emit helper ──────────────────────────────────────────────────────────
221
 
222
- async def _emit(self, chunk: str, task_id: str, session_id: str, demo: bool = False):
223
  if not self.ws:
224
  return
225
  payload = {"chunk": chunk, "demo": demo}
226
- try:
227
- if task_id:
228
- await self.ws.emit(task_id, "llm_chunk", payload, session_id=session_id)
229
- elif session_id:
230
- await self.ws.emit_chat(session_id, "llm_chunk", payload)
231
- except Exception:
232
- pass
233
 
234
- # ─── Stats ─────────────────────────────────────────────────────────────────
235
 
236
  def get_stats(self) -> Dict:
237
- result = {}
238
  for name, s in self._stats.items():
239
- lats = s["latency_ms"][-20:]
240
- avg = round(sum(lats) / max(len(lats), 1), 1)
241
- provider = next((p for p in PROVIDER_CHAIN if p["name"] == name), None)
242
- result[name] = {
243
- "calls": s["calls"],
244
- "errors": s["errors"],
245
- "avg_latency_ms": avg,
246
- "available": bool(provider and provider["enabled"]()),
247
  }
248
- return result
249
-
250
- def get_active_provider(self) -> str:
251
- for p in PROVIDER_CHAIN:
252
- if p["enabled"]():
253
- return p["name"]
254
- return "demo"
255
-
256
-
257
- # ─── Singleton alias for easy import ──────────────────────────────────────────
258
- # Usage: from ai_router.router import LLMRouter
259
- # In agents: result = await self.router.ask(messages, task_id=..., session_id=...)
260
-
261
- # Legacy AIRouter alias so existing imports don't break
262
- AIRouter = LLMRouter
 
1
  """
2
+ Multi-Model AI Router β€” Phase 9
3
+ Supports: OpenAI, Groq, Cerebras, OpenRouter, HuggingFace
4
+ Automatic failover chain: OpenAI β†’ Groq β†’ Cerebras β†’ OpenRouter β†’ HF
 
5
  """
6
 
7
  import asyncio
 
15
 
16
  log = structlog.get_logger()
17
 
18
+ # ─── Provider Config ──────────────────────────────────────────────────────────
19
+ PROVIDERS = [
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  {
21
+ "name": "openai",
22
+ "key_env": "OPENAI_API_KEY",
23
+ "base_url": os.environ.get("OPENAI_BASE_URL", "https://api.openai.com/v1"),
24
+ "default_model": os.environ.get("DEFAULT_MODEL", "gpt-4o"),
25
+ "headers_fn": lambda k: {"Authorization": f"Bearer {k}", "Content-Type": "application/json"},
26
  },
27
  {
28
+ "name": "groq",
29
+ "key_env": "GROQ_API_KEY",
30
+ "base_url": "https://api.groq.com/openai/v1",
31
+ "default_model": "llama-3.3-70b-versatile",
32
+ "headers_fn": lambda k: {"Authorization": f"Bearer {k}", "Content-Type": "application/json"},
33
  },
34
  {
35
+ "name": "cerebras",
36
+ "key_env": "CEREBRAS_API_KEY",
37
+ "base_url": "https://api.cerebras.ai/v1",
38
+ "default_model": "llama3.1-70b",
39
+ "headers_fn": lambda k: {"Authorization": f"Bearer {k}", "Content-Type": "application/json"},
40
  },
41
  {
42
+ "name": "openrouter",
43
+ "key_env": "OPENROUTER_API_KEY",
44
+ "base_url": "https://openrouter.ai/api/v1",
45
+ "default_model": "meta-llama/llama-3.3-70b-instruct:free",
46
+ "headers_fn": lambda k: {
47
+ "Authorization": f"Bearer {k}",
48
+ "Content-Type": "application/json",
49
+ "HTTP-Referer": "https://god-agent.ai",
50
+ "X-Title": "God Agent Platform",
51
+ },
52
+ },
53
+ {
54
+ "name": "anthropic",
55
+ "key_env": "ANTHROPIC_API_KEY",
56
+ "base_url": "https://api.anthropic.com/v1",
57
+ "default_model": "claude-3-5-sonnet-20241022",
58
+ "headers_fn": lambda k: {
59
+ "x-api-key": k,
60
+ "anthropic-version": "2023-06-01",
61
+ "Content-Type": "application/json",
62
+ },
63
  },
64
  ]
65
 
66
 
67
+ class AIRouter:
68
  """
69
+ God Mode AI Router β€” automatically routes and fails over across providers.
70
+ Supports streaming token output via WebSocket.
 
71
  """
72
 
73
  def __init__(self, ws_manager=None):
74
  self.ws = ws_manager
75
+ self._stats: Dict[str, Dict] = {p["name"]: {"calls": 0, "errors": 0, "latency": []} for p in PROVIDERS}
 
 
 
76
 
77
+ def _get_provider(self, name: str) -> Optional[Dict]:
78
+ return next((p for p in PROVIDERS if p["name"] == name), None)
79
 
80
+ def _available_providers(self) -> List[Dict]:
81
+ """Return providers with valid API keys, in priority order."""
82
+ return [p for p in PROVIDERS if os.environ.get(p["key_env"], "")]
83
+
84
+ # ─── Main Entry Point ─────────────────────────────────────────────────────
85
+
86
+ async def complete(
87
  self,
88
  messages: List[Dict],
89
  task_id: str = "",
90
  session_id: str = "",
91
  temperature: float = 0.7,
92
  max_tokens: int = 4096,
93
+ preferred_model: str = "",
94
  stream: bool = True,
95
  ) -> str:
96
+ """Route completion through available providers with failover."""
97
+ providers = self._available_providers()
 
 
 
98
 
99
+ if not providers:
100
+ return await self._demo_stream(messages, task_id, session_id)
 
101
 
102
  last_error = None
103
+ for provider in providers:
104
  try:
105
  start = time.time()
106
+ if provider["name"] == "anthropic":
107
+ result = await self._anthropic_stream(
108
+ provider, messages, task_id, session_id, temperature, max_tokens
109
+ )
110
+ else:
111
+ result = await self._openai_compat_stream(
112
+ provider, messages, task_id, session_id, temperature, max_tokens, preferred_model
113
+ )
114
+ elapsed = time.time() - start
 
 
115
  self._stats[provider["name"]]["calls"] += 1
116
+ self._stats[provider["name"]]["latency"].append(elapsed)
117
+ log.info("AI Router success", provider=provider["name"], ms=round(elapsed * 1000))
118
  return result
 
119
  except Exception as e:
120
  last_error = e
121
  self._stats[provider["name"]]["errors"] += 1
122
+ log.warning("AI Router failover", provider=provider["name"], error=str(e))
123
  continue
124
 
125
+ log.error("All AI providers failed", last_error=str(last_error))
126
+ return await self._demo_stream(messages, task_id, session_id)
127
 
128
+ # ─── OpenAI-compatible Stream (OpenAI, Groq, Cerebras, OpenRouter) ────────
129
 
130
+ async def _openai_compat_stream(
131
+ self, provider, messages, task_id, session_id, temperature, max_tokens, preferred_model
 
 
 
 
 
 
 
 
132
  ) -> str:
133
+ key = os.environ.get(provider["key_env"], "")
134
+ model = preferred_model or provider["default_model"]
135
+ headers = provider["headers_fn"](key)
 
 
 
 
 
136
  payload = {
137
  "model": model,
138
  "messages": messages,
139
+ "stream": True,
140
  "temperature": temperature,
141
  "max_tokens": max_tokens,
 
142
  }
143
+ full_text = ""
144
+ async with httpx.AsyncClient(timeout=120) as client:
145
+ async with client.stream(
146
+ "POST", f"{provider['base_url']}/chat/completions",
147
+ headers=headers, json=payload
148
+ ) as resp:
149
+ resp.raise_for_status()
150
+ async for line in resp.aiter_lines():
151
+ if not line.startswith("data:"):
152
+ continue
153
+ chunk = line[6:].strip()
154
+ if chunk == "[DONE]":
155
+ break
156
+ try:
157
+ data = json.loads(chunk)
158
+ delta = data["choices"][0]["delta"].get("content", "")
159
+ if delta:
160
+ full_text += delta
161
+ await self._emit_chunk(delta, task_id, session_id)
162
+ except Exception:
163
+ pass
164
+ return full_text
165
+
166
+ # ─── Anthropic Stream ─────────────────────────────────────────────────────
167
 
168
+ async def _anthropic_stream(
169
+ self, provider, messages, task_id, session_id, temperature, max_tokens
170
+ ) -> str:
171
+ key = os.environ.get(provider["key_env"], "")
172
+ headers = provider["headers_fn"](key)
173
+ system = ""
174
+ filtered = []
175
+ for m in messages:
176
+ if m["role"] == "system":
177
+ system = m["content"]
178
+ else:
179
+ filtered.append(m)
180
+ payload = {
181
+ "model": provider["default_model"],
182
+ "max_tokens": max_tokens,
183
+ "messages": filtered,
184
+ "stream": True,
185
+ }
186
+ if system:
187
+ payload["system"] = system
188
  full_text = ""
189
+ async with httpx.AsyncClient(timeout=120) as client:
190
+ async with client.stream(
191
+ "POST", f"{provider['base_url']}/messages",
192
+ headers=headers, json=payload
193
+ ) as resp:
194
+ resp.raise_for_status()
195
+ async for line in resp.aiter_lines():
196
+ if not line.startswith("data:"):
197
+ continue
198
+ try:
199
+ data = json.loads(line[5:].strip())
200
+ if data.get("type") == "content_block_delta":
201
+ delta = data["delta"].get("text", "")
 
 
 
 
202
  if delta:
203
  full_text += delta
204
+ await self._emit_chunk(delta, task_id, session_id)
205
+ except Exception:
206
+ pass
 
 
 
 
 
 
 
 
 
 
207
  return full_text
208
 
209
+ # ─── Demo Stream ──────────────────────────────────────────────────────────
210
 
211
+ async def _demo_stream(self, messages, task_id, session_id) -> str:
212
  last_user = next(
213
  (m["content"] for m in reversed(messages) if m["role"] == "user"), "Hello"
214
  )
215
+ response = (
216
+ f"πŸ€– **God Agent** (Demo Mode)\n\n"
217
+ f"Received: *{last_user[:100]}*\n\n"
218
+ f"To enable real AI, set one of these env vars:\n"
 
 
219
  f"- `OPENAI_API_KEY` (GPT-4o)\n"
220
+ f"- `GROQ_API_KEY` (Llama 3.3 70B β€” Free)\n"
221
+ f"- `OPENROUTER_API_KEY` (Multi-model)\n"
222
+ f"- `ANTHROPIC_API_KEY` (Claude 3.5)\n\n"
223
+ f"**God Mode+ Capabilities Active:**\n"
224
+ f"- ⚑ Multi-agent orchestration\n"
225
+ f"- πŸ”§ Autonomous coding & debugging\n"
226
+ f"- 🧠 Persistent memory system\n"
227
+ f"- πŸ”Œ Connector ecosystem\n"
228
+ f"- πŸ“‘ Real-time streaming\n"
229
+ f"- 🌐 Multi-model failover\n"
230
  )
231
+ full_text = ""
232
+ for word in response.split():
233
  chunk = word + " "
234
+ full_text += chunk
235
+ await asyncio.sleep(0.02)
236
+ await self._emit_chunk(chunk, task_id, session_id, demo=True)
237
+ return full_text
238
 
239
+ # ─── Emit Helper ──────────────────────────────────────────────────────────
240
 
241
+ async def _emit_chunk(self, chunk: str, task_id: str, session_id: str, demo: bool = False):
242
  if not self.ws:
243
  return
244
  payload = {"chunk": chunk, "demo": demo}
245
+ if task_id:
246
+ await self.ws.emit(task_id, "llm_chunk", payload, session_id=session_id)
247
+ if session_id and not task_id:
248
+ await self.ws.emit_chat(session_id, "llm_chunk", payload)
 
 
 
249
 
250
+ # ─── Stats ────────────────────────────────────────────────────────────────
251
 
252
  def get_stats(self) -> Dict:
253
+ stats = {}
254
  for name, s in self._stats.items():
255
+ avg_lat = round(sum(s["latency"][-20:]) / max(len(s["latency"][-20:]), 1) * 1000, 1)
256
+ stats[name] = {
257
+ "calls": s["calls"],
258
+ "errors": s["errors"],
259
+ "avg_latency_ms": avg_lat,
260
+ "available": bool(os.environ.get(
261
+ next((p["key_env"] for p in PROVIDERS if p["name"] == name), ""), ""
262
+ )),
263
  }
264
+ return stats