Cuong2004 commited on
Commit
14208c6
·
1 Parent(s): 4a35a3f

fix logic agent and add api rotation

Browse files
.env.example CHANGED
@@ -20,8 +20,11 @@ GOOGLE_API_KEY=your_google_api_key
20
  GOOGLE_CLIENT_ID=your_google_api_key
21
  JWT_SECRET=your-super-secret-jwt-key-change-in-production
22
 
23
- # MegaLLM (OpenAI-compatible - DeepSeek)
24
- MEGALLM_API_KEY=your_megallm_api_key
 
 
 
25
  MEGALLM_BASE_URL=https://ai.megallm.io/v1
26
 
27
  # Brave Social Search
@@ -29,6 +32,3 @@ BRAVE_API_KEY=your_brave_api_key
29
 
30
  # Google OAuth
31
  GOOGLE_CLIENT_ID=your_google_client_id
32
-
33
- # CLIP (optional - for image embeddings)
34
- HUGGINGFACE_API_KEY=your_hf_api_key
 
20
  GOOGLE_CLIENT_ID=your_google_api_key
21
  JWT_SECRET=your-super-secret-jwt-key-change-in-production
22
 
23
+ # MegaLLM (API Key Rotation - add as many as needed)
24
+ # Keys are rotated round-robin to avoid 15 req/min limit per key
25
+ MEGALLM_API_KEY_1=your_first_megallm_api_key
26
+ MEGALLM_API_KEY_2=your_second_megallm_api_key
27
+ MEGALLM_API_KEY_3=your_third_megallm_api_key
28
  MEGALLM_BASE_URL=https://ai.megallm.io/v1
29
 
30
  # Brave Social Search
 
32
 
33
  # Google OAuth
34
  GOOGLE_CLIENT_ID=your_google_client_id
 
 
 
app/agent/mmca_agent.py CHANGED
@@ -10,6 +10,7 @@ Supports multiple LLM providers: Google (Gemini) and MegaLLM (DeepSeek).
10
  """
11
 
12
  import json
 
13
  import time
14
  from dataclasses import dataclass, field
15
  from typing import Any
@@ -83,6 +84,7 @@ class ChatResult:
83
  tools_used: list[str] = field(default_factory=list)
84
  total_duration_ms: float = 0
85
  tool_results: list = field(default_factory=list) # List of ToolCall with results
 
86
 
87
 
88
  class MMCAAgent:
@@ -209,7 +211,7 @@ class MMCAAgent:
209
  agent_logger.workflow_step("Step 3: Synthesize Response")
210
 
211
  llm_start = time.time()
212
- response = await self._synthesize_response(message, tool_results, image_url, history)
213
  llm_duration = (time.time() - llm_start) * 1000
214
 
215
  agent_logger.llm_response(self.provider, response[:100], tokens=None)
@@ -228,14 +230,15 @@ class MMCAAgent:
228
  workflow.total_duration_ms = total_duration
229
 
230
  # Log complete
231
- agent_logger.api_response("/chat", 200, {"response_len": len(response)}, total_duration)
232
 
233
  return ChatResult(
234
  response=response,
235
  workflow=workflow,
236
  tools_used=workflow.tools_used,
237
  total_duration_ms=total_duration,
238
- tool_results=tool_results, # Include tool results for place extraction
 
239
  )
240
 
241
  def _detect_intent(self, message: str, image_url: str | None) -> str:
@@ -269,6 +272,35 @@ class MMCAAgent:
269
  }
270
  return purposes.get(tool_name, tool_name)
271
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
  async def _plan_tool_calls(
273
  self,
274
  message: str,
@@ -278,7 +310,13 @@ class MMCAAgent:
278
  Analyze message and plan which tools to call.
279
 
280
  Returns list of ToolCall objects with tool_name and arguments.
 
281
  """
 
 
 
 
 
282
  tool_calls = []
283
 
284
  # If image is provided, always use visual search
@@ -343,7 +381,6 @@ class MMCAAgent:
343
  arguments={"query": message, "limit": 5},
344
  ))
345
 
346
-
347
  return tool_calls
348
 
349
  async def _execute_tool(
@@ -442,8 +479,39 @@ class MMCAAgent:
442
  tool_results: list[ToolCall],
443
  image_url: str | None = None,
444
  history: str | None = None,
445
- ) -> str:
446
- """Synthesize final response from tool results with conversation history."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
447
  # Build context from tool results
448
  context_parts = []
449
  for tool_call in tool_results:
@@ -452,7 +520,7 @@ class MMCAAgent:
452
  f"Kết quả từ {tool_call.tool_name}:\n{json.dumps(tool_call.result, ensure_ascii=False, indent=2)}"
453
  )
454
 
455
- context = "\n\n".join(context_parts) if context_parts else "Không tìm thấy kết quả phù hợp."
456
 
457
  # Build history section if available
458
  history_section = ""
@@ -463,25 +531,61 @@ class MMCAAgent:
463
  ---
464
  """
465
 
466
- # Generate response using LLM
467
- prompt = f"""{history_section}Dựa trên kết quả tìm kiếm sau, hãy trả lời câu hỏi của người dùng một cách tự nhiên và hữu ích.
468
 
469
  Câu hỏi hiện tại: {message}
470
 
471
  {context}
472
 
473
- Hãy trả lời bằng tiếng Việt, thân thiện. Nếu có nhiều kết quả, hãy giới thiệu top 2-3 địa điểm phù hợp nhất.
 
 
 
 
 
 
 
 
474
  Nếu có lịch sử hội thoại, hãy cân nhắc ngữ cảnh trước đó khi trả lời."""
475
 
476
  agent_logger.llm_call(self.provider, self.model or "default", prompt[:100])
477
 
478
- response = await self.llm_client.generate(
479
  prompt=prompt,
480
  temperature=0.7,
481
  system_instruction=SYSTEM_PROMPT,
482
  )
483
 
484
- return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
485
 
486
  def _extract_location(self, message: str) -> str | None:
487
  """Extract location name from message using pattern matching."""
 
10
  """
11
 
12
  import json
13
+ import re
14
  import time
15
  from dataclasses import dataclass, field
16
  from typing import Any
 
84
  tools_used: list[str] = field(default_factory=list)
85
  total_duration_ms: float = 0
86
  tool_results: list = field(default_factory=list) # List of ToolCall with results
87
+ selected_place_ids: list[str] = field(default_factory=list) # LLM-selected place IDs
88
 
89
 
90
  class MMCAAgent:
 
211
  agent_logger.workflow_step("Step 3: Synthesize Response")
212
 
213
  llm_start = time.time()
214
+ response, selected_place_ids = await self._synthesize_response(message, tool_results, image_url, history)
215
  llm_duration = (time.time() - llm_start) * 1000
216
 
217
  agent_logger.llm_response(self.provider, response[:100], tokens=None)
 
230
  workflow.total_duration_ms = total_duration
231
 
232
  # Log complete
233
+ agent_logger.api_response("/chat", 200, {"response_len": len(response), "places": len(selected_place_ids)}, total_duration)
234
 
235
  return ChatResult(
236
  response=response,
237
  workflow=workflow,
238
  tools_used=workflow.tools_used,
239
  total_duration_ms=total_duration,
240
+ tool_results=tool_results,
241
+ selected_place_ids=selected_place_ids,
242
  )
243
 
244
  def _detect_intent(self, message: str, image_url: str | None) -> str:
 
272
  }
273
  return purposes.get(tool_name, tool_name)
274
 
275
+ def _is_greeting_or_simple_query(self, message: str) -> bool:
276
+ """
277
+ Check if message is a simple greeting/small-talk that doesn't need tools.
278
+
279
+ Returns True for greetings, thanks, simple acknowledgments.
280
+ """
281
+ simple_patterns = [
282
+ # English
283
+ "hello", "hi", "hey", "yo", "sup",
284
+ "thank", "thanks", "bye", "goodbye",
285
+ "ok", "okay", "yes", "no", "good", "great", "nice",
286
+ # Vietnamese
287
+ "xin chào", "chào", "chào bạn", "ê", "alo",
288
+ "cảm ơn", "cám ơn", "thanks", "tạm biệt", "bye",
289
+ "ok", "được", "tốt", "hay", "ừ", "ờ", "vâng", "dạ",
290
+ ]
291
+ msg_lower = message.lower().strip()
292
+
293
+ # Very short messages are likely greetings
294
+ if len(msg_lower) < 15:
295
+ for pattern in simple_patterns:
296
+ if pattern in msg_lower:
297
+ return True
298
+ # Also check if message is just a single word greeting
299
+ if msg_lower in simple_patterns:
300
+ return True
301
+
302
+ return False
303
+
304
  async def _plan_tool_calls(
305
  self,
306
  message: str,
 
310
  Analyze message and plan which tools to call.
311
 
312
  Returns list of ToolCall objects with tool_name and arguments.
313
+ Returns empty list for simple greetings (no tools needed).
314
  """
315
+ # Early exit for greetings - no tools needed
316
+ if self._is_greeting_or_simple_query(message) and not image_url:
317
+ agent_logger.workflow_step("Greeting detected", "Skipping tools")
318
+ return []
319
+
320
  tool_calls = []
321
 
322
  # If image is provided, always use visual search
 
381
  arguments={"query": message, "limit": 5},
382
  ))
383
 
 
384
  return tool_calls
385
 
386
  async def _execute_tool(
 
479
  tool_results: list[ToolCall],
480
  image_url: str | None = None,
481
  history: str | None = None,
482
+ ) -> tuple[str, list[str]]:
483
+ """
484
+ Synthesize final response from tool results with conversation history.
485
+
486
+ Returns:
487
+ Tuple of (response_text, selected_place_ids)
488
+ """
489
+ # Collect all available place_ids from tool results
490
+ all_place_ids = []
491
+ for tool_call in tool_results:
492
+ if tool_call.result:
493
+ for item in tool_call.result:
494
+ if isinstance(item, dict) and 'place_id' in item:
495
+ all_place_ids.append(item['place_id'])
496
+
497
+ # If no tool results (greeting case), return simple response
498
+ if not tool_results:
499
+ # Build history section if available
500
+ history_section = ""
501
+ if history:
502
+ history_section = f"Lịch sử hội thoại:\n{history}\n\n---\n"
503
+
504
+ prompt = f"""{history_section}User nói: "{message}"
505
+
506
+ Hãy trả lời thân thiện bằng tiếng Việt. Đây là lời chào hoặc tin nhắn đơn giản, không cần tìm kiếm địa điểm."""
507
+
508
+ response = await self.llm_client.generate(
509
+ prompt=prompt,
510
+ temperature=0.7,
511
+ system_instruction="Bạn là LocalMate - trợ lý du lịch thân thiện cho Đà Nẵng. Trả lời ngắn gọn, thân thiện.",
512
+ )
513
+ return response, []
514
+
515
  # Build context from tool results
516
  context_parts = []
517
  for tool_call in tool_results:
 
520
  f"Kết quả từ {tool_call.tool_name}:\n{json.dumps(tool_call.result, ensure_ascii=False, indent=2)}"
521
  )
522
 
523
+ context = "\n\n".join(context_parts)
524
 
525
  # Build history section if available
526
  history_section = ""
 
531
  ---
532
  """
533
 
534
+ # Generate response using LLM with JSON format for place selection
535
+ prompt = f"""{history_section}Dựa trên kết quả tìm kiếm sau, hãy trả lời câu hỏi của người dùng.
536
 
537
  Câu hỏi hiện tại: {message}
538
 
539
  {context}
540
 
541
+ **QUAN TRỌNG:** Trả lời theo format JSON:
542
+ ```json
543
+ {{
544
+ "response": "Câu trả lời tiếng Việt, thân thiện. Giới thiệu top 2-3 địa điểm phù hợp nhất.",
545
+ "selected_place_ids": ["place_id_1", "place_id_2", "place_id_3"]
546
+ }}
547
+ ```
548
+
549
+ Chỉ chọn những place_id xuất hiện trong kết quả tìm kiếm ở trên. Nếu không có địa điểm phù hợp, để mảng rỗng.
550
  Nếu có lịch sử hội thoại, hãy cân nhắc ngữ cảnh trước đó khi trả lời."""
551
 
552
  agent_logger.llm_call(self.provider, self.model or "default", prompt[:100])
553
 
554
+ raw_response = await self.llm_client.generate(
555
  prompt=prompt,
556
  temperature=0.7,
557
  system_instruction=SYSTEM_PROMPT,
558
  )
559
 
560
+ # Parse JSON response
561
+ try:
562
+ # Extract JSON from code blocks
563
+ json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', raw_response, re.DOTALL)
564
+ if json_match:
565
+ json_str = json_match.group(1)
566
+ else:
567
+ # Try to find raw JSON
568
+ json_start = raw_response.find('{')
569
+ json_end = raw_response.rfind('}')
570
+ if json_start != -1 and json_end != -1:
571
+ json_str = raw_response[json_start:json_end + 1]
572
+ else:
573
+ # No JSON found, return raw response
574
+ return raw_response, []
575
+
576
+ data = json.loads(json_str)
577
+ text_response = data.get("response", raw_response)
578
+ selected_ids = data.get("selected_place_ids", [])
579
+
580
+ # Validate selected_ids are in available places
581
+ valid_ids = [pid for pid in selected_ids if pid in all_place_ids]
582
+
583
+ return text_response, valid_ids
584
+
585
+ except (json.JSONDecodeError, KeyError) as e:
586
+ agent_logger.error("Failed to parse synthesis JSON", e)
587
+ # Fallback: return raw response with no places
588
+ return raw_response, []
589
 
590
  def _extract_location(self, message: str) -> str | None:
591
  """Extract location name from message using pattern matching."""
app/api/router.py CHANGED
@@ -405,30 +405,20 @@ async def chat(
405
  session_id=session_id,
406
  )
407
 
408
- # Extract places from tool results if available
409
  places = []
410
- if result.tool_results:
411
- # Extract place_ids from ToolCall objects
412
- place_ids = []
413
- distance_map = {} # Store distance info for nearby places
414
  for tool_call in result.tool_results:
415
- # ToolCall has .result attribute which is a list of dicts
416
  if tool_call.result:
417
  for item in tool_call.result:
418
- if isinstance(item, dict) and 'place_id' in item:
419
- pid = item['place_id']
420
- if pid not in place_ids: # Avoid duplicates
421
- place_ids.append(pid)
422
- # Capture distance if available (from find_nearby_places)
423
- if 'distance_km' in item:
424
- distance_map[pid] = item['distance_km']
425
-
426
- if place_ids:
427
- places = await enrich_places_from_ids(place_ids[:5], db) # Limit to top 5
428
- # Add distance info to places
429
- for place in places:
430
- if place.place_id in distance_map:
431
- place.distance_km = distance_map[place.place_id]
432
 
433
  return ChatResponse(
434
  response=result.response,
 
405
  session_id=session_id,
406
  )
407
 
408
+ # Use LLM-selected places (same pattern as ReAct mode)
409
  places = []
410
+ if result.selected_place_ids:
411
+ places = await enrich_places_from_ids(result.selected_place_ids, db)
412
+ # Add distance info if available from tool results
413
+ distance_map = {}
414
  for tool_call in result.tool_results:
 
415
  if tool_call.result:
416
  for item in tool_call.result:
417
+ if isinstance(item, dict) and 'place_id' in item and 'distance_km' in item:
418
+ distance_map[item['place_id']] = item['distance_km']
419
+ for place in places:
420
+ if place.place_id in distance_map:
421
+ place.distance_km = distance_map[place.place_id]
 
 
 
 
 
 
 
 
 
422
 
423
  return ChatResponse(
424
  response=result.response,
app/shared/integrations/key_rotator.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Thread-safe API Key Rotator for load balancing across multiple keys.
2
+
3
+ This module provides a round-robin key rotation mechanism to distribute
4
+ API requests across multiple keys, helping to avoid rate limits.
5
+
6
+ Usage:
7
+ from app.shared.integrations.key_rotator import megallm_key_rotator
8
+
9
+ api_key = megallm_key_rotator.get_next_key()
10
+ """
11
+
12
+ import logging
13
+ import os
14
+ import threading
15
+ from typing import List
16
+
17
+ from dotenv import load_dotenv
18
+
19
+ # Ensure .env is loaded before accessing os.environ
20
+ load_dotenv()
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class KeyRotator:
26
+ """Thread-safe round-robin API key rotator.
27
+
28
+ Distributes API calls across multiple keys to avoid per-key rate limits.
29
+ Each call to get_next_key() returns the next key in rotation.
30
+
31
+ Attributes:
32
+ _keys: List of API keys to rotate through
33
+ _index: Current position in rotation
34
+ _lock: Thread lock for safe concurrent access
35
+ """
36
+
37
+ def __init__(self, keys: List[str], name: str = "default"):
38
+ """Initialize the key rotator.
39
+
40
+ Args:
41
+ keys: List of API keys (must have at least one)
42
+ name: Name for logging identification
43
+
44
+ Raises:
45
+ ValueError: If keys list is empty
46
+ """
47
+ if not keys:
48
+ raise ValueError("At least one API key is required")
49
+
50
+ self._keys = keys
51
+ self._name = name
52
+ self._index = 0
53
+ self._lock = threading.Lock()
54
+ self._request_count = 0
55
+
56
+ logger.info(f"[KeyRotator:{name}] Initialized with {len(keys)} API keys")
57
+
58
+ def get_next_key(self) -> str:
59
+ """Get next API key in rotation (thread-safe).
60
+
61
+ Returns:
62
+ The next API key in round-robin order
63
+ """
64
+ with self._lock:
65
+ key = self._keys[self._index]
66
+ key_index = self._index + 1 # 1-based for logging
67
+ self._index = (self._index + 1) % len(self._keys)
68
+ self._request_count += 1
69
+
70
+ # Log rotation (mask key for security, only show last 8 chars)
71
+ masked_key = f"...{key[-8:]}" if len(key) > 8 else key
72
+ logger.info(
73
+ f"[KeyRotator:{self._name}] Request #{self._request_count} "
74
+ f"using key {key_index}/{len(self._keys)} ({masked_key})"
75
+ )
76
+
77
+ return key
78
+
79
+ @property
80
+ def total_keys(self) -> int:
81
+ """Number of keys in rotation."""
82
+ return len(self._keys)
83
+
84
+ @property
85
+ def request_count(self) -> int:
86
+ """Total number of requests made through this rotator."""
87
+ return self._request_count
88
+
89
+ def get_stats(self) -> dict:
90
+ """Get rotation statistics for debugging."""
91
+ return {
92
+ "name": self._name,
93
+ "total_keys": len(self._keys),
94
+ "current_index": self._index,
95
+ "total_requests": self._request_count,
96
+ }
97
+
98
+
99
+ def load_megallm_keys() -> List[str]:
100
+ """Load all MEGALLM_API_KEY_* from environment variables.
101
+
102
+ Looks for keys in format: MEGALLM_API_KEY_1, MEGALLM_API_KEY_2, etc.
103
+ Falls back to single MEGALLM_API_KEY for backward compatibility.
104
+
105
+ Returns:
106
+ List of API keys found in environment
107
+ """
108
+ keys = []
109
+ i = 1
110
+
111
+ # Load numbered keys (MEGALLM_API_KEY_1, MEGALLM_API_KEY_2, ...)
112
+ while True:
113
+ key = os.environ.get(f"MEGALLM_API_KEY_{i}")
114
+ if not key:
115
+ break
116
+ keys.append(key)
117
+ i += 1
118
+
119
+ # Fallback to single key for backward compatibility
120
+ if not keys:
121
+ single_key = os.environ.get("MEGALLM_API_KEY")
122
+ if single_key:
123
+ keys = [single_key]
124
+ logger.warning(
125
+ "[KeyRotator] Using legacy MEGALLM_API_KEY. "
126
+ "Consider migrating to MEGALLM_API_KEY_1, MEGALLM_API_KEY_2, etc."
127
+ )
128
+
129
+ if keys:
130
+ logger.info(f"[KeyRotator] Loaded {len(keys)} MegaLLM API key(s)")
131
+ else:
132
+ logger.warning("[KeyRotator] No MegaLLM API keys found in environment")
133
+
134
+ return keys
135
+
136
+
137
+ # Singleton instance for MegaLLM key rotation
138
+ _megallm_keys = load_megallm_keys()
139
+ megallm_key_rotator: KeyRotator | None = None
140
+
141
+ if _megallm_keys:
142
+ megallm_key_rotator = KeyRotator(_megallm_keys, name="MegaLLM")
app/shared/integrations/megallm_client.py CHANGED
@@ -1,8 +1,13 @@
1
- """MegaLLM client using OpenAI-compatible API with retry logic."""
 
 
2
 
3
  import httpx
4
 
5
  from app.core.config import settings
 
 
 
6
 
7
  # Timeout configuration for DeepSeek reasoning models (can take longer)
8
  REQUEST_TIMEOUT = httpx.Timeout(
@@ -14,13 +19,21 @@ REQUEST_TIMEOUT = httpx.Timeout(
14
 
15
 
16
  class MegaLLMClient:
17
- """Client for MegaLLM (OpenAI-compatible API) operations."""
18
 
19
  def __init__(self, model: str | None = None):
20
  """Initialize with optional model override."""
21
  self.model = model or settings.default_megallm_model
22
- self.api_key = settings.megallm_api_key
23
  self.base_url = settings.megallm_base_url
 
 
 
 
 
 
 
 
 
24
 
25
  async def generate(
26
  self,
@@ -41,8 +54,8 @@ class MegaLLMClient:
41
  Returns:
42
  Generated text
43
  """
44
- if not self.api_key:
45
- raise ValueError("MEGALLM_API_KEY is not configured")
46
 
47
  messages = []
48
  if system_instruction:
@@ -56,7 +69,7 @@ class MegaLLMClient:
56
  response = await client.post(
57
  f"{self.base_url}/chat/completions",
58
  headers={
59
- "Authorization": f"Bearer {self.api_key}",
60
  "Content-Type": "application/json",
61
  },
62
  json={
@@ -97,8 +110,8 @@ class MegaLLMClient:
97
  Returns:
98
  Generated text response
99
  """
100
- if not self.api_key:
101
- raise ValueError("MEGALLM_API_KEY is not configured")
102
 
103
  chat_messages = []
104
  if system_instruction:
@@ -114,7 +127,7 @@ class MegaLLMClient:
114
  response = await client.post(
115
  f"{self.base_url}/chat/completions",
116
  headers={
117
- "Authorization": f"Bearer {self.api_key}",
118
  "Content-Type": "application/json",
119
  },
120
  json={
 
1
+ """MegaLLM client using OpenAI-compatible API with retry logic and key rotation."""
2
+
3
+ import logging
4
 
5
  import httpx
6
 
7
  from app.core.config import settings
8
+ from app.shared.integrations.key_rotator import megallm_key_rotator
9
+
10
+ logger = logging.getLogger(__name__)
11
 
12
  # Timeout configuration for DeepSeek reasoning models (can take longer)
13
  REQUEST_TIMEOUT = httpx.Timeout(
 
19
 
20
 
21
  class MegaLLMClient:
22
+ """Client for MegaLLM (OpenAI-compatible API) operations with key rotation."""
23
 
24
  def __init__(self, model: str | None = None):
25
  """Initialize with optional model override."""
26
  self.model = model or settings.default_megallm_model
 
27
  self.base_url = settings.megallm_base_url
28
+
29
+ def _get_api_key(self) -> str:
30
+ """Get API key using rotation or fallback to settings."""
31
+ if megallm_key_rotator:
32
+ return megallm_key_rotator.get_next_key()
33
+ # Fallback to settings (backward compatibility)
34
+ if settings.megallm_api_key:
35
+ return settings.megallm_api_key
36
+ raise ValueError("No MegaLLM API keys configured")
37
 
38
  async def generate(
39
  self,
 
54
  Returns:
55
  Generated text
56
  """
57
+ # Get rotated API key
58
+ api_key = self._get_api_key()
59
 
60
  messages = []
61
  if system_instruction:
 
69
  response = await client.post(
70
  f"{self.base_url}/chat/completions",
71
  headers={
72
+ "Authorization": f"Bearer {api_key}",
73
  "Content-Type": "application/json",
74
  },
75
  json={
 
110
  Returns:
111
  Generated text response
112
  """
113
+ # Get rotated API key
114
+ api_key = self._get_api_key()
115
 
116
  chat_messages = []
117
  if system_instruction:
 
127
  response = await client.post(
128
  f"{self.base_url}/chat/completions",
129
  headers={
130
+ "Authorization": f"Bearer {api_key}",
131
  "Content-Type": "application/json",
132
  },
133
  json={
tests/test_key_rotation.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for MegaLLM API Key Rotation.
2
+
3
+ Run with:
4
+ cd /Volumes/WorkSpace/Project/LocalMate/localmate-danang-backend-v2
5
+ python -m pytest tests/test_key_rotation.py -v
6
+ """
7
+
8
+ import os
9
+ import threading
10
+ from concurrent.futures import ThreadPoolExecutor
11
+ from unittest.mock import patch
12
+
13
+ import pytest
14
+
15
+
16
+ class TestKeyRotator:
17
+ """Tests for KeyRotator class."""
18
+
19
+ def test_rotation_cycles_through_keys(self):
20
+ """Verify round-robin cycles through all keys in order."""
21
+ from app.shared.integrations.key_rotator import KeyRotator
22
+
23
+ keys = ["key_1", "key_2", "key_3"]
24
+ rotator = KeyRotator(keys, name="test")
25
+
26
+ # First cycle
27
+ assert rotator.get_next_key() == "key_1"
28
+ assert rotator.get_next_key() == "key_2"
29
+ assert rotator.get_next_key() == "key_3"
30
+
31
+ # Second cycle (should loop back)
32
+ assert rotator.get_next_key() == "key_1"
33
+ assert rotator.get_next_key() == "key_2"
34
+ assert rotator.get_next_key() == "key_3"
35
+
36
+ # Verify request count
37
+ assert rotator.request_count == 6
38
+
39
+ def test_single_key_always_returns_same(self):
40
+ """Verify single key mode works correctly."""
41
+ from app.shared.integrations.key_rotator import KeyRotator
42
+
43
+ keys = ["only_key"]
44
+ rotator = KeyRotator(keys, name="single")
45
+
46
+ for _ in range(5):
47
+ assert rotator.get_next_key() == "only_key"
48
+
49
+ assert rotator.request_count == 5
50
+
51
+ def test_empty_keys_raises_error(self):
52
+ """Verify empty keys list raises ValueError."""
53
+ from app.shared.integrations.key_rotator import KeyRotator
54
+
55
+ with pytest.raises(ValueError, match="At least one API key is required"):
56
+ KeyRotator([], name="empty")
57
+
58
+ def test_rotation_thread_safety(self):
59
+ """Verify rotation is thread-safe under concurrent access."""
60
+ from app.shared.integrations.key_rotator import KeyRotator
61
+
62
+ keys = ["key_1", "key_2", "key_3"]
63
+ rotator = KeyRotator(keys, name="threaded")
64
+
65
+ results = []
66
+ lock = threading.Lock()
67
+
68
+ def get_key():
69
+ key = rotator.get_next_key()
70
+ with lock:
71
+ results.append(key)
72
+
73
+ # Run 100 concurrent requests
74
+ with ThreadPoolExecutor(max_workers=10) as executor:
75
+ futures = [executor.submit(get_key) for _ in range(100)]
76
+ for future in futures:
77
+ future.result()
78
+
79
+ # Should have 100 results
80
+ assert len(results) == 100
81
+ assert rotator.request_count == 100
82
+
83
+ # Each key should be used roughly equally (with some variance due to threading)
84
+ for key in keys:
85
+ count = results.count(key)
86
+ # Should be approximately 33 each, allow 20% variance
87
+ assert 20 <= count <= 45, f"Key {key} used {count} times (expected ~33)"
88
+
89
+ def test_get_stats(self):
90
+ """Verify stats reporting works."""
91
+ from app.shared.integrations.key_rotator import KeyRotator
92
+
93
+ keys = ["key_1", "key_2"]
94
+ rotator = KeyRotator(keys, name="stats_test")
95
+
96
+ rotator.get_next_key()
97
+ rotator.get_next_key()
98
+ rotator.get_next_key()
99
+
100
+ stats = rotator.get_stats()
101
+ assert stats["name"] == "stats_test"
102
+ assert stats["total_keys"] == 2
103
+ assert stats["total_requests"] == 3
104
+ assert stats["current_index"] == 1 # After 3 requests: 0->1->0->1
105
+
106
+
107
+ class TestLoadMegaLLMKeys:
108
+ """Tests for environment-based key loading."""
109
+
110
+ def test_load_numbered_keys(self):
111
+ """Verify loading MEGALLM_API_KEY_1, _2, _3 format."""
112
+ env_vars = {
113
+ "MEGALLM_API_KEY_1": "first_key",
114
+ "MEGALLM_API_KEY_2": "second_key",
115
+ "MEGALLM_API_KEY_3": "third_key",
116
+ }
117
+
118
+ with patch.dict(os.environ, env_vars, clear=False):
119
+ from importlib import reload
120
+ from app.shared.integrations import key_rotator
121
+ reload(key_rotator)
122
+
123
+ keys = key_rotator.load_megallm_keys()
124
+ assert keys == ["first_key", "second_key", "third_key"]
125
+
126
+ def test_load_fallback_single_key(self):
127
+ """Verify fallback to MEGALLM_API_KEY (legacy format)."""
128
+ # Clear any numbered keys
129
+ env_vars = {
130
+ "MEGALLM_API_KEY": "legacy_key",
131
+ }
132
+
133
+ # Remove any numbered keys that might exist
134
+ for i in range(1, 10):
135
+ env_vars[f"MEGALLM_API_KEY_{i}"] = ""
136
+
137
+ with patch.dict(os.environ, env_vars, clear=False):
138
+ from importlib import reload
139
+ from app.shared.integrations import key_rotator
140
+ reload(key_rotator)
141
+
142
+ # Note: This test may need adjustment based on environment state
143
+ keys = key_rotator.load_megallm_keys()
144
+ # Should have at least the legacy key if no numbered keys
145
+ assert len(keys) >= 1
146
+
147
+
148
+ if __name__ == "__main__":
149
+ pytest.main([__file__, "-v"])