Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import json | |
| from dataclasses import dataclass | |
| from typing import Any | |
| from urllib.parse import urlparse | |
| import httpx | |
| from mcp.shared.exceptions import McpError | |
| from livekit.agents import AgentSession, mcp | |
| from livekit.agents.llm.tool_context import ToolError | |
| from livekit.plugins import openai as openai_plugin | |
| from src.agent.prompts.runtime import MCP_STARTUP_GREETING | |
| from src.core.logger import logger | |
| from src.core.settings import LLMSettings | |
| NVIDIA_OPENAI_BASE_URL = "https://integrate.api.nvidia.com/v1" | |
| MCP_GENERATE_REPLY_BLOCK_MESSAGE = ( | |
| "Manual generate_reply is disabled in MCP mode; use session.say(...) instead." | |
| ) | |
| MCP_TOOL_TIMEOUT_MESSAGE = ( | |
| "The external tool '{tool_name}' timed out. " | |
| "Do not retry '{tool_name}' again in this turn. " | |
| "Give the user a brief answer without it." | |
| ) | |
| MCP_TOOL_UNAVAILABLE_MESSAGE = ( | |
| "The external tool '{tool_name}' is temporarily unavailable. " | |
| "Do not retry '{tool_name}' again in this turn. " | |
| "Give the user a brief answer without it." | |
| ) | |
| class MCPRuntimeDecision: | |
| enabled: bool | |
| reason: str | |
| class LLMRuntimeConfig: | |
| llm: Any | |
| mcp_servers: list[mcp.MCPServerHTTP] | None | |
| provider: str | |
| model: str | |
| def mcp_runtime_active(self) -> bool: | |
| return self.mcp_servers is not None | |
| class ConfiguredMCPServerHTTP(mcp.MCPServerHTTP): | |
| def __init__( | |
| self, | |
| *, | |
| url: str, | |
| timeout_seconds: float, | |
| headers: dict[str, Any] | None = None, | |
| ) -> None: | |
| bounded_timeout = _bounded_timeout_seconds(timeout_seconds) | |
| super().__init__( | |
| url=url, | |
| headers=headers, | |
| timeout=bounded_timeout, | |
| client_session_timeout_seconds=bounded_timeout, | |
| ) | |
| self._request_timeout_seconds = bounded_timeout | |
| def request_timeout_seconds(self) -> float: | |
| return self._request_timeout_seconds | |
| def client_session_timeout_seconds(self) -> float: | |
| return self._read_timeout | |
| def _make_function_tool( | |
| self, | |
| name: str, | |
| description: str | None, | |
| input_schema: dict[str, Any], | |
| meta: dict[str, Any] | None, | |
| ) -> mcp.MCPTool: | |
| async def _tool_called(raw_arguments: dict[str, Any]) -> Any: | |
| if self._client is None: | |
| raise ToolError( | |
| "Tool invocation failed: internal service is unavailable. " | |
| "Please check that the MCPServer is still running." | |
| ) | |
| try: | |
| tool_result = await self._client.call_tool(name, raw_arguments) | |
| except Exception as exc: | |
| normalized = normalize_mcp_tool_exception(tool_name=name, exc=exc) | |
| if normalized is None: | |
| raise | |
| logger.warning( | |
| "MCP tool invocation failed: tool=%s timeout=%s detail=%s", | |
| name, | |
| is_mcp_timeout_exception(exc), | |
| describe_mcp_exception(exc), | |
| ) | |
| raise normalized from exc | |
| if tool_result.isError: | |
| error_str = "\n".join(str(part) for part in tool_result.content) | |
| raise ToolError(error_str) | |
| if len(tool_result.content) == 1: | |
| return tool_result.content[0].model_dump_json() | |
| if len(tool_result.content) > 1: | |
| return json.dumps([item.model_dump() for item in tool_result.content]) | |
| raise ToolError( | |
| f"Tool '{name}' completed without producing a result. " | |
| "This might indicate an issue with internal processing." | |
| ) | |
| raw_schema = { | |
| "name": name, | |
| "description": description, | |
| "parameters": input_schema, | |
| } | |
| if meta: | |
| raw_schema["meta"] = meta | |
| return mcp.function_tool(_tool_called, raw_schema=raw_schema) | |
| def resolve_mcp_runtime_mode( | |
| *, | |
| mcp_enabled: bool, | |
| llm_provider: str, | |
| nvidia_api_key: str | None, | |
| ) -> MCPRuntimeDecision: | |
| provider = (llm_provider or "").strip().lower() | |
| if not mcp_enabled: | |
| return MCPRuntimeDecision(enabled=False, reason="mcp_disabled") | |
| if provider not in {"nvidia", "ollama"}: | |
| return MCPRuntimeDecision(enabled=False, reason=f"provider_not_supported:{provider}") | |
| if provider == "nvidia" and not nvidia_api_key: | |
| return MCPRuntimeDecision(enabled=False, reason="missing_nvidia_api_key") | |
| return MCPRuntimeDecision(enabled=True, reason="mcp_enabled") | |
| def resolve_mcp_server_urls( | |
| *, | |
| mcp_server_url: str, | |
| mcp_extra_server_urls: str, | |
| ) -> list[str]: | |
| candidates = [mcp_server_url, *(mcp_extra_server_urls or "").split(",")] | |
| deduplicated: list[str] = [] | |
| seen: set[str] = set() | |
| for candidate in candidates: | |
| normalized = (candidate or "").strip() | |
| if not normalized or normalized in seen: | |
| continue | |
| seen.add(normalized) | |
| deduplicated.append(normalized) | |
| return deduplicated | |
| def normalize_mcp_tool_exception(*, tool_name: str, exc: Exception) -> ToolError | None: | |
| if is_mcp_timeout_exception(exc): | |
| return ToolError(MCP_TOOL_TIMEOUT_MESSAGE.format(tool_name=tool_name)) | |
| if is_mcp_transport_exception(exc): | |
| return ToolError(MCP_TOOL_UNAVAILABLE_MESSAGE.format(tool_name=tool_name)) | |
| return None | |
| def is_mcp_timeout_exception(exc: BaseException) -> bool: | |
| for error in iter_exception_chain(exc): | |
| if isinstance(error, (TimeoutError, httpx.TimeoutException)): | |
| return True | |
| if isinstance(error, McpError) and looks_like_timeout_message(error.error.message): | |
| return True | |
| return False | |
| def is_mcp_transport_exception(exc: BaseException) -> bool: | |
| return any( | |
| isinstance(error, (McpError, httpx.RequestError, OSError)) | |
| for error in iter_exception_chain(exc) | |
| ) | |
| def iter_exception_chain(exc: BaseException) -> tuple[BaseException, ...]: | |
| chain: list[BaseException] = [] | |
| seen: set[int] = set() | |
| current: BaseException | None = exc | |
| while current is not None and id(current) not in seen: | |
| chain.append(current) | |
| seen.add(id(current)) | |
| current = current.__cause__ or current.__context__ | |
| return tuple(chain) | |
| def looks_like_timeout_message(message: str | None) -> bool: | |
| normalized = (message or "").strip().lower() | |
| if not normalized: | |
| return False | |
| return any( | |
| token in normalized | |
| for token in ("timed out", "timeout", "deadline exceeded", "read timed out") | |
| ) | |
| def describe_mcp_exception(exc: BaseException) -> str: | |
| for error in iter_exception_chain(exc): | |
| if isinstance(error, McpError): | |
| detail = error.error.message | |
| else: | |
| detail = str(error) | |
| if detail: | |
| return detail | |
| return type(exc).__name__ | |
| def _bounded_timeout_seconds(timeout_seconds: float) -> float: | |
| return max(float(timeout_seconds), 1.0) | |
| def build_llm_runtime( | |
| llm_settings: LLMSettings, | |
| ) -> LLMRuntimeConfig: | |
| provider = (llm_settings.LLM_PROVIDER or "").strip().lower() | |
| llm_timeout = build_mcp_http_timeout(llm_settings.LLM_CONN_TIMEOUT_SEC) | |
| mcp_timeout_seconds = _bounded_timeout_seconds(llm_settings.MCP_CONN_TIMEOUT_SEC) | |
| mcp_decision = resolve_mcp_runtime_mode( | |
| mcp_enabled=llm_settings.MCP_ENABLED, | |
| llm_provider=provider, | |
| nvidia_api_key=llm_settings.NVIDIA_API_KEY, | |
| ) | |
| mcp_server_urls: list[str] = [] | |
| if mcp_decision.enabled: | |
| mcp_server_urls = resolve_mcp_server_urls( | |
| mcp_server_url=llm_settings.MCP_SERVER_URL, | |
| mcp_extra_server_urls=llm_settings.MCP_EXTRA_SERVER_URLS, | |
| ) | |
| mcp_servers = [ | |
| ConfiguredMCPServerHTTP(url=url, timeout_seconds=mcp_timeout_seconds) | |
| for url in mcp_server_urls | |
| ] | |
| else: | |
| mcp_servers = None | |
| if provider == "nvidia": | |
| if not llm_settings.NVIDIA_API_KEY: | |
| raise ValueError( | |
| "NVIDIA_API_KEY is required when LLM_PROVIDER=nvidia" | |
| ) | |
| model = llm_settings.NVIDIA_MODEL | |
| base_url = NVIDIA_OPENAI_BASE_URL | |
| api_key = llm_settings.NVIDIA_API_KEY | |
| extra_body = {"chat_template_kwargs": {"enable_thinking": False}} | |
| elif provider == "ollama": | |
| model = (llm_settings.OLLAMA_MODEL or "").strip() | |
| if not model: | |
| raise ValueError("OLLAMA_MODEL is required when LLM_PROVIDER=ollama") | |
| base_url = (llm_settings.OLLAMA_BASE_URL or "").strip() | |
| if not base_url: | |
| raise ValueError("OLLAMA_BASE_URL is required when LLM_PROVIDER=ollama") | |
| validate_ollama_model_for_endpoint(base_url=base_url, model=model) | |
| api_key = resolve_ollama_api_key(llm_settings.OLLAMA_API_KEY) | |
| extra_body = {"think": False} | |
| else: | |
| raise ValueError( | |
| f"Unknown LLM provider: {provider}. Must be 'nvidia' or 'ollama'" | |
| ) | |
| if llm_settings.MCP_ENABLED and not mcp_decision.enabled: | |
| logger.warning( | |
| "MCP runtime requested but unavailable: reason=%s provider=%s", | |
| mcp_decision.reason, | |
| provider, | |
| ) | |
| elif mcp_decision.enabled: | |
| logger.info( | |
| "MCP runtime enabled: mcp_servers=%s llm_provider=%s llm_model=%s llm_timeout_sec=%.2f mcp_timeout_sec=%.2f", | |
| mcp_server_urls, | |
| provider, | |
| model, | |
| llm_settings.LLM_CONN_TIMEOUT_SEC, | |
| mcp_timeout_seconds, | |
| ) | |
| else: | |
| logger.info("MCP runtime disabled (MCP_ENABLED=false)") | |
| llm = openai_plugin.LLM( | |
| model=model, | |
| api_key=api_key, | |
| base_url=base_url, | |
| temperature=llm_settings.LLM_TEMPERATURE, | |
| max_completion_tokens=llm_settings.LLM_MAX_TOKENS, | |
| timeout=llm_timeout, | |
| _strict_tool_schema=False, | |
| extra_body=extra_body, | |
| ) | |
| return LLMRuntimeConfig( | |
| llm=llm, | |
| mcp_servers=mcp_servers, | |
| provider=provider, | |
| model=model, | |
| ) | |
| def resolve_ollama_api_key(api_key: str | None) -> str: | |
| value = (api_key or "").strip() | |
| if value: | |
| return value | |
| return "ollama" | |
| def validate_ollama_model_for_endpoint(*, base_url: str, model: str) -> None: | |
| if not is_ollama_cloud_openai_endpoint(base_url): | |
| return | |
| if model.lower().endswith(":cloud"): | |
| raise ValueError( | |
| "OLLAMA_MODEL cannot use ':cloud' aliases with OLLAMA_BASE_URL=https://ollama.com/v1. " | |
| "Use an exact model ID from https://ollama.com/v1/models (for example, qwen3-next:80b)." | |
| ) | |
| def is_ollama_cloud_openai_endpoint(base_url: str) -> bool: | |
| raw = (base_url or "").strip() | |
| if not raw: | |
| return False | |
| parsed = urlparse(raw) | |
| host = (parsed.hostname or "").lower() | |
| path = (parsed.path or "").rstrip("/") | |
| return host in {"ollama.com", "www.ollama.com", "api.ollama.com"} and path == "/v1" | |
| def build_mcp_http_timeout(timeout_seconds: float) -> httpx.Timeout: | |
| bounded_timeout = max(timeout_seconds, 1.0) | |
| return httpx.Timeout( | |
| connect=bounded_timeout, | |
| read=bounded_timeout, | |
| write=bounded_timeout, | |
| pool=bounded_timeout, | |
| ) | |
| def install_mcp_generate_reply_guard( | |
| session: AgentSession, | |
| *, | |
| mcp_runtime_active: bool, | |
| ) -> None: | |
| if not mcp_runtime_active: | |
| return | |
| if getattr(session, "_open_voice_mcp_generate_reply_guard_installed", False): | |
| return | |
| def _blocked_generate_reply(*_: Any, **__: Any) -> Any: | |
| raise RuntimeError(MCP_GENERATE_REPLY_BLOCK_MESSAGE) | |
| setattr(session, "_open_voice_mcp_generate_reply_guard_installed", True) | |
| setattr(session, "_open_voice_original_generate_reply", session.generate_reply) | |
| setattr(session, "generate_reply", _blocked_generate_reply) | |
| logger.info("MCP runtime policy active: manual generate_reply disabled") | |
| def run_startup_greeting( | |
| session: AgentSession, | |
| *, | |
| mcp_runtime_active: bool, | |
| ) -> Any | None: | |
| if mcp_runtime_active: | |
| logger.info("MCP runtime startup greeting via session.say") | |
| try: | |
| return session.say( | |
| MCP_STARTUP_GREETING, | |
| allow_interruptions=True, | |
| add_to_chat_ctx=False, | |
| ) | |
| except Exception as exc: | |
| logger.warning(f"MCP startup greeting could not start: {exc}") | |
| return None | |
| try: | |
| session.generate_reply(instructions="Greet the user and offer your assistance.") | |
| except Exception as exc: | |
| logger.warning(f"Startup greeting via generate_reply failed: {exc}") | |
| return None | |