fix logic agent and add api rotation
Browse files- .env.example +5 -5
- app/agent/mmca_agent.py +116 -12
- app/api/router.py +10 -20
- app/shared/integrations/key_rotator.py +142 -0
- app/shared/integrations/megallm_client.py +22 -9
- tests/test_key_rotation.py +149 -0
.env.example
CHANGED
|
@@ -20,8 +20,11 @@ GOOGLE_API_KEY=your_google_api_key
|
|
| 20 |
GOOGLE_CLIENT_ID=your_google_api_key
|
| 21 |
JWT_SECRET=your-super-secret-jwt-key-change-in-production
|
| 22 |
|
| 23 |
-
# MegaLLM (
|
| 24 |
-
|
|
|
|
|
|
|
|
|
|
| 25 |
MEGALLM_BASE_URL=https://ai.megallm.io/v1
|
| 26 |
|
| 27 |
# Brave Social Search
|
|
@@ -29,6 +32,3 @@ BRAVE_API_KEY=your_brave_api_key
|
|
| 29 |
|
| 30 |
# Google OAuth
|
| 31 |
GOOGLE_CLIENT_ID=your_google_client_id
|
| 32 |
-
|
| 33 |
-
# CLIP (optional - for image embeddings)
|
| 34 |
-
HUGGINGFACE_API_KEY=your_hf_api_key
|
|
|
|
| 20 |
GOOGLE_CLIENT_ID=your_google_api_key
|
| 21 |
JWT_SECRET=your-super-secret-jwt-key-change-in-production
|
| 22 |
|
| 23 |
+
# MegaLLM (API Key Rotation - add as many as needed)
|
| 24 |
+
# Keys are rotated round-robin to avoid 15 req/min limit per key
|
| 25 |
+
MEGALLM_API_KEY_1=your_first_megallm_api_key
|
| 26 |
+
MEGALLM_API_KEY_2=your_second_megallm_api_key
|
| 27 |
+
MEGALLM_API_KEY_3=your_third_megallm_api_key
|
| 28 |
MEGALLM_BASE_URL=https://ai.megallm.io/v1
|
| 29 |
|
| 30 |
# Brave Social Search
|
|
|
|
| 32 |
|
| 33 |
# Google OAuth
|
| 34 |
GOOGLE_CLIENT_ID=your_google_client_id
|
|
|
|
|
|
|
|
|
app/agent/mmca_agent.py
CHANGED
|
@@ -10,6 +10,7 @@ Supports multiple LLM providers: Google (Gemini) and MegaLLM (DeepSeek).
|
|
| 10 |
"""
|
| 11 |
|
| 12 |
import json
|
|
|
|
| 13 |
import time
|
| 14 |
from dataclasses import dataclass, field
|
| 15 |
from typing import Any
|
|
@@ -83,6 +84,7 @@ class ChatResult:
|
|
| 83 |
tools_used: list[str] = field(default_factory=list)
|
| 84 |
total_duration_ms: float = 0
|
| 85 |
tool_results: list = field(default_factory=list) # List of ToolCall with results
|
|
|
|
| 86 |
|
| 87 |
|
| 88 |
class MMCAAgent:
|
|
@@ -209,7 +211,7 @@ class MMCAAgent:
|
|
| 209 |
agent_logger.workflow_step("Step 3: Synthesize Response")
|
| 210 |
|
| 211 |
llm_start = time.time()
|
| 212 |
-
response = await self._synthesize_response(message, tool_results, image_url, history)
|
| 213 |
llm_duration = (time.time() - llm_start) * 1000
|
| 214 |
|
| 215 |
agent_logger.llm_response(self.provider, response[:100], tokens=None)
|
|
@@ -228,14 +230,15 @@ class MMCAAgent:
|
|
| 228 |
workflow.total_duration_ms = total_duration
|
| 229 |
|
| 230 |
# Log complete
|
| 231 |
-
agent_logger.api_response("/chat", 200, {"response_len": len(response)}, total_duration)
|
| 232 |
|
| 233 |
return ChatResult(
|
| 234 |
response=response,
|
| 235 |
workflow=workflow,
|
| 236 |
tools_used=workflow.tools_used,
|
| 237 |
total_duration_ms=total_duration,
|
| 238 |
-
tool_results=tool_results,
|
|
|
|
| 239 |
)
|
| 240 |
|
| 241 |
def _detect_intent(self, message: str, image_url: str | None) -> str:
|
|
@@ -269,6 +272,35 @@ class MMCAAgent:
|
|
| 269 |
}
|
| 270 |
return purposes.get(tool_name, tool_name)
|
| 271 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
async def _plan_tool_calls(
|
| 273 |
self,
|
| 274 |
message: str,
|
|
@@ -278,7 +310,13 @@ class MMCAAgent:
|
|
| 278 |
Analyze message and plan which tools to call.
|
| 279 |
|
| 280 |
Returns list of ToolCall objects with tool_name and arguments.
|
|
|
|
| 281 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
tool_calls = []
|
| 283 |
|
| 284 |
# If image is provided, always use visual search
|
|
@@ -343,7 +381,6 @@ class MMCAAgent:
|
|
| 343 |
arguments={"query": message, "limit": 5},
|
| 344 |
))
|
| 345 |
|
| 346 |
-
|
| 347 |
return tool_calls
|
| 348 |
|
| 349 |
async def _execute_tool(
|
|
@@ -442,8 +479,39 @@ class MMCAAgent:
|
|
| 442 |
tool_results: list[ToolCall],
|
| 443 |
image_url: str | None = None,
|
| 444 |
history: str | None = None,
|
| 445 |
-
) -> str:
|
| 446 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 447 |
# Build context from tool results
|
| 448 |
context_parts = []
|
| 449 |
for tool_call in tool_results:
|
|
@@ -452,7 +520,7 @@ class MMCAAgent:
|
|
| 452 |
f"Kết quả từ {tool_call.tool_name}:\n{json.dumps(tool_call.result, ensure_ascii=False, indent=2)}"
|
| 453 |
)
|
| 454 |
|
| 455 |
-
context = "\n\n".join(context_parts)
|
| 456 |
|
| 457 |
# Build history section if available
|
| 458 |
history_section = ""
|
|
@@ -463,25 +531,61 @@ class MMCAAgent:
|
|
| 463 |
---
|
| 464 |
"""
|
| 465 |
|
| 466 |
-
# Generate response using LLM
|
| 467 |
-
prompt = f"""{history_section}Dựa trên kết quả tìm kiếm sau, hãy trả lời câu hỏi của người dùng
|
| 468 |
|
| 469 |
Câu hỏi hiện tại: {message}
|
| 470 |
|
| 471 |
{context}
|
| 472 |
|
| 473 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 474 |
Nếu có lịch sử hội thoại, hãy cân nhắc ngữ cảnh trước đó khi trả lời."""
|
| 475 |
|
| 476 |
agent_logger.llm_call(self.provider, self.model or "default", prompt[:100])
|
| 477 |
|
| 478 |
-
|
| 479 |
prompt=prompt,
|
| 480 |
temperature=0.7,
|
| 481 |
system_instruction=SYSTEM_PROMPT,
|
| 482 |
)
|
| 483 |
|
| 484 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 485 |
|
| 486 |
def _extract_location(self, message: str) -> str | None:
|
| 487 |
"""Extract location name from message using pattern matching."""
|
|
|
|
| 10 |
"""
|
| 11 |
|
| 12 |
import json
|
| 13 |
+
import re
|
| 14 |
import time
|
| 15 |
from dataclasses import dataclass, field
|
| 16 |
from typing import Any
|
|
|
|
| 84 |
tools_used: list[str] = field(default_factory=list)
|
| 85 |
total_duration_ms: float = 0
|
| 86 |
tool_results: list = field(default_factory=list) # List of ToolCall with results
|
| 87 |
+
selected_place_ids: list[str] = field(default_factory=list) # LLM-selected place IDs
|
| 88 |
|
| 89 |
|
| 90 |
class MMCAAgent:
|
|
|
|
| 211 |
agent_logger.workflow_step("Step 3: Synthesize Response")
|
| 212 |
|
| 213 |
llm_start = time.time()
|
| 214 |
+
response, selected_place_ids = await self._synthesize_response(message, tool_results, image_url, history)
|
| 215 |
llm_duration = (time.time() - llm_start) * 1000
|
| 216 |
|
| 217 |
agent_logger.llm_response(self.provider, response[:100], tokens=None)
|
|
|
|
| 230 |
workflow.total_duration_ms = total_duration
|
| 231 |
|
| 232 |
# Log complete
|
| 233 |
+
agent_logger.api_response("/chat", 200, {"response_len": len(response), "places": len(selected_place_ids)}, total_duration)
|
| 234 |
|
| 235 |
return ChatResult(
|
| 236 |
response=response,
|
| 237 |
workflow=workflow,
|
| 238 |
tools_used=workflow.tools_used,
|
| 239 |
total_duration_ms=total_duration,
|
| 240 |
+
tool_results=tool_results,
|
| 241 |
+
selected_place_ids=selected_place_ids,
|
| 242 |
)
|
| 243 |
|
| 244 |
def _detect_intent(self, message: str, image_url: str | None) -> str:
|
|
|
|
| 272 |
}
|
| 273 |
return purposes.get(tool_name, tool_name)
|
| 274 |
|
| 275 |
+
def _is_greeting_or_simple_query(self, message: str) -> bool:
|
| 276 |
+
"""
|
| 277 |
+
Check if message is a simple greeting/small-talk that doesn't need tools.
|
| 278 |
+
|
| 279 |
+
Returns True for greetings, thanks, simple acknowledgments.
|
| 280 |
+
"""
|
| 281 |
+
simple_patterns = [
|
| 282 |
+
# English
|
| 283 |
+
"hello", "hi", "hey", "yo", "sup",
|
| 284 |
+
"thank", "thanks", "bye", "goodbye",
|
| 285 |
+
"ok", "okay", "yes", "no", "good", "great", "nice",
|
| 286 |
+
# Vietnamese
|
| 287 |
+
"xin chào", "chào", "chào bạn", "ê", "alo",
|
| 288 |
+
"cảm ơn", "cám ơn", "thanks", "tạm biệt", "bye",
|
| 289 |
+
"ok", "được", "tốt", "hay", "ừ", "ờ", "vâng", "dạ",
|
| 290 |
+
]
|
| 291 |
+
msg_lower = message.lower().strip()
|
| 292 |
+
|
| 293 |
+
# Very short messages are likely greetings
|
| 294 |
+
if len(msg_lower) < 15:
|
| 295 |
+
for pattern in simple_patterns:
|
| 296 |
+
if pattern in msg_lower:
|
| 297 |
+
return True
|
| 298 |
+
# Also check if message is just a single word greeting
|
| 299 |
+
if msg_lower in simple_patterns:
|
| 300 |
+
return True
|
| 301 |
+
|
| 302 |
+
return False
|
| 303 |
+
|
| 304 |
async def _plan_tool_calls(
|
| 305 |
self,
|
| 306 |
message: str,
|
|
|
|
| 310 |
Analyze message and plan which tools to call.
|
| 311 |
|
| 312 |
Returns list of ToolCall objects with tool_name and arguments.
|
| 313 |
+
Returns empty list for simple greetings (no tools needed).
|
| 314 |
"""
|
| 315 |
+
# Early exit for greetings - no tools needed
|
| 316 |
+
if self._is_greeting_or_simple_query(message) and not image_url:
|
| 317 |
+
agent_logger.workflow_step("Greeting detected", "Skipping tools")
|
| 318 |
+
return []
|
| 319 |
+
|
| 320 |
tool_calls = []
|
| 321 |
|
| 322 |
# If image is provided, always use visual search
|
|
|
|
| 381 |
arguments={"query": message, "limit": 5},
|
| 382 |
))
|
| 383 |
|
|
|
|
| 384 |
return tool_calls
|
| 385 |
|
| 386 |
async def _execute_tool(
|
|
|
|
| 479 |
tool_results: list[ToolCall],
|
| 480 |
image_url: str | None = None,
|
| 481 |
history: str | None = None,
|
| 482 |
+
) -> tuple[str, list[str]]:
|
| 483 |
+
"""
|
| 484 |
+
Synthesize final response from tool results with conversation history.
|
| 485 |
+
|
| 486 |
+
Returns:
|
| 487 |
+
Tuple of (response_text, selected_place_ids)
|
| 488 |
+
"""
|
| 489 |
+
# Collect all available place_ids from tool results
|
| 490 |
+
all_place_ids = []
|
| 491 |
+
for tool_call in tool_results:
|
| 492 |
+
if tool_call.result:
|
| 493 |
+
for item in tool_call.result:
|
| 494 |
+
if isinstance(item, dict) and 'place_id' in item:
|
| 495 |
+
all_place_ids.append(item['place_id'])
|
| 496 |
+
|
| 497 |
+
# If no tool results (greeting case), return simple response
|
| 498 |
+
if not tool_results:
|
| 499 |
+
# Build history section if available
|
| 500 |
+
history_section = ""
|
| 501 |
+
if history:
|
| 502 |
+
history_section = f"Lịch sử hội thoại:\n{history}\n\n---\n"
|
| 503 |
+
|
| 504 |
+
prompt = f"""{history_section}User nói: "{message}"
|
| 505 |
+
|
| 506 |
+
Hãy trả lời thân thiện bằng tiếng Việt. Đây là lời chào hoặc tin nhắn đơn giản, không cần tìm kiếm địa điểm."""
|
| 507 |
+
|
| 508 |
+
response = await self.llm_client.generate(
|
| 509 |
+
prompt=prompt,
|
| 510 |
+
temperature=0.7,
|
| 511 |
+
system_instruction="Bạn là LocalMate - trợ lý du lịch thân thiện cho Đà Nẵng. Trả lời ngắn gọn, thân thiện.",
|
| 512 |
+
)
|
| 513 |
+
return response, []
|
| 514 |
+
|
| 515 |
# Build context from tool results
|
| 516 |
context_parts = []
|
| 517 |
for tool_call in tool_results:
|
|
|
|
| 520 |
f"Kết quả từ {tool_call.tool_name}:\n{json.dumps(tool_call.result, ensure_ascii=False, indent=2)}"
|
| 521 |
)
|
| 522 |
|
| 523 |
+
context = "\n\n".join(context_parts)
|
| 524 |
|
| 525 |
# Build history section if available
|
| 526 |
history_section = ""
|
|
|
|
| 531 |
---
|
| 532 |
"""
|
| 533 |
|
| 534 |
+
# Generate response using LLM with JSON format for place selection
|
| 535 |
+
prompt = f"""{history_section}Dựa trên kết quả tìm kiếm sau, hãy trả lời câu hỏi của người dùng.
|
| 536 |
|
| 537 |
Câu hỏi hiện tại: {message}
|
| 538 |
|
| 539 |
{context}
|
| 540 |
|
| 541 |
+
**QUAN TRỌNG:** Trả lời theo format JSON:
|
| 542 |
+
```json
|
| 543 |
+
{{
|
| 544 |
+
"response": "Câu trả lời tiếng Việt, thân thiện. Giới thiệu top 2-3 địa điểm phù hợp nhất.",
|
| 545 |
+
"selected_place_ids": ["place_id_1", "place_id_2", "place_id_3"]
|
| 546 |
+
}}
|
| 547 |
+
```
|
| 548 |
+
|
| 549 |
+
Chỉ chọn những place_id xuất hiện trong kết quả tìm kiếm ở trên. Nếu không có địa điểm phù hợp, để mảng rỗng.
|
| 550 |
Nếu có lịch sử hội thoại, hãy cân nhắc ngữ cảnh trước đó khi trả lời."""
|
| 551 |
|
| 552 |
agent_logger.llm_call(self.provider, self.model or "default", prompt[:100])
|
| 553 |
|
| 554 |
+
raw_response = await self.llm_client.generate(
|
| 555 |
prompt=prompt,
|
| 556 |
temperature=0.7,
|
| 557 |
system_instruction=SYSTEM_PROMPT,
|
| 558 |
)
|
| 559 |
|
| 560 |
+
# Parse JSON response
|
| 561 |
+
try:
|
| 562 |
+
# Extract JSON from code blocks
|
| 563 |
+
json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', raw_response, re.DOTALL)
|
| 564 |
+
if json_match:
|
| 565 |
+
json_str = json_match.group(1)
|
| 566 |
+
else:
|
| 567 |
+
# Try to find raw JSON
|
| 568 |
+
json_start = raw_response.find('{')
|
| 569 |
+
json_end = raw_response.rfind('}')
|
| 570 |
+
if json_start != -1 and json_end != -1:
|
| 571 |
+
json_str = raw_response[json_start:json_end + 1]
|
| 572 |
+
else:
|
| 573 |
+
# No JSON found, return raw response
|
| 574 |
+
return raw_response, []
|
| 575 |
+
|
| 576 |
+
data = json.loads(json_str)
|
| 577 |
+
text_response = data.get("response", raw_response)
|
| 578 |
+
selected_ids = data.get("selected_place_ids", [])
|
| 579 |
+
|
| 580 |
+
# Validate selected_ids are in available places
|
| 581 |
+
valid_ids = [pid for pid in selected_ids if pid in all_place_ids]
|
| 582 |
+
|
| 583 |
+
return text_response, valid_ids
|
| 584 |
+
|
| 585 |
+
except (json.JSONDecodeError, KeyError) as e:
|
| 586 |
+
agent_logger.error("Failed to parse synthesis JSON", e)
|
| 587 |
+
# Fallback: return raw response with no places
|
| 588 |
+
return raw_response, []
|
| 589 |
|
| 590 |
def _extract_location(self, message: str) -> str | None:
|
| 591 |
"""Extract location name from message using pattern matching."""
|
app/api/router.py
CHANGED
|
@@ -405,30 +405,20 @@ async def chat(
|
|
| 405 |
session_id=session_id,
|
| 406 |
)
|
| 407 |
|
| 408 |
-
#
|
| 409 |
places = []
|
| 410 |
-
if result.
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
distance_map = {}
|
| 414 |
for tool_call in result.tool_results:
|
| 415 |
-
# ToolCall has .result attribute which is a list of dicts
|
| 416 |
if tool_call.result:
|
| 417 |
for item in tool_call.result:
|
| 418 |
-
if isinstance(item, dict) and 'place_id' in item:
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
|
| 423 |
-
if 'distance_km' in item:
|
| 424 |
-
distance_map[pid] = item['distance_km']
|
| 425 |
-
|
| 426 |
-
if place_ids:
|
| 427 |
-
places = await enrich_places_from_ids(place_ids[:5], db) # Limit to top 5
|
| 428 |
-
# Add distance info to places
|
| 429 |
-
for place in places:
|
| 430 |
-
if place.place_id in distance_map:
|
| 431 |
-
place.distance_km = distance_map[place.place_id]
|
| 432 |
|
| 433 |
return ChatResponse(
|
| 434 |
response=result.response,
|
|
|
|
| 405 |
session_id=session_id,
|
| 406 |
)
|
| 407 |
|
| 408 |
+
# Use LLM-selected places (same pattern as ReAct mode)
|
| 409 |
places = []
|
| 410 |
+
if result.selected_place_ids:
|
| 411 |
+
places = await enrich_places_from_ids(result.selected_place_ids, db)
|
| 412 |
+
# Add distance info if available from tool results
|
| 413 |
+
distance_map = {}
|
| 414 |
for tool_call in result.tool_results:
|
|
|
|
| 415 |
if tool_call.result:
|
| 416 |
for item in tool_call.result:
|
| 417 |
+
if isinstance(item, dict) and 'place_id' in item and 'distance_km' in item:
|
| 418 |
+
distance_map[item['place_id']] = item['distance_km']
|
| 419 |
+
for place in places:
|
| 420 |
+
if place.place_id in distance_map:
|
| 421 |
+
place.distance_km = distance_map[place.place_id]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 422 |
|
| 423 |
return ChatResponse(
|
| 424 |
response=result.response,
|
app/shared/integrations/key_rotator.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Thread-safe API Key Rotator for load balancing across multiple keys.
|
| 2 |
+
|
| 3 |
+
This module provides a round-robin key rotation mechanism to distribute
|
| 4 |
+
API requests across multiple keys, helping to avoid rate limits.
|
| 5 |
+
|
| 6 |
+
Usage:
|
| 7 |
+
from app.shared.integrations.key_rotator import megallm_key_rotator
|
| 8 |
+
|
| 9 |
+
api_key = megallm_key_rotator.get_next_key()
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
import os
|
| 14 |
+
import threading
|
| 15 |
+
from typing import List
|
| 16 |
+
|
| 17 |
+
from dotenv import load_dotenv
|
| 18 |
+
|
| 19 |
+
# Ensure .env is loaded before accessing os.environ
|
| 20 |
+
load_dotenv()
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class KeyRotator:
|
| 26 |
+
"""Thread-safe round-robin API key rotator.
|
| 27 |
+
|
| 28 |
+
Distributes API calls across multiple keys to avoid per-key rate limits.
|
| 29 |
+
Each call to get_next_key() returns the next key in rotation.
|
| 30 |
+
|
| 31 |
+
Attributes:
|
| 32 |
+
_keys: List of API keys to rotate through
|
| 33 |
+
_index: Current position in rotation
|
| 34 |
+
_lock: Thread lock for safe concurrent access
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(self, keys: List[str], name: str = "default"):
|
| 38 |
+
"""Initialize the key rotator.
|
| 39 |
+
|
| 40 |
+
Args:
|
| 41 |
+
keys: List of API keys (must have at least one)
|
| 42 |
+
name: Name for logging identification
|
| 43 |
+
|
| 44 |
+
Raises:
|
| 45 |
+
ValueError: If keys list is empty
|
| 46 |
+
"""
|
| 47 |
+
if not keys:
|
| 48 |
+
raise ValueError("At least one API key is required")
|
| 49 |
+
|
| 50 |
+
self._keys = keys
|
| 51 |
+
self._name = name
|
| 52 |
+
self._index = 0
|
| 53 |
+
self._lock = threading.Lock()
|
| 54 |
+
self._request_count = 0
|
| 55 |
+
|
| 56 |
+
logger.info(f"[KeyRotator:{name}] Initialized with {len(keys)} API keys")
|
| 57 |
+
|
| 58 |
+
def get_next_key(self) -> str:
|
| 59 |
+
"""Get next API key in rotation (thread-safe).
|
| 60 |
+
|
| 61 |
+
Returns:
|
| 62 |
+
The next API key in round-robin order
|
| 63 |
+
"""
|
| 64 |
+
with self._lock:
|
| 65 |
+
key = self._keys[self._index]
|
| 66 |
+
key_index = self._index + 1 # 1-based for logging
|
| 67 |
+
self._index = (self._index + 1) % len(self._keys)
|
| 68 |
+
self._request_count += 1
|
| 69 |
+
|
| 70 |
+
# Log rotation (mask key for security, only show last 8 chars)
|
| 71 |
+
masked_key = f"...{key[-8:]}" if len(key) > 8 else key
|
| 72 |
+
logger.info(
|
| 73 |
+
f"[KeyRotator:{self._name}] Request #{self._request_count} "
|
| 74 |
+
f"using key {key_index}/{len(self._keys)} ({masked_key})"
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
return key
|
| 78 |
+
|
| 79 |
+
@property
|
| 80 |
+
def total_keys(self) -> int:
|
| 81 |
+
"""Number of keys in rotation."""
|
| 82 |
+
return len(self._keys)
|
| 83 |
+
|
| 84 |
+
@property
|
| 85 |
+
def request_count(self) -> int:
|
| 86 |
+
"""Total number of requests made through this rotator."""
|
| 87 |
+
return self._request_count
|
| 88 |
+
|
| 89 |
+
def get_stats(self) -> dict:
|
| 90 |
+
"""Get rotation statistics for debugging."""
|
| 91 |
+
return {
|
| 92 |
+
"name": self._name,
|
| 93 |
+
"total_keys": len(self._keys),
|
| 94 |
+
"current_index": self._index,
|
| 95 |
+
"total_requests": self._request_count,
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def load_megallm_keys() -> List[str]:
|
| 100 |
+
"""Load all MEGALLM_API_KEY_* from environment variables.
|
| 101 |
+
|
| 102 |
+
Looks for keys in format: MEGALLM_API_KEY_1, MEGALLM_API_KEY_2, etc.
|
| 103 |
+
Falls back to single MEGALLM_API_KEY for backward compatibility.
|
| 104 |
+
|
| 105 |
+
Returns:
|
| 106 |
+
List of API keys found in environment
|
| 107 |
+
"""
|
| 108 |
+
keys = []
|
| 109 |
+
i = 1
|
| 110 |
+
|
| 111 |
+
# Load numbered keys (MEGALLM_API_KEY_1, MEGALLM_API_KEY_2, ...)
|
| 112 |
+
while True:
|
| 113 |
+
key = os.environ.get(f"MEGALLM_API_KEY_{i}")
|
| 114 |
+
if not key:
|
| 115 |
+
break
|
| 116 |
+
keys.append(key)
|
| 117 |
+
i += 1
|
| 118 |
+
|
| 119 |
+
# Fallback to single key for backward compatibility
|
| 120 |
+
if not keys:
|
| 121 |
+
single_key = os.environ.get("MEGALLM_API_KEY")
|
| 122 |
+
if single_key:
|
| 123 |
+
keys = [single_key]
|
| 124 |
+
logger.warning(
|
| 125 |
+
"[KeyRotator] Using legacy MEGALLM_API_KEY. "
|
| 126 |
+
"Consider migrating to MEGALLM_API_KEY_1, MEGALLM_API_KEY_2, etc."
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
if keys:
|
| 130 |
+
logger.info(f"[KeyRotator] Loaded {len(keys)} MegaLLM API key(s)")
|
| 131 |
+
else:
|
| 132 |
+
logger.warning("[KeyRotator] No MegaLLM API keys found in environment")
|
| 133 |
+
|
| 134 |
+
return keys
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
# Singleton instance for MegaLLM key rotation
|
| 138 |
+
_megallm_keys = load_megallm_keys()
|
| 139 |
+
megallm_key_rotator: KeyRotator | None = None
|
| 140 |
+
|
| 141 |
+
if _megallm_keys:
|
| 142 |
+
megallm_key_rotator = KeyRotator(_megallm_keys, name="MegaLLM")
|
app/shared/integrations/megallm_client.py
CHANGED
|
@@ -1,8 +1,13 @@
|
|
| 1 |
-
"""MegaLLM client using OpenAI-compatible API with retry logic."""
|
|
|
|
|
|
|
| 2 |
|
| 3 |
import httpx
|
| 4 |
|
| 5 |
from app.core.config import settings
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
# Timeout configuration for DeepSeek reasoning models (can take longer)
|
| 8 |
REQUEST_TIMEOUT = httpx.Timeout(
|
|
@@ -14,13 +19,21 @@ REQUEST_TIMEOUT = httpx.Timeout(
|
|
| 14 |
|
| 15 |
|
| 16 |
class MegaLLMClient:
|
| 17 |
-
"""Client for MegaLLM (OpenAI-compatible API) operations."""
|
| 18 |
|
| 19 |
def __init__(self, model: str | None = None):
|
| 20 |
"""Initialize with optional model override."""
|
| 21 |
self.model = model or settings.default_megallm_model
|
| 22 |
-
self.api_key = settings.megallm_api_key
|
| 23 |
self.base_url = settings.megallm_base_url
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
async def generate(
|
| 26 |
self,
|
|
@@ -41,8 +54,8 @@ class MegaLLMClient:
|
|
| 41 |
Returns:
|
| 42 |
Generated text
|
| 43 |
"""
|
| 44 |
-
|
| 45 |
-
|
| 46 |
|
| 47 |
messages = []
|
| 48 |
if system_instruction:
|
|
@@ -56,7 +69,7 @@ class MegaLLMClient:
|
|
| 56 |
response = await client.post(
|
| 57 |
f"{self.base_url}/chat/completions",
|
| 58 |
headers={
|
| 59 |
-
"Authorization": f"Bearer {
|
| 60 |
"Content-Type": "application/json",
|
| 61 |
},
|
| 62 |
json={
|
|
@@ -97,8 +110,8 @@ class MegaLLMClient:
|
|
| 97 |
Returns:
|
| 98 |
Generated text response
|
| 99 |
"""
|
| 100 |
-
|
| 101 |
-
|
| 102 |
|
| 103 |
chat_messages = []
|
| 104 |
if system_instruction:
|
|
@@ -114,7 +127,7 @@ class MegaLLMClient:
|
|
| 114 |
response = await client.post(
|
| 115 |
f"{self.base_url}/chat/completions",
|
| 116 |
headers={
|
| 117 |
-
"Authorization": f"Bearer {
|
| 118 |
"Content-Type": "application/json",
|
| 119 |
},
|
| 120 |
json={
|
|
|
|
| 1 |
+
"""MegaLLM client using OpenAI-compatible API with retry logic and key rotation."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
|
| 5 |
import httpx
|
| 6 |
|
| 7 |
from app.core.config import settings
|
| 8 |
+
from app.shared.integrations.key_rotator import megallm_key_rotator
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
# Timeout configuration for DeepSeek reasoning models (can take longer)
|
| 13 |
REQUEST_TIMEOUT = httpx.Timeout(
|
|
|
|
| 19 |
|
| 20 |
|
| 21 |
class MegaLLMClient:
|
| 22 |
+
"""Client for MegaLLM (OpenAI-compatible API) operations with key rotation."""
|
| 23 |
|
| 24 |
def __init__(self, model: str | None = None):
|
| 25 |
"""Initialize with optional model override."""
|
| 26 |
self.model = model or settings.default_megallm_model
|
|
|
|
| 27 |
self.base_url = settings.megallm_base_url
|
| 28 |
+
|
| 29 |
+
def _get_api_key(self) -> str:
|
| 30 |
+
"""Get API key using rotation or fallback to settings."""
|
| 31 |
+
if megallm_key_rotator:
|
| 32 |
+
return megallm_key_rotator.get_next_key()
|
| 33 |
+
# Fallback to settings (backward compatibility)
|
| 34 |
+
if settings.megallm_api_key:
|
| 35 |
+
return settings.megallm_api_key
|
| 36 |
+
raise ValueError("No MegaLLM API keys configured")
|
| 37 |
|
| 38 |
async def generate(
|
| 39 |
self,
|
|
|
|
| 54 |
Returns:
|
| 55 |
Generated text
|
| 56 |
"""
|
| 57 |
+
# Get rotated API key
|
| 58 |
+
api_key = self._get_api_key()
|
| 59 |
|
| 60 |
messages = []
|
| 61 |
if system_instruction:
|
|
|
|
| 69 |
response = await client.post(
|
| 70 |
f"{self.base_url}/chat/completions",
|
| 71 |
headers={
|
| 72 |
+
"Authorization": f"Bearer {api_key}",
|
| 73 |
"Content-Type": "application/json",
|
| 74 |
},
|
| 75 |
json={
|
|
|
|
| 110 |
Returns:
|
| 111 |
Generated text response
|
| 112 |
"""
|
| 113 |
+
# Get rotated API key
|
| 114 |
+
api_key = self._get_api_key()
|
| 115 |
|
| 116 |
chat_messages = []
|
| 117 |
if system_instruction:
|
|
|
|
| 127 |
response = await client.post(
|
| 128 |
f"{self.base_url}/chat/completions",
|
| 129 |
headers={
|
| 130 |
+
"Authorization": f"Bearer {api_key}",
|
| 131 |
"Content-Type": "application/json",
|
| 132 |
},
|
| 133 |
json={
|
tests/test_key_rotation.py
ADDED
|
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for MegaLLM API Key Rotation.
|
| 2 |
+
|
| 3 |
+
Run with:
|
| 4 |
+
cd /Volumes/WorkSpace/Project/LocalMate/localmate-danang-backend-v2
|
| 5 |
+
python -m pytest tests/test_key_rotation.py -v
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import threading
|
| 10 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 11 |
+
from unittest.mock import patch
|
| 12 |
+
|
| 13 |
+
import pytest
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class TestKeyRotator:
|
| 17 |
+
"""Tests for KeyRotator class."""
|
| 18 |
+
|
| 19 |
+
def test_rotation_cycles_through_keys(self):
|
| 20 |
+
"""Verify round-robin cycles through all keys in order."""
|
| 21 |
+
from app.shared.integrations.key_rotator import KeyRotator
|
| 22 |
+
|
| 23 |
+
keys = ["key_1", "key_2", "key_3"]
|
| 24 |
+
rotator = KeyRotator(keys, name="test")
|
| 25 |
+
|
| 26 |
+
# First cycle
|
| 27 |
+
assert rotator.get_next_key() == "key_1"
|
| 28 |
+
assert rotator.get_next_key() == "key_2"
|
| 29 |
+
assert rotator.get_next_key() == "key_3"
|
| 30 |
+
|
| 31 |
+
# Second cycle (should loop back)
|
| 32 |
+
assert rotator.get_next_key() == "key_1"
|
| 33 |
+
assert rotator.get_next_key() == "key_2"
|
| 34 |
+
assert rotator.get_next_key() == "key_3"
|
| 35 |
+
|
| 36 |
+
# Verify request count
|
| 37 |
+
assert rotator.request_count == 6
|
| 38 |
+
|
| 39 |
+
def test_single_key_always_returns_same(self):
|
| 40 |
+
"""Verify single key mode works correctly."""
|
| 41 |
+
from app.shared.integrations.key_rotator import KeyRotator
|
| 42 |
+
|
| 43 |
+
keys = ["only_key"]
|
| 44 |
+
rotator = KeyRotator(keys, name="single")
|
| 45 |
+
|
| 46 |
+
for _ in range(5):
|
| 47 |
+
assert rotator.get_next_key() == "only_key"
|
| 48 |
+
|
| 49 |
+
assert rotator.request_count == 5
|
| 50 |
+
|
| 51 |
+
def test_empty_keys_raises_error(self):
|
| 52 |
+
"""Verify empty keys list raises ValueError."""
|
| 53 |
+
from app.shared.integrations.key_rotator import KeyRotator
|
| 54 |
+
|
| 55 |
+
with pytest.raises(ValueError, match="At least one API key is required"):
|
| 56 |
+
KeyRotator([], name="empty")
|
| 57 |
+
|
| 58 |
+
def test_rotation_thread_safety(self):
|
| 59 |
+
"""Verify rotation is thread-safe under concurrent access."""
|
| 60 |
+
from app.shared.integrations.key_rotator import KeyRotator
|
| 61 |
+
|
| 62 |
+
keys = ["key_1", "key_2", "key_3"]
|
| 63 |
+
rotator = KeyRotator(keys, name="threaded")
|
| 64 |
+
|
| 65 |
+
results = []
|
| 66 |
+
lock = threading.Lock()
|
| 67 |
+
|
| 68 |
+
def get_key():
|
| 69 |
+
key = rotator.get_next_key()
|
| 70 |
+
with lock:
|
| 71 |
+
results.append(key)
|
| 72 |
+
|
| 73 |
+
# Run 100 concurrent requests
|
| 74 |
+
with ThreadPoolExecutor(max_workers=10) as executor:
|
| 75 |
+
futures = [executor.submit(get_key) for _ in range(100)]
|
| 76 |
+
for future in futures:
|
| 77 |
+
future.result()
|
| 78 |
+
|
| 79 |
+
# Should have 100 results
|
| 80 |
+
assert len(results) == 100
|
| 81 |
+
assert rotator.request_count == 100
|
| 82 |
+
|
| 83 |
+
# Each key should be used roughly equally (with some variance due to threading)
|
| 84 |
+
for key in keys:
|
| 85 |
+
count = results.count(key)
|
| 86 |
+
# Should be approximately 33 each, allow 20% variance
|
| 87 |
+
assert 20 <= count <= 45, f"Key {key} used {count} times (expected ~33)"
|
| 88 |
+
|
| 89 |
+
def test_get_stats(self):
|
| 90 |
+
"""Verify stats reporting works."""
|
| 91 |
+
from app.shared.integrations.key_rotator import KeyRotator
|
| 92 |
+
|
| 93 |
+
keys = ["key_1", "key_2"]
|
| 94 |
+
rotator = KeyRotator(keys, name="stats_test")
|
| 95 |
+
|
| 96 |
+
rotator.get_next_key()
|
| 97 |
+
rotator.get_next_key()
|
| 98 |
+
rotator.get_next_key()
|
| 99 |
+
|
| 100 |
+
stats = rotator.get_stats()
|
| 101 |
+
assert stats["name"] == "stats_test"
|
| 102 |
+
assert stats["total_keys"] == 2
|
| 103 |
+
assert stats["total_requests"] == 3
|
| 104 |
+
assert stats["current_index"] == 1 # After 3 requests: 0->1->0->1
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class TestLoadMegaLLMKeys:
|
| 108 |
+
"""Tests for environment-based key loading."""
|
| 109 |
+
|
| 110 |
+
def test_load_numbered_keys(self):
|
| 111 |
+
"""Verify loading MEGALLM_API_KEY_1, _2, _3 format."""
|
| 112 |
+
env_vars = {
|
| 113 |
+
"MEGALLM_API_KEY_1": "first_key",
|
| 114 |
+
"MEGALLM_API_KEY_2": "second_key",
|
| 115 |
+
"MEGALLM_API_KEY_3": "third_key",
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
with patch.dict(os.environ, env_vars, clear=False):
|
| 119 |
+
from importlib import reload
|
| 120 |
+
from app.shared.integrations import key_rotator
|
| 121 |
+
reload(key_rotator)
|
| 122 |
+
|
| 123 |
+
keys = key_rotator.load_megallm_keys()
|
| 124 |
+
assert keys == ["first_key", "second_key", "third_key"]
|
| 125 |
+
|
| 126 |
+
def test_load_fallback_single_key(self):
|
| 127 |
+
"""Verify fallback to MEGALLM_API_KEY (legacy format)."""
|
| 128 |
+
# Clear any numbered keys
|
| 129 |
+
env_vars = {
|
| 130 |
+
"MEGALLM_API_KEY": "legacy_key",
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
# Remove any numbered keys that might exist
|
| 134 |
+
for i in range(1, 10):
|
| 135 |
+
env_vars[f"MEGALLM_API_KEY_{i}"] = ""
|
| 136 |
+
|
| 137 |
+
with patch.dict(os.environ, env_vars, clear=False):
|
| 138 |
+
from importlib import reload
|
| 139 |
+
from app.shared.integrations import key_rotator
|
| 140 |
+
reload(key_rotator)
|
| 141 |
+
|
| 142 |
+
# Note: This test may need adjustment based on environment state
|
| 143 |
+
keys = key_rotator.load_megallm_keys()
|
| 144 |
+
# Should have at least the legacy key if no numbered keys
|
| 145 |
+
assert len(keys) >= 1
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
if __name__ == "__main__":
|
| 149 |
+
pytest.main([__file__, "-v"])
|