namish10 commited on
Commit
32149f9
·
verified ·
1 Parent(s): eba543b

Upload app/agents/llm_orchestrator_agent.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app/agents/llm_orchestrator_agent.py +868 -0
app/agents/llm_orchestrator_agent.py ADDED
@@ -0,0 +1,868 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM Orchestrator Agent
3
+
4
+ Multi-model orchestration with rate limit handling:
5
+ - ChatGPT (OpenAI)
6
+ - Gemini (Google AI)
7
+ - Automatic retry on rate limits
8
+ - Fallback mechanisms
9
+
10
+ Inspired by GestLLM/GestOS research for gesture-to-LLM integration.
11
+ """
12
+
13
+ import asyncio
14
+ import time
15
+ import json
16
+ from typing import Dict, List, Any, Optional
17
+ from dataclasses import dataclass, field
18
+ from datetime import datetime, timedelta
19
+ from enum import Enum
20
+ import logging
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class LLMProvider(Enum):
26
+ CHATGPT = "chatgpt"
27
+ GEMINI = "gemini"
28
+ CLAUDE = "claude"
29
+ DEEPSEEK = "deepseek"
30
+ OLLAMA = "ollama"
31
+ GROQ = "groq"
32
+
33
+
34
+ @dataclass
35
+ class RateLimitConfig:
36
+ """Rate limit configuration per provider"""
37
+ requests_per_minute: int = 60
38
+ requests_per_day: int = 500
39
+ tokens_per_minute: int = 90000
40
+ retry_after_seconds: int = 60
41
+ max_retries: int = 5
42
+
43
+
44
+ @dataclass
45
+ class RateLimitState:
46
+ """Current rate limit state"""
47
+ request_timestamps: List[datetime] = field(default_factory=list)
48
+ daily_requests: int = 0
49
+ daily_reset: datetime = field(default_factory=lambda: datetime.now() + timedelta(days=1))
50
+ token_usage: List[tuple] = field(default_factory=list)
51
+ is_rate_limited: bool = False
52
+ retry_after: Optional[datetime] = None
53
+
54
+
55
+ @dataclass
56
+ class LLMRequest:
57
+ """A request to an LLM"""
58
+ prompt: str
59
+ system_prompt: Optional[str] = None
60
+ model: str = "gpt-4"
61
+ max_tokens: int = 2000
62
+ temperature: float = 0.7
63
+ providers: List[LLMProvider] = field(default_factory=lambda: [LLMProvider.CHATGPT, LLMProvider.GEMINI])
64
+ timeout: int = 60
65
+ user_id: str = "default"
66
+ models: Optional[Dict[str, str]] = None
67
+
68
+
69
+ @dataclass
70
+ class LLMResponse:
71
+ """Response from an LLM"""
72
+ content: str
73
+ provider: LLMProvider
74
+ model: str
75
+ tokens_used: int
76
+ latency_ms: float
77
+ success: bool
78
+ error: Optional[str] = None
79
+ cached: bool = False
80
+
81
+
82
+ class RateLimitHandler:
83
+ """
84
+ Handles rate limiting with automatic retry.
85
+
86
+ Features:
87
+ - Token bucket algorithm
88
+ - Per-provider limits
89
+ - Exponential backoff
90
+ - Automatic retry when limits refresh
91
+ """
92
+
93
+ def __init__(self, config: Optional[Dict[str, RateLimitConfig]] = None):
94
+ self.configs = config or {
95
+ LLMProvider.CHATGPT: RateLimitConfig(
96
+ requests_per_minute=60,
97
+ requests_per_day=500,
98
+ retry_after_seconds=60
99
+ ),
100
+ LLMProvider.GEMINI: RateLimitConfig(
101
+ requests_per_minute=60,
102
+ requests_per_day=150,
103
+ retry_after_seconds=30
104
+ ),
105
+ LLMProvider.CLAUDE: RateLimitConfig(
106
+ requests_per_minute=50,
107
+ requests_per_day=200,
108
+ retry_after_seconds=60
109
+ ),
110
+ LLMProvider.DEEPSEEK: RateLimitConfig(
111
+ requests_per_minute=60,
112
+ requests_per_day=200,
113
+ retry_after_seconds=60
114
+ ),
115
+ LLMProvider.OLLAMA: RateLimitConfig(
116
+ requests_per_minute=1000,
117
+ requests_per_day=100000,
118
+ retry_after_seconds=1
119
+ ),
120
+ LLMProvider.GROQ: RateLimitConfig(
121
+ requests_per_minute=30,
122
+ requests_per_day=10000,
123
+ retry_after_seconds=60
124
+ )
125
+ }
126
+
127
+ self.states: Dict[LLMProvider, RateLimitState] = {
128
+ provider: RateLimitState() for provider in LLMProvider
129
+ }
130
+
131
+ self.request_queue: Dict[LLMProvider, List[asyncio.Task]] = {}
132
+ self._lock = asyncio.Lock()
133
+
134
+ async def acquire(self, provider: LLMProvider, priority: int = 0) -> bool:
135
+ """Acquire permission to make a request"""
136
+ async with self._lock:
137
+ state = self.states[provider]
138
+ config = self.configs.get(provider, RateLimitConfig())
139
+
140
+ now = datetime.now()
141
+
142
+ if state.retry_after and now < state.retry_after:
143
+ return False
144
+
145
+ if state.daily_reset < now:
146
+ state.daily_requests = 0
147
+ state.daily_reset = now + timedelta(days=1)
148
+
149
+ minute_ago = now - timedelta(minutes=1)
150
+ state.request_timestamps = [
151
+ ts for ts in state.request_timestamps if ts > minute_ago
152
+ ]
153
+
154
+ if state.daily_requests >= config.requests_per_day:
155
+ state.is_rate_limited = True
156
+ state.retry_after = state.daily_reset
157
+ return False
158
+
159
+ if len(state.request_timestamps) >= config.requests_per_minute:
160
+ oldest = min(state.request_timestamps)
161
+ wait_time = (oldest - minute_ago).total_seconds()
162
+ if wait_time > 0:
163
+ state.retry_after = now + timedelta(seconds=wait_time)
164
+ return False
165
+
166
+ return True
167
+
168
+ async def release(self, provider: LLMProvider, tokens_used: int = 0):
169
+ """Release a request slot"""
170
+ async with self._lock:
171
+ state = self.states[provider]
172
+ now = datetime.now()
173
+
174
+ state.request_timestamps.append(now)
175
+ state.daily_requests += 1
176
+
177
+ if tokens_used > 0:
178
+ state.token_usage.append((now, tokens_used))
179
+
180
+ state.retry_after = None
181
+ state.is_rate_limited = False
182
+
183
+ def set_rate_limited(self, provider: LLMProvider, retry_after_seconds: int):
184
+ """Manually set a provider as rate limited from API response"""
185
+ state = self.states[provider]
186
+ state.is_rate_limited = True
187
+ state.retry_after = datetime.now() + timedelta(seconds=retry_after_seconds)
188
+
189
+ config = self.configs.get(provider, RateLimitConfig())
190
+ logger.warning(
191
+ f"Rate limited for {provider.value}: retrying after {retry_after_seconds}s"
192
+ )
193
+
194
+ async def wait_for_slot(self, provider: LLMProvider, max_wait: int = 120) -> bool:
195
+ """Wait for a rate limit slot to become available"""
196
+ start = time.time()
197
+
198
+ while time.time() - start < max_wait:
199
+ if await self.acquire(provider):
200
+ return True
201
+
202
+ state = self.states[provider]
203
+ if state.retry_after:
204
+ wait_seconds = (state.retry_after - datetime.now()).total_seconds()
205
+ if wait_seconds > 0:
206
+ await asyncio.sleep(min(wait_seconds, 5))
207
+ else:
208
+ await asyncio.sleep(1)
209
+
210
+ return False
211
+
212
+ def get_status(self) -> Dict:
213
+ """Get rate limit status for all providers"""
214
+ now = datetime.now()
215
+ status = {}
216
+
217
+ for provider in LLMProvider:
218
+ state = self.states[provider]
219
+ config = self.configs.get(provider, RateLimitConfig())
220
+
221
+ minute_ago = now - timedelta(minutes=1)
222
+ recent_requests = sum(1 for ts in state.request_timestamps if ts > minute_ago)
223
+
224
+ status[provider.value] = {
225
+ "rate_limited": state.is_rate_limited,
226
+ "requests_this_minute": recent_requests,
227
+ "requests_per_minute_limit": config.requests_per_minute,
228
+ "requests_today": state.daily_requests,
229
+ "requests_per_day_limit": config.requests_per_day,
230
+ "retry_after_seconds": (
231
+ (state.retry_after - now).total_seconds()
232
+ if state.retry_after and state.retry_after > now else 0
233
+ )
234
+ }
235
+
236
+ return status
237
+
238
+
239
+ class LLMOrchestrator:
240
+ """
241
+ Multi-model LLM orchestration with gesture triggers.
242
+
243
+ Features:
244
+ - Parallel queries to multiple LLMs
245
+ - Rate limit handling with auto-retry
246
+ - Response synthesis
247
+ - Gesture-triggered actions
248
+
249
+ Inspired by:
250
+ - GestLLM: LLM-powered gesture interpretation
251
+ - GestOS: Multi-robot gesture orchestration
252
+ - GestureGPT: Free-form gesture understanding
253
+ """
254
+
255
+ def __init__(self, api_keys: Optional[Dict[str, str]] = None):
256
+ self.api_keys = api_keys or {}
257
+ self.rate_limiter = RateLimitHandler()
258
+
259
+ self.provider_configs = {
260
+ LLMProvider.CHATGPT: {
261
+ "model": "gpt-4",
262
+ "model_name": "gpt-4o",
263
+ "base_url": "https://api.openai.com/v1",
264
+ "supports_vision": True
265
+ },
266
+ LLMProvider.GEMINI: {
267
+ "model": "gemini-pro",
268
+ "model_name": "gemini-2.0-flash",
269
+ "base_url": "https://generativelanguage.googleapis.com/v1beta",
270
+ "supports_vision": True
271
+ },
272
+ LLMProvider.CLAUDE: {
273
+ "model": "claude-3-opus-20240229",
274
+ "model_name": "claude-3-5-sonnet",
275
+ "base_url": "https://api.anthropic.com/v1",
276
+ "supports_vision": True
277
+ },
278
+ LLMProvider.DEEPSEEK: {
279
+ "model": "deepseek-chat",
280
+ "model_name": "deepseek-chat",
281
+ "base_url": "https://api.deepseek.com/v1",
282
+ "supports_vision": False
283
+ },
284
+ LLMProvider.OLLAMA: {
285
+ "model": "llama3",
286
+ "model_name": "llama3",
287
+ "base_url": "http://localhost:11434/v1",
288
+ "supports_vision": False
289
+ },
290
+ LLMProvider.GROQ: {
291
+ "model": "llama-3.1-70b-versatile",
292
+ "model_name": "llama-3.1-70b-versatile",
293
+ "base_url": "https://api.groq.com/openai/v1",
294
+ "supports_vision": False
295
+ }
296
+ }
297
+
298
+ self.cache: Dict[str, LLMResponse] = {}
299
+ self.cache_ttl = 3600
300
+
301
+ self.pending_requests: List[Dict] = []
302
+ self.response_history: List[LLMResponse] = []
303
+
304
+ async def query(
305
+ self,
306
+ request: LLMRequest,
307
+ preferred_provider: Optional[LLMProvider] = None
308
+ ) -> LLMResponse:
309
+ """Query an LLM with rate limit handling"""
310
+
311
+ providers_to_try = (
312
+ [preferred_provider] if preferred_provider
313
+ else request.providers
314
+ )
315
+
316
+ last_error = None
317
+
318
+ for attempt in range(self.rate_limiter.configs.get(
319
+ providers_to_try[0], RateLimitConfig()
320
+ ).max_retries):
321
+
322
+ for provider in providers_to_try:
323
+ if not await self.rate_limiter.acquire(provider):
324
+ continue
325
+
326
+ try:
327
+ response = await self._call_provider(provider, request)
328
+ await self.rate_limiter.release(provider, response.tokens_used)
329
+
330
+ self.response_history.append(response)
331
+ return response
332
+
333
+ except RateLimitError as e:
334
+ await self.rate_limiter.release(provider, 0)
335
+ self.rate_limiter.set_rate_limited(
336
+ provider,
337
+ e.retry_after or 60
338
+ )
339
+ last_error = e
340
+
341
+ except Exception as e:
342
+ await self.rate_limiter.release(provider, 0)
343
+ last_error = e
344
+ logger.error(f"LLM call failed: {e}")
345
+
346
+ if last_error and not isinstance(last_error, RateLimitError):
347
+ break
348
+
349
+ if providers_to_try[0]:
350
+ await asyncio.sleep(2 ** attempt)
351
+
352
+ return LLMResponse(
353
+ content="",
354
+ provider=providers_to_try[0] if providers_to_try else LLMProvider.CHATGPT,
355
+ model="",
356
+ tokens_used=0,
357
+ latency_ms=0,
358
+ success=False,
359
+ error=str(last_error) if last_error else "All providers failed"
360
+ )
361
+
362
+ async def query_parallel(
363
+ self,
364
+ request: LLMRequest
365
+ ) -> List[LLMResponse]:
366
+ """Query multiple LLMs in parallel and return all responses"""
367
+ tasks = []
368
+
369
+ for provider in request.providers:
370
+ provider_request = LLMRequest(
371
+ prompt=request.prompt,
372
+ system_prompt=request.system_prompt,
373
+ model=request.model,
374
+ max_tokens=request.max_tokens,
375
+ temperature=request.temperature,
376
+ providers=[provider],
377
+ timeout=request.timeout,
378
+ user_id=request.user_id
379
+ )
380
+ tasks.append(self.query(provider_request, provider))
381
+
382
+ responses = await asyncio.gather(*tasks, return_exceptions=True)
383
+
384
+ valid_responses = []
385
+ for r in responses:
386
+ if isinstance(r, Exception):
387
+ valid_responses.append(LLMResponse(
388
+ content="",
389
+ provider=LLMProvider.CHATGPT,
390
+ model="",
391
+ tokens_used=0,
392
+ latency_ms=0,
393
+ success=False,
394
+ error=str(r)
395
+ ))
396
+ else:
397
+ valid_responses.append(r)
398
+
399
+ return valid_responses
400
+
401
+ async def query_with_retry(
402
+ self,
403
+ request: LLMRequest,
404
+ max_attempts: int = 3
405
+ ) -> LLMResponse:
406
+ """Query with automatic retry on rate limits"""
407
+
408
+ for attempt in range(max_attempts):
409
+ response = await self.query(request)
410
+
411
+ if response.success:
412
+ return response
413
+
414
+ if "rate_limit" in (response.error or "").lower():
415
+ await asyncio.sleep(30 * (attempt + 1))
416
+ continue
417
+
418
+ return response
419
+
420
+ return response
421
+
422
+ async def _call_provider(
423
+ self,
424
+ provider: LLMProvider,
425
+ request: LLMRequest
426
+ ) -> LLMResponse:
427
+ """Call a specific LLM provider"""
428
+ start_time = time.time()
429
+
430
+ config = self.provider_configs.get(provider, {})
431
+ model = config.get("model_name", request.model)
432
+
433
+ if request.models and provider.value in request.models:
434
+ model = request.models[provider.value]
435
+
436
+ if provider == LLMProvider.CHATGPT:
437
+ return await self._call_chatgpt(request, model, start_time)
438
+ elif provider == LLMProvider.GEMINI:
439
+ return await self._call_gemini(request, model, start_time)
440
+ elif provider == LLMProvider.CLAUDE:
441
+ return await self._call_claude(request, model, start_time)
442
+ elif provider == LLMProvider.DEEPSEEK:
443
+ return await self._call_deepseek(request, model, start_time)
444
+ elif provider == LLMProvider.OLLAMA:
445
+ return await self._call_ollama(request, model, start_time)
446
+ elif provider == LLMProvider.GROQ:
447
+ return await self._call_groq(request, model, start_time)
448
+ else:
449
+ raise ValueError(f"Unknown provider: {provider}")
450
+
451
+ async def _call_chatgpt(
452
+ self,
453
+ request: LLMRequest,
454
+ model: str,
455
+ start_time: float
456
+ ) -> LLMResponse:
457
+ """Call OpenAI ChatGPT"""
458
+ try:
459
+ import aiohttp
460
+
461
+ api_key = self.api_keys.get("openai", "")
462
+ if not api_key:
463
+ api_key = "dummy-key"
464
+
465
+ headers = {
466
+ "Authorization": f"Bearer {api_key}",
467
+ "Content-Type": "application/json"
468
+ }
469
+
470
+ payload = {
471
+ "model": model,
472
+ "messages": [
473
+ {"role": "system", "content": request.system_prompt or "You are a helpful AI assistant."},
474
+ {"role": "user", "content": request.prompt}
475
+ ],
476
+ "max_tokens": request.max_tokens,
477
+ "temperature": request.temperature
478
+ }
479
+
480
+ async with aiohttp.ClientSession() as session:
481
+ async with session.post(
482
+ f"https://api.openai.com/v1/chat/completions",
483
+ headers=headers,
484
+ json=payload,
485
+ timeout=aiohttp.ClientTimeout(total=request.timeout)
486
+ ) as resp:
487
+ if resp.status == 429:
488
+ raise RateLimitError("Rate limit exceeded", retry_after=60)
489
+
490
+ if resp.status != 200:
491
+ text = await resp.text()
492
+ raise Exception(f"API error: {resp.status} - {text}")
493
+
494
+ data = await resp.json()
495
+
496
+ return LLMResponse(
497
+ content=data["choices"][0]["message"]["content"],
498
+ provider=LLMProvider.CHATGPT,
499
+ model=model,
500
+ tokens_used=data.get("usage", {}).get("total_tokens", 0),
501
+ latency_ms=(time.time() - start_time) * 1000,
502
+ success=True
503
+ )
504
+ except aiohttp.ClientError as e:
505
+ raise Exception(f"Network error: {e}")
506
+
507
+ async def _call_gemini(
508
+ self,
509
+ request: LLMRequest,
510
+ model: str,
511
+ start_time: float
512
+ ) -> LLMResponse:
513
+ """Call Google Gemini"""
514
+ try:
515
+ import aiohttp
516
+
517
+ api_key = self.api_keys.get("gemini", "")
518
+ if not api_key:
519
+ api_key = "dummy-key"
520
+
521
+ payload = {
522
+ "contents": [{
523
+ "parts": [{"text": request.prompt}]
524
+ }],
525
+ "generationConfig": {
526
+ "maxOutputTokens": request.max_tokens,
527
+ "temperature": request.temperature
528
+ }
529
+ }
530
+
531
+ if request.system_prompt:
532
+ payload["systemInstruction"] = {"parts": [{"text": request.system_prompt}]}
533
+
534
+ async with aiohttp.ClientSession() as session:
535
+ async with session.post(
536
+ f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateContent?key={api_key}",
537
+ json=payload,
538
+ timeout=aiohttp.ClientTimeout(total=request.timeout)
539
+ ) as resp:
540
+ if resp.status == 429:
541
+ raise RateLimitError("Rate limit exceeded", retry_after=30)
542
+
543
+ if resp.status != 200:
544
+ text = await resp.text()
545
+ raise Exception(f"API error: {resp.status} - {text}")
546
+
547
+ data = await resp.json()
548
+
549
+ content = data["candidates"][0]["content"]["parts"][0]["text"]
550
+
551
+ return LLMResponse(
552
+ content=content,
553
+ provider=LLMProvider.GEMINI,
554
+ model=model,
555
+ tokens_used=0,
556
+ latency_ms=(time.time() - start_time) * 1000,
557
+ success=True
558
+ )
559
+ except aiohttp.ClientError as e:
560
+ raise Exception(f"Network error: {e}")
561
+
562
+ async def _call_claude(
563
+ self,
564
+ request: LLMRequest,
565
+ model: str,
566
+ start_time: float
567
+ ) -> LLMResponse:
568
+ """Call Anthropic Claude"""
569
+ try:
570
+ import aiohttp
571
+
572
+ api_key = self.api_keys.get("claude", "") or self.api_keys.get("anthropic", "")
573
+ if not api_key:
574
+ api_key = "dummy-key"
575
+
576
+ headers = {
577
+ "x-api-key": api_key,
578
+ "Content-Type": "application/json",
579
+ "anthropic-version": "2023-06-01"
580
+ }
581
+
582
+ payload = {
583
+ "model": model,
584
+ "max_tokens": request.max_tokens,
585
+ "temperature": request.temperature,
586
+ "messages": [
587
+ {"role": "user", "content": request.prompt}
588
+ ]
589
+ }
590
+
591
+ if request.system_prompt:
592
+ payload["system"] = request.system_prompt
593
+
594
+ async with aiohttp.ClientSession() as session:
595
+ async with session.post(
596
+ "https://api.anthropic.com/v1/messages",
597
+ headers=headers,
598
+ json=payload,
599
+ timeout=aiohttp.ClientTimeout(total=request.timeout)
600
+ ) as resp:
601
+ if resp.status == 429:
602
+ raise RateLimitError("Rate limit exceeded", retry_after=60)
603
+
604
+ if resp.status != 201:
605
+ text = await resp.text()
606
+ raise Exception(f"API error: {resp.status} - {text}")
607
+
608
+ data = await resp.json()
609
+
610
+ return LLMResponse(
611
+ content=data["content"][0]["text"],
612
+ provider=LLMProvider.CLAUDE,
613
+ model=model,
614
+ tokens_used=data.get("usage", {}).get("input_tokens", 0) + data["usage"].get("output_tokens", 0),
615
+ latency_ms=(time.time() - start_time) * 1000,
616
+ success=True
617
+ )
618
+ except aiohttp.ClientError as e:
619
+ raise Exception(f"Network error: {e}")
620
+
621
+ async def _call_deepseek(
622
+ self,
623
+ request: LLMRequest,
624
+ model: str,
625
+ start_time: float
626
+ ) -> LLMResponse:
627
+ """Call DeepSeek"""
628
+ try:
629
+ import aiohttp
630
+
631
+ api_key = self.api_keys.get("deepseek", "")
632
+ if not api_key:
633
+ return LLMResponse(
634
+ content="",
635
+ provider=LLMProvider.DEEPSEEK,
636
+ model=model,
637
+ tokens_used=0,
638
+ latency_ms=0,
639
+ success=False,
640
+ error="API key not configured"
641
+ )
642
+
643
+ headers = {
644
+ "Authorization": f"Bearer {api_key}",
645
+ "Content-Type": "application/json"
646
+ }
647
+
648
+ payload = {
649
+ "model": model,
650
+ "messages": [
651
+ {"role": "system", "content": request.system_prompt or "You are a helpful AI assistant."},
652
+ {"role": "user", "content": request.prompt}
653
+ ],
654
+ "max_tokens": request.max_tokens,
655
+ "temperature": request.temperature
656
+ }
657
+
658
+ async with aiohttp.ClientSession() as session:
659
+ async with session.post(
660
+ "https://api.deepseek.com/chat/completions",
661
+ headers=headers,
662
+ json=payload,
663
+ timeout=aiohttp.ClientTimeout(total=request.timeout)
664
+ ) as resp:
665
+ if resp.status == 429:
666
+ raise RateLimitError("Rate limit exceeded", retry_after=60)
667
+
668
+ if resp.status != 200:
669
+ text = await resp.text()
670
+ raise Exception(f"API error: {resp.status} - {text}")
671
+
672
+ data = await resp.json()
673
+
674
+ return LLMResponse(
675
+ content=data["choices"][0]["message"]["content"],
676
+ provider=LLMProvider.DEEPSEEK,
677
+ model=model,
678
+ tokens_used=data.get("usage", {}).get("total_tokens", 0),
679
+ latency_ms=(time.time() - start_time) * 1000,
680
+ success=True
681
+ )
682
+ except aiohttp.ClientError as e:
683
+ raise Exception(f"Network error: {e}")
684
+
685
+ async def _call_ollama(
686
+ self,
687
+ request: LLMRequest,
688
+ model: str,
689
+ start_time: float
690
+ ) -> LLMResponse:
691
+ """Call Ollama (local)"""
692
+ try:
693
+ import aiohttp
694
+
695
+ headers = {
696
+ "Content-Type": "application/json"
697
+ }
698
+
699
+ payload = {
700
+ "model": model,
701
+ "messages": [
702
+ {"role": "system", "content": request.system_prompt or "You are a helpful AI assistant."},
703
+ {"role": "user", "content": request.prompt}
704
+ ],
705
+ "options": {
706
+ "temperature": request.temperature
707
+ },
708
+ "stream": False
709
+ }
710
+
711
+ base_url = self.provider_configs.get(LLMProvider.OLLAMA, {}).get("base_url", "http://localhost:11434/v1")
712
+
713
+ async with aiohttp.ClientSession() as session:
714
+ async with session.post(
715
+ f"{base_url}/chat/completions",
716
+ headers=headers,
717
+ json=payload,
718
+ timeout=aiohttp.ClientTimeout(total=request.timeout)
719
+ ) as resp:
720
+ if resp.status != 200:
721
+ text = await resp.text()
722
+ raise Exception(f"Ollama error: {resp.status} - {text}")
723
+
724
+ data = await resp.json()
725
+
726
+ return LLMResponse(
727
+ content=data["message"]["content"],
728
+ provider=LLMProvider.OLLAMA,
729
+ model=model,
730
+ tokens_used=0,
731
+ latency_ms=(time.time() - start_time) * 1000,
732
+ success=True
733
+ )
734
+ except aiohttp.ClientError as e:
735
+ raise Exception(f"Network error: {e}")
736
+
737
+ async def _call_groq(
738
+ self,
739
+ request: LLMRequest,
740
+ model: str,
741
+ start_time: float
742
+ ) -> LLMResponse:
743
+ """Call Groq (fast GPU inference)"""
744
+ try:
745
+ import aiohttp
746
+
747
+ api_key = self.api_keys.get("groq", "")
748
+ if not api_key:
749
+ return LLMResponse(
750
+ content="",
751
+ provider=LLMProvider.GROQ,
752
+ model=model,
753
+ tokens_used=0,
754
+ latency_ms=0,
755
+ success=False,
756
+ error="API key not configured"
757
+ )
758
+
759
+ headers = {
760
+ "Authorization": f"Bearer {api_key}",
761
+ "Content-Type": "application/json"
762
+ }
763
+
764
+ payload = {
765
+ "model": model,
766
+ "messages": [
767
+ {"role": "system", "content": request.system_prompt or "You are a helpful AI assistant."},
768
+ {"role": "user", "content": request.prompt}
769
+ ],
770
+ "max_tokens": request.max_tokens,
771
+ "temperature": request.temperature
772
+ }
773
+
774
+ async with aiohttp.ClientSession() as session:
775
+ async with session.post(
776
+ "https://api.groq.com/openai/v1/chat/completions",
777
+ headers=headers,
778
+ json=payload,
779
+ timeout=aiohttp.ClientTimeout(total=request.timeout)
780
+ ) as resp:
781
+ if resp.status == 429:
782
+ raise RateLimitError("Rate limit exceeded", retry_after=30)
783
+
784
+ if resp.status != 200:
785
+ text = await resp.text()
786
+ raise Exception(f"API error: {resp.status} - {text}")
787
+
788
+ data = await resp.json()
789
+
790
+ return LLMResponse(
791
+ content=data["choices"][0]["message"]["content"],
792
+ provider=LLMProvider.GROQ,
793
+ model=model,
794
+ tokens_used=data.get("usage", {}).get("total_tokens", 0),
795
+ latency_ms=(time.time() - start_time) * 1000,
796
+ success=True
797
+ )
798
+ except aiohttp.ClientError as e:
799
+ raise Exception(f"Network error: {e}")
800
+
801
+ def get_rate_limit_status(self) -> Dict:
802
+ """Get current rate limit status"""
803
+ return self.rate_limiter.get_status()
804
+
805
+ def get_pending_requests(self) -> int:
806
+ """Get number of pending requests"""
807
+ return len(self.pending_requests)
808
+
809
+
810
+ class RateLimitError(Exception):
811
+ """Raised when rate limited"""
812
+ def __init__(self, message: str, retry_after: Optional[int] = None):
813
+ super().__init__(message)
814
+ self.retry_after = retry_after
815
+
816
+
817
+ class LLMSession:
818
+ """Manages an LLM conversation session"""
819
+
820
+ def __init__(self, orchestrator: LLMOrchestrator, user_id: str):
821
+ self.orchestrator = orchestrator
822
+ self.user_id = user_id
823
+ self.messages: List[Dict] = []
824
+ self.system_prompt = "You are a helpful learning assistant."
825
+
826
+ def add_message(self, role: str, content: str):
827
+ """Add a message to the conversation"""
828
+ self.messages.append({"role": role, "content": content})
829
+
830
+ async def send(
831
+ self,
832
+ message: str,
833
+ providers: Optional[List[LLMProvider]] = None
834
+ ) -> List[LLMResponse]:
835
+ """Send a message and get responses from all providers"""
836
+ self.add_message("user", message)
837
+
838
+ request = LLMRequest(
839
+ prompt=self._format_conversation(),
840
+ system_prompt=self.system_prompt,
841
+ providers=providers or [LLMProvider.CHATGPT, LLMProvider.GEMINI],
842
+ user_id=self.user_id
843
+ )
844
+
845
+ responses = await self.orchestrator.query_parallel(request)
846
+
847
+ for response in responses:
848
+ if response.success:
849
+ self.add_message("assistant", response.content)
850
+
851
+ return responses
852
+
853
+ def _format_conversation(self) -> str:
854
+ """Format conversation history for LLM"""
855
+ formatted = []
856
+ for msg in self.messages[-10:]:
857
+ role = msg["role"].capitalize()
858
+ formatted.append(f"{role}: {msg['content']}")
859
+ return "\n".join(formatted)
860
+
861
+ def clear(self):
862
+ """Clear conversation history"""
863
+ self.messages = []
864
+
865
+
866
+ def create_orchestrator(api_keys: Optional[Dict[str, str]] = None) -> LLMOrchestrator:
867
+ """Create a new LLM orchestrator instance"""
868
+ return LLMOrchestrator(api_keys)