Spaces:
Running
Running
| from __future__ import annotations | |
| import json | |
| import logging | |
| from typing import Any, Dict, List, Optional | |
| import aiohttp | |
| from redis.asyncio import Redis | |
| from app.config import ( | |
| AION_LABS_BASE_URL, | |
| AION_LABS_CHAT_PATH, | |
| AION_LABS_DEFAULT_MODEL, | |
| DEFAULT_MAX_TOKENS, | |
| DEFAULT_TEMPERATURE, | |
| DEFAULT_TOP_P, | |
| MEGANOVA_BASE_URL, | |
| MEGANOVA_CHAT_PATH, | |
| MODELS, | |
| OPENROUTER_MIMIKA_BASE_URL, | |
| OPENROUTER_MIMIKA_CHAT_PATH, | |
| OPENROUTER_MIMIKA_MODEL, | |
| get_settings, | |
| ) | |
| from app.utils.json_utils import extract_json_blocks | |
| logger = logging.getLogger(__name__) | |
| _settings = get_settings() | |
| PREFIX = "ai_lb" | |
| AION_PREFIX = "ai_lb:aion" | |
| DEFAULT_JSON_PROMPT = "Return your response as a valid JSON object inside a JSON code block (```json)." | |
| KEY_IDS: List[str] = [] | |
| KEY_MAP: Dict[str, str] = {} | |
| def _refresh_keys() -> None: | |
| global KEY_IDS, KEY_MAP | |
| keys_str = _settings.ai_api_keys | |
| if keys_str: | |
| keys = [k.strip() for k in keys_str.split(",") if k.strip()] | |
| KEY_IDS = [f"k{i}" for i in range(len(keys))] | |
| KEY_MAP = {f"k{i}": key for i, key in enumerate(keys)} | |
| else: | |
| KEY_IDS = [] | |
| KEY_MAP = {} | |
| _refresh_keys() | |
| def build_default_system_prompt(model_name: Optional[str]) -> str: | |
| name = model_name or "agentdeck-1.0" | |
| return ( | |
| f"Role: You are LLM model {name}. " | |
| f"You are built by AgentDeck. " | |
| f"You are a helpful, respectful, and honest assistant. " | |
| f"Always respond in a concise and accurate manner." | |
| ) | |
| def inject_system_identity( | |
| messages: List[Dict[str, str]], | |
| model_name: str, | |
| ) -> List[Dict[str, str]]: | |
| default_prompt = build_default_system_prompt(model_name) | |
| has_system = messages and messages[0].get("role") == "system" | |
| if has_system: | |
| messages[0] = { | |
| "role": "system", | |
| "content": f"{default_prompt}\n\n{messages[0]['content']}", | |
| } | |
| else: | |
| messages.insert(0, {"role": "system", "content": default_prompt}) | |
| return messages | |
| def prepare_messages( | |
| messages: List[Dict[str, str]], | |
| response_format: Optional[Dict[str, str]], | |
| ) -> List[Dict[str, str]]: | |
| if not response_format or response_format.get("type") != "json_object": | |
| return messages | |
| has_system = messages and messages[0].get("role") == "system" | |
| if has_system: | |
| return [ | |
| { | |
| "role": "system", | |
| "content": f"{messages[0]['content']}\n\n{DEFAULT_JSON_PROMPT}", | |
| }, | |
| *messages[1:], | |
| ] | |
| return [ | |
| {"role": "system", "content": DEFAULT_JSON_PROMPT}, | |
| *messages, | |
| ] | |
| def attach_json_content( | |
| response_data: Dict[str, Any], | |
| response_format: Optional[Dict[str, str]], | |
| ) -> None: | |
| if not response_format or response_format.get("type") != "json_object": | |
| return | |
| try: | |
| choices = response_data.get("choices", []) | |
| if not choices: | |
| return | |
| content = choices[0].get("message", {}).get("content", "") | |
| if content: | |
| parsed = extract_json_blocks(content) | |
| if parsed: | |
| response_data["parsed"] = ( | |
| parsed[0] if len(parsed) == 1 else parsed | |
| ) | |
| except Exception as exc: | |
| response_data["parsed"] = {"error": str(exc)} | |
| async def call_openrouter_mimika( | |
| messages: List[Dict[str, str]], | |
| response_format: Optional[Dict[str, str]], | |
| max_tokens: int = DEFAULT_MAX_TOKENS, | |
| temperature: float = DEFAULT_TEMPERATURE, | |
| top_p: float = DEFAULT_TOP_P, | |
| ) -> Optional[Dict[str, Any]]: | |
| api_key = _settings.openrouter_mimika_api_key | |
| if not api_key: | |
| logger.info("OPENROUTER_MIMIKA_API_KEY not set, skipping") | |
| return None | |
| prepared = prepare_messages(messages, response_format) | |
| payload = { | |
| "model": OPENROUTER_MIMIKA_MODEL, | |
| "messages": prepared, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "stream": False, | |
| } | |
| logger.info("Calling OpenRouter Mimika...") | |
| try: | |
| timeout = aiohttp.ClientTimeout(total=_settings.request_timeout_ms / 1000) | |
| async with aiohttp.ClientSession(timeout=timeout) as session: | |
| async with session.post( | |
| f"{OPENROUTER_MIMIKA_BASE_URL}{OPENROUTER_MIMIKA_CHAT_PATH}", | |
| json=payload, | |
| headers={ | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json", | |
| }, | |
| ) as resp: | |
| if resp.status != 200: | |
| logger.warning("OpenRouter Mimika HTTP %d", resp.status) | |
| return None | |
| data = await resp.json() | |
| attach_json_content(data, response_format) | |
| return data | |
| except Exception as exc: | |
| logger.warning("OpenRouter Mimika failed: %s", exc) | |
| return None | |
| def _get_meganova_key() -> Optional[str]: | |
| keys_str = _settings.ai_api_keys | |
| if not keys_str: | |
| return None | |
| keys = [k.strip() for k in keys_str.split(",") if k.strip()] | |
| return keys[0] if keys else None | |
| async def _acquire_slot( | |
| redis: Redis, | |
| scripts: Dict[str, str], | |
| ) -> Optional[Dict[str, Any]]: | |
| _refresh_keys() | |
| if not KEY_IDS: | |
| return None | |
| raw = await redis.evalsha( | |
| scripts["acquire_slot_sha"], | |
| 1, | |
| PREFIX, | |
| json.dumps(KEY_IDS), | |
| json.dumps(MODELS), | |
| str(_settings.rounds_per_model), | |
| str(len(KEY_IDS)), | |
| ) | |
| if not raw or raw[0] != "ok": | |
| return None | |
| return { | |
| "keyId": raw[1], | |
| "model": raw[2], | |
| "modelIndex": int(raw[3]), | |
| } | |
| async def _acquire_slot_for_model( | |
| redis: Redis, | |
| target_model: str, | |
| ) -> Optional[Dict[str, Any]]: | |
| _refresh_keys() | |
| if not KEY_IDS: | |
| return None | |
| target_idx = MODELS.index(target_model) if target_model in MODELS else -1 | |
| if target_idx == -1: | |
| return None | |
| for key_id in KEY_IDS: | |
| lock_k = f"{PREFIX}:key:{key_id}:lock" | |
| failed_k = f"{PREFIX}:key:{key_id}:m{target_idx}:failed" | |
| locked = await redis.exists(lock_k) | |
| if locked: | |
| continue | |
| failed = await redis.get(failed_k) | |
| if failed == "1": | |
| continue | |
| rounds_k = f"{PREFIX}:key:{key_id}:m{target_idx}:rounds" | |
| used = int(await redis.get(rounds_k) or "0") | |
| if used < _settings.rounds_per_model: | |
| await redis.incr(rounds_k) | |
| return { | |
| "keyId": key_id, | |
| "model": target_model, | |
| "modelIndex": target_idx, | |
| } | |
| return None | |
| async def _mark_failure( | |
| redis: Redis, | |
| scripts: Dict[str, str], | |
| key_id: str, | |
| model_index: int, | |
| ) -> str: | |
| result = await redis.evalsha( | |
| scripts["lock_key_sha"], | |
| 1, | |
| PREFIX, | |
| key_id, | |
| str(model_index), | |
| str(len(MODELS)), | |
| str(_settings.key_lock_ttl), | |
| ) | |
| return result | |
| async def call_meganova( | |
| redis: Redis, | |
| scripts: Dict[str, str], | |
| messages: List[Dict[str, str]], | |
| response_format: Optional[Dict[str, str]], | |
| max_tokens: int = DEFAULT_MAX_TOKENS, | |
| temperature: float = DEFAULT_TEMPERATURE, | |
| top_p: float = DEFAULT_TOP_P, | |
| target_model: Optional[str] = None, | |
| ) -> Optional[Dict[str, Any]]: | |
| if target_model and target_model not in MODELS: | |
| return None | |
| total_unique_slots = len(MODELS) * len(KEY_IDS) | |
| tried: set = set() | |
| hard_cap = total_unique_slots * 2 + 2 | |
| loops = 0 | |
| max_tries = len(KEY_IDS) if target_model else total_unique_slots | |
| while len(tried) < max_tries and loops < hard_cap: | |
| loops += 1 | |
| slot = ( | |
| await _acquire_slot_for_model(redis, target_model) | |
| if target_model | |
| else await _acquire_slot(redis, scripts) | |
| ) | |
| if not slot: | |
| return None | |
| combo_key = f"{slot['keyId']}:{slot['modelIndex']}" | |
| if combo_key in tried: | |
| continue | |
| tried.add(combo_key) | |
| prepared = prepare_messages(messages, response_format) | |
| api_key = KEY_MAP.get(slot["keyId"]) | |
| if not api_key: | |
| continue | |
| payload = { | |
| "messages": prepared, | |
| "model": slot["model"], | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "stream": False, | |
| } | |
| try: | |
| timeout = aiohttp.ClientTimeout(total=_settings.request_timeout_ms / 1000) | |
| async with aiohttp.ClientSession(timeout=timeout) as session: | |
| async with session.post( | |
| f"{MEGANOVA_BASE_URL}{MEGANOVA_CHAT_PATH}", | |
| json=payload, | |
| headers={ | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json", | |
| }, | |
| ) as resp: | |
| if resp.status >= 400: | |
| logger.warning( | |
| "MegaNova HTTP %d for keyId=%s model=%s", | |
| resp.status, | |
| slot["keyId"], | |
| slot["model"], | |
| ) | |
| await _mark_failure(redis, scripts, slot["keyId"], slot["modelIndex"]) | |
| continue | |
| data = await resp.json() | |
| attach_json_content(data, response_format) | |
| return data | |
| except Exception as exc: | |
| logger.warning("MegaNova request failed: %s", exc) | |
| continue | |
| return None | |
| async def _get_next_aion_key(redis: Redis) -> Optional[str]: | |
| keys_json = await redis.get(f"{AION_PREFIX}:keys") | |
| if not keys_json: | |
| return None | |
| keys: List[str] = json.loads(keys_json) | |
| if not keys: | |
| return None | |
| ptr = await redis.incr(f"{AION_PREFIX}:rr_ptr") | |
| idx = (ptr - 1) % len(keys) | |
| await redis.incr(f"{AION_PREFIX}:key:{idx}:uses") | |
| logger.info("[aion] Round-robin: ptr=%s, idx=%s, total=%s", ptr, idx, len(keys)) | |
| return keys[idx] | |
| async def call_aion_labs( | |
| redis: Redis, | |
| messages: List[Dict[str, str]], | |
| response_format: Optional[Dict[str, str]], | |
| model: str = AION_LABS_DEFAULT_MODEL, | |
| max_tokens: int = DEFAULT_MAX_TOKENS, | |
| temperature: float = DEFAULT_TEMPERATURE, | |
| top_p: float = DEFAULT_TOP_P, | |
| ) -> Optional[Dict[str, Any]]: | |
| keys_json = await redis.get(f"{AION_PREFIX}:keys") | |
| if not keys_json: | |
| logger.info("[aion] No keys in Redis, skipping") | |
| return None | |
| keys: List[str] = json.loads(keys_json) | |
| if not keys: | |
| return None | |
| total_tries = min(len(keys), 3) | |
| for attempt in range(total_tries): | |
| key = await _get_next_aion_key(redis) | |
| if not key: | |
| return None | |
| prepared = prepare_messages(messages, response_format) | |
| payload = { | |
| "model": model, | |
| "messages": prepared, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "stream": False, | |
| } | |
| logger.info("[aion] Attempt %s/%s", attempt + 1, total_tries) | |
| try: | |
| timeout = aiohttp.ClientTimeout(total=_settings.request_timeout_ms / 1000) | |
| async with aiohttp.ClientSession(timeout=timeout) as session: | |
| async with session.post( | |
| f"{AION_LABS_BASE_URL}{AION_LABS_CHAT_PATH}", | |
| json=payload, | |
| headers={ | |
| "Authorization": f"Bearer {key}", | |
| "Content-Type": "application/json", | |
| }, | |
| ) as resp: | |
| if resp.status != 200: | |
| logger.warning("[aion] Attempt %s HTTP %s", attempt + 1, resp.status) | |
| continue | |
| data = await resp.json() | |
| attach_json_content(data, response_format) | |
| return data | |
| except Exception as exc: | |
| logger.warning("[aion] Attempt %s failed: %s", attempt + 1, exc) | |
| continue | |
| logger.info("[aion] All attempts exhausted") | |
| return None | |
| async def chat_completion( | |
| messages: List[Dict[str, str]], | |
| model: str = "agentdeck-1.0", | |
| response_format: Optional[Dict[str, str]] = None, | |
| max_tokens: int = DEFAULT_MAX_TOKENS, | |
| temperature: float = DEFAULT_TEMPERATURE, | |
| top_p: float = DEFAULT_TOP_P, | |
| provider: Optional[str] = None, | |
| redis: Optional[Redis] = None, | |
| scripts: Optional[Dict[str, str]] = None, | |
| ) -> Dict[str, Any]: | |
| messages = inject_system_identity(messages, model) | |
| if not provider: | |
| provider = "meganova" | |
| logger.info( | |
| "Chat completion: model=%s provider=%s messages=%d", | |
| model, | |
| provider, | |
| len(messages), | |
| ) | |
| target = model if model in MODELS else None | |
| if provider == "aionlabs": | |
| if redis and scripts: | |
| result = await call_aion_labs( | |
| redis, messages, response_format, model or AION_LABS_DEFAULT_MODEL, | |
| max_tokens, temperature, top_p, | |
| ) | |
| else: | |
| result = await call_aion_labs_no_redis( | |
| messages, response_format, model or AION_LABS_DEFAULT_MODEL, | |
| max_tokens, temperature, top_p, | |
| ) | |
| if result: | |
| result["model"] = model | |
| return result | |
| raise RuntimeError("AionLabs request failed") | |
| if provider == "openprovider": | |
| result = await call_openrouter_mimika( | |
| messages, response_format, max_tokens, temperature, top_p, | |
| ) | |
| if result: | |
| result["model"] = model | |
| return result | |
| raise RuntimeError("OpenRouter Mimika request failed") | |
| if redis and scripts: | |
| result = await call_meganova( | |
| redis, scripts, messages, response_format, | |
| max_tokens, temperature, top_p, target, | |
| ) | |
| else: | |
| result = await call_meganova_no_redis( | |
| messages, response_format, max_tokens, temperature, top_p, target, | |
| ) | |
| if result: | |
| result["model"] = model | |
| return result | |
| logger.info("MegaNova failed, falling back to AionLabs") | |
| if redis and scripts: | |
| result = await call_aion_labs( | |
| redis, messages, response_format, model or AION_LABS_DEFAULT_MODEL, | |
| max_tokens, temperature, top_p, | |
| ) | |
| else: | |
| result = await call_aion_labs_no_redis( | |
| messages, response_format, model or AION_LABS_DEFAULT_MODEL, | |
| max_tokens, temperature, top_p, | |
| ) | |
| if result: | |
| result["model"] = model | |
| return result | |
| logger.info("AionLabs failed, falling back to OpenRouter Mimika") | |
| result = await call_openrouter_mimika( | |
| messages, response_format, max_tokens, temperature, top_p, | |
| ) | |
| if result: | |
| result["model"] = model | |
| return result | |
| raise RuntimeError("All AI providers exhausted") | |
| async def call_meganova_no_redis( | |
| messages: List[Dict[str, str]], | |
| response_format: Optional[Dict[str, str]], | |
| max_tokens: int = DEFAULT_MAX_TOKENS, | |
| temperature: float = DEFAULT_TEMPERATURE, | |
| top_p: float = DEFAULT_TOP_P, | |
| target_model: Optional[str] = None, | |
| ) -> Optional[Dict[str, Any]]: | |
| api_key = _get_meganova_key() | |
| if not api_key: | |
| logger.info("No MegaNova keys configured, skipping") | |
| return None | |
| prepared = prepare_messages(messages, response_format) | |
| payload = { | |
| "model": target_model or MODELS[1], | |
| "messages": prepared, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "stream": False, | |
| } | |
| logger.info("Calling MegaNova (no Redis)...") | |
| try: | |
| timeout = aiohttp.ClientTimeout(total=_settings.request_timeout_ms / 1000) | |
| async with aiohttp.ClientSession(timeout=timeout) as session: | |
| async with session.post( | |
| f"{MEGANOVA_BASE_URL}{MEGANOVA_CHAT_PATH}", | |
| json=payload, | |
| headers={ | |
| "Authorization": f"Bearer {api_key}", | |
| "Content-Type": "application/json", | |
| }, | |
| ) as resp: | |
| if resp.status != 200: | |
| logger.warning("MegaNova HTTP %d", resp.status) | |
| return None | |
| data = await resp.json() | |
| attach_json_content(data, response_format) | |
| return data | |
| except Exception as exc: | |
| logger.warning("MegaNova failed: %s", exc) | |
| return None | |
| async def call_aion_labs_no_redis( | |
| messages: List[Dict[str, str]], | |
| response_format: Optional[Dict[str, str]], | |
| model: str = AION_LABS_DEFAULT_MODEL, | |
| max_tokens: int = DEFAULT_MAX_TOKENS, | |
| temperature: float = DEFAULT_TEMPERATURE, | |
| top_p: float = DEFAULT_TOP_P, | |
| ) -> Optional[Dict[str, Any]]: | |
| keys_str = _settings.aion_lab_keys | |
| if not keys_str: | |
| logger.info("No AION keys configured, skipping") | |
| return None | |
| keys = [k.strip() for k in keys_str.split(",") if k.strip()] | |
| if not keys: | |
| return None | |
| for attempt, key in enumerate(keys[:3]): | |
| prepared = prepare_messages(messages, response_format) | |
| payload = { | |
| "model": model, | |
| "messages": prepared, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "stream": False, | |
| } | |
| logger.info("[aion] (no redis) Attempt %s/%s", attempt + 1, min(len(keys), 3)) | |
| try: | |
| timeout = aiohttp.ClientTimeout(total=_settings.request_timeout_ms / 1000) | |
| async with aiohttp.ClientSession(timeout=timeout) as session: | |
| async with session.post( | |
| f"{AION_LABS_BASE_URL}{AION_LABS_CHAT_PATH}", | |
| json=payload, | |
| headers={ | |
| "Authorization": f"Bearer {key}", | |
| "Content-Type": "application/json", | |
| }, | |
| ) as resp: | |
| if resp.status != 200: | |
| continue | |
| data = await resp.json() | |
| attach_json_content(data, response_format) | |
| return data | |
| except Exception: | |
| continue | |
| return None | |