open-voice-agent / tests /test_agent_mcp_runtime.py
dvalle08's picture
feat: Introduce MCP connection timeout handling and enhance configuration
8193c41
from __future__ import annotations
import asyncio
import re
from datetime import date
from typing import Any
import httpx
import pytest
from livekit.agents import llm
from livekit.agents.llm.tool_context import ToolError
from mcp.types import ErrorData
from src.agent.models.llm_runtime import (
ConfiguredMCPServerHTTP,
MCP_GENERATE_REPLY_BLOCK_MESSAGE,
build_llm_runtime,
build_mcp_http_timeout,
install_mcp_generate_reply_guard,
resolve_mcp_server_urls,
resolve_mcp_runtime_mode,
run_startup_greeting,
)
from src.agent.prompts.assistant import (
ASSISTANT_INSTRUCTIONS,
build_assistant_instructions,
)
from src.agent.prompts.runtime import MCP_STARTUP_GREETING
from src.agent.runtime.tasks import (
cancel_task_for_shutdown,
monitor_startup_greeting_handle,
run_llm_warmup,
schedule_llm_warmup_task,
schedule_startup_greeting_task,
)
from src.agent.tools.pre_tool_feedback import inject_pre_tool_feedback
from src.core.settings import LLMSettings, Settings
class _FakeSpeechHandle:
def __init__(self, *, speech_id: str = "speech-greeting", block: bool = False) -> None:
self.id = speech_id
self.interrupt_calls: list[bool] = []
self._done = asyncio.Event()
if not block:
self._done.set()
async def wait_for_playout(self) -> None:
await self._done.wait()
def interrupt(self, *, force: bool = False) -> "_FakeSpeechHandle":
self.interrupt_calls.append(force)
self._done.set()
return self
class _FakeSession:
def __init__(self, *, greeting_handle: _FakeSpeechHandle | None = None) -> None:
self.say_calls: list[dict[str, object]] = []
self.generate_reply_calls: list[dict[str, object]] = []
self.greeting_handle = greeting_handle or _FakeSpeechHandle()
def say(self, text: str, **kwargs: object) -> _FakeSpeechHandle:
self.say_calls.append({"text": text, "kwargs": kwargs})
return self.greeting_handle
def generate_reply(self, **kwargs: object) -> str:
self.generate_reply_calls.append(kwargs)
return "reply-handle"
class _FailingSaySession(_FakeSession):
def say(self, text: str, **kwargs: object) -> _FakeSpeechHandle:
self.say_calls.append({"text": text, "kwargs": kwargs})
raise RuntimeError("say failed")
class _FakeLLMStream:
def __init__(
self,
*,
chunks: list[object] | None = None,
iter_error: Exception | None = None,
block_forever: bool = False,
) -> None:
self._chunks = list(chunks or [])
self._iter_error = iter_error
self._block_forever = block_forever
self._idx = 0
self._blocked = asyncio.Event()
self.aclose_calls = 0
self.iterations = 0
self.closed = False
def __aiter__(self) -> "_FakeLLMStream":
return self
async def __anext__(self) -> object:
if self._block_forever:
await self._blocked.wait()
if self._iter_error is not None:
raise self._iter_error
if self._idx >= len(self._chunks):
raise StopAsyncIteration
chunk = self._chunks[self._idx]
self._idx += 1
self.iterations += 1
return chunk
async def aclose(self) -> None:
self.closed = True
self.aclose_calls += 1
self._blocked.set()
class _FakeLLMClient:
def __init__(self, stream: _FakeLLMStream) -> None:
self._stream = stream
self.chat_calls: list[dict[str, object]] = []
def chat(self, **kwargs: object) -> _FakeLLMStream:
self.chat_calls.append(kwargs)
return self._stream
class _FakeMCPClient:
def __init__(self, *, result: Any = None, exc: Exception | None = None) -> None:
self._result = result
self._exc = exc
self.calls: list[dict[str, Any]] = []
async def call_tool(self, name: str, arguments: dict[str, Any] | None = None) -> Any:
self.calls.append({"name": name, "arguments": arguments})
if self._exc is not None:
raise self._exc
return self._result
def _build_llm_settings(**overrides: Any) -> LLMSettings:
defaults: dict[str, Any] = {
"LLM_PROVIDER": "ollama",
"OLLAMA_CLOUD_MODE": False,
"OLLAMA_MODEL": "qwen2.5:7b",
"OLLAMA_API_KEY": "ollama",
"MCP_ENABLED": False,
"MCP_SERVER_URL": "https://huggingface.co/mcp",
"MCP_EXTRA_SERVER_URLS": "https://docs.livekit.io/mcp",
"LLM_TEMPERATURE": 0.7,
"LLM_MAX_TOKENS": 1024,
"LLM_CONN_TIMEOUT_SEC": 12.0,
"MCP_CONN_TIMEOUT_SEC": 20.0,
"NVIDIA_MODEL": "meta/llama-3.1-8b-instruct",
"NVIDIA_API_KEY": None,
}
defaults.update(overrides)
return LLMSettings(**defaults)
class _FakeToolFeedback:
def __init__(self, *, fallback_phrase: str = "Let me check that.") -> None:
self.fallback_phrase = fallback_phrase
self.start_typing_calls = 0
def next_fallback_phrase(self) -> str:
return self.fallback_phrase
async def start_typing_sound(self) -> None:
self.start_typing_calls += 1
class _FakeToolStepStartedCallback:
def __init__(self, *, announce: bool = True) -> None:
self.calls = 0
self.announce = announce
async def __call__(self) -> bool:
self.calls += 1
return self.announce
async def _iter_items(items: list[Any]) -> Any:
for item in items:
yield item
async def _collect(items: Any) -> list[Any]:
out: list[Any] = []
async for item in items:
out.append(item)
return out
def _tool_chunk(
*,
content: str | None,
call_id: str = "call-1",
name: str = "paper_search",
) -> llm.ChatChunk:
return llm.ChatChunk(
id="chunk-1",
delta=llm.ChoiceDelta(
content=content,
tool_calls=[
llm.FunctionToolCall(
name=name,
arguments='{"query":"transformers"}',
call_id=call_id,
)
],
),
)
def test_resolve_mcp_runtime_mode_enables_mcp_with_supported_config() -> None:
decision = resolve_mcp_runtime_mode(
mcp_enabled=True,
llm_provider="NVIDIA",
nvidia_api_key="nvapi-test",
)
assert decision.enabled is True
assert decision.reason == "mcp_enabled"
def test_resolve_mcp_runtime_mode_respects_disabled_flag() -> None:
decision = resolve_mcp_runtime_mode(
mcp_enabled=False,
llm_provider="nvidia",
nvidia_api_key="nvapi-test",
)
assert decision.enabled is False
assert decision.reason == "mcp_disabled"
def test_resolve_mcp_runtime_mode_rejects_unsupported_provider() -> None:
decision = resolve_mcp_runtime_mode(
mcp_enabled=True,
llm_provider="unknown-provider",
nvidia_api_key="nvapi-test",
)
assert decision.enabled is False
assert decision.reason == "provider_not_supported:unknown-provider"
def test_resolve_mcp_runtime_mode_supports_ollama_provider() -> None:
decision = resolve_mcp_runtime_mode(
mcp_enabled=True,
llm_provider="ollama",
nvidia_api_key=None,
)
assert decision.enabled is True
assert decision.reason == "mcp_enabled"
def test_resolve_mcp_runtime_mode_requires_nvidia_api_key() -> None:
decision = resolve_mcp_runtime_mode(
mcp_enabled=True,
llm_provider="nvidia",
nvidia_api_key=None,
)
assert decision.enabled is False
assert decision.reason == "missing_nvidia_api_key"
def testinstall_mcp_generate_reply_guard_blocks_manual_generate_reply() -> None:
session = _FakeSession()
install_mcp_generate_reply_guard(session, mcp_runtime_active=True) # type: ignore[arg-type]
with pytest.raises(RuntimeError, match=re.escape(MCP_GENERATE_REPLY_BLOCK_MESSAGE)):
session.generate_reply(instructions="hello")
def testinstall_mcp_generate_reply_guard_is_noop_when_mcp_disabled() -> None:
session = _FakeSession()
install_mcp_generate_reply_guard(session, mcp_runtime_active=False) # type: ignore[arg-type]
session.generate_reply(instructions="hello")
assert session.generate_reply_calls == [{"instructions": "hello"}]
def testrun_startup_greeting_uses_say_in_mcp_mode() -> None:
session = _FakeSession()
handle = run_startup_greeting(session, mcp_runtime_active=True) # type: ignore[arg-type]
assert handle is session.greeting_handle
assert session.say_calls == [
{
"text": MCP_STARTUP_GREETING,
"kwargs": {
"allow_interruptions": True,
"add_to_chat_ctx": False,
},
}
]
assert session.generate_reply_calls == []
def testrun_startup_greeting_uses_generate_reply_without_mcp() -> None:
session = _FakeSession()
handle = run_startup_greeting(session, mcp_runtime_active=False) # type: ignore[arg-type]
assert handle is None
assert session.generate_reply_calls == [
{"instructions": "Greet the user and offer your assistance."}
]
assert session.say_calls == []
def test_startup_greeting_monitor_times_out_and_interrupts() -> None:
session = _FakeSession(greeting_handle=_FakeSpeechHandle(block=True))
handle = run_startup_greeting(session, mcp_runtime_active=True) # type: ignore[arg-type]
assert handle is not None
asyncio.run(monitor_startup_greeting_handle(handle, timeout_sec=0.01))
assert handle.interrupt_calls == [True]
def test_startup_greeting_monitor_allows_unbounded_playout_when_timeout_disabled() -> None:
async def _run() -> _FakeSpeechHandle:
handle = _FakeSpeechHandle(block=True)
task = asyncio.create_task(
monitor_startup_greeting_handle(handle, timeout_sec=0.0)
)
await asyncio.sleep(0)
assert not task.done()
handle.interrupt(force=False)
await task
return handle
handle = asyncio.run(_run())
assert handle.interrupt_calls == [False]
def testrun_startup_greeting_swallows_say_exception() -> None:
session = _FailingSaySession()
handle = run_startup_greeting(session, mcp_runtime_active=True) # type: ignore[arg-type]
assert handle is None
assert session.say_calls == [{"text": MCP_STARTUP_GREETING, "kwargs": {"allow_interruptions": True, "add_to_chat_ctx": False}}]
assert session.generate_reply_calls == []
def test_startup_greeting_monitor_handles_cancellation() -> None:
async def _run() -> _FakeSpeechHandle:
handle = _FakeSpeechHandle(block=True)
task = asyncio.create_task(
monitor_startup_greeting_handle(handle, timeout_sec=10.0)
)
await asyncio.sleep(0)
task.cancel()
await task
return handle
handle = asyncio.run(_run())
assert handle.interrupt_calls == [True]
def testschedule_startup_greeting_task_is_shutdown_safe() -> None:
async def _run() -> tuple[asyncio.Task[Any], _FakeSpeechHandle]:
handle = _FakeSpeechHandle(block=True)
session = _FakeSession(greeting_handle=handle)
task = schedule_startup_greeting_task( # type: ignore[arg-type]
session,
mcp_runtime_active=True,
)
assert task is not None
assert not task.done()
await cancel_task_for_shutdown(task, task_name="startup greeting", timeout_sec=0.1)
return task, handle
task, handle = asyncio.run(_run())
assert task.done()
assert handle.interrupt_calls == [True]
def testbuild_mcp_http_timeout_uses_runtime_timeout_for_all_phases() -> None:
timeout = build_mcp_http_timeout(12.0)
assert timeout.connect == 12.0
assert timeout.read == 12.0
assert timeout.write == 12.0
assert timeout.pool == 12.0
def testbuild_mcp_http_timeout_enforces_positive_minimum() -> None:
timeout = build_mcp_http_timeout(0.0)
assert timeout.connect == 1.0
assert timeout.read == 1.0
assert timeout.write == 1.0
assert timeout.pool == 1.0
def test_build_llm_runtime_supports_ollama_with_mcp() -> None:
runtime = build_llm_runtime(
_build_llm_settings(
MCP_ENABLED=True,
MCP_CONN_TIMEOUT_SEC=17.0,
OLLAMA_API_KEY=None,
)
)
assert runtime.provider == "ollama"
assert runtime.model == "qwen2.5:7b"
assert runtime.mcp_runtime_active is True
assert runtime.mcp_servers is not None
assert [server.url for server in runtime.mcp_servers] == [
"https://huggingface.co/mcp",
"https://docs.livekit.io/mcp",
]
assert all(isinstance(server, ConfiguredMCPServerHTTP) for server in runtime.mcp_servers)
assert [server.request_timeout_seconds for server in runtime.mcp_servers] == [17.0, 17.0]
assert [server.client_session_timeout_seconds for server in runtime.mcp_servers] == [
17.0,
17.0,
]
def test_resolve_mcp_server_urls_trims_ignores_empty_and_deduplicates() -> None:
urls = resolve_mcp_server_urls(
mcp_server_url=" https://huggingface.co/mcp ",
mcp_extra_server_urls=(
"https://docs.livekit.io/mcp, , https://huggingface.co/mcp, "
"https://docs.livekit.io/mcp, https://example.com/mcp"
),
)
assert urls == [
"https://huggingface.co/mcp",
"https://docs.livekit.io/mcp",
"https://example.com/mcp",
]
def test_build_llm_runtime_rejects_cloud_alias_model_for_ollama_cloud_v1() -> None:
with pytest.raises(
ValueError,
match=r"cannot use ':cloud' aliases.*https://ollama\.com/v1",
):
build_llm_runtime(
_build_llm_settings(
OLLAMA_CLOUD_MODE=True,
OLLAMA_MODEL="qwen3.5:cloud",
OLLAMA_API_KEY="test-key",
MCP_ENABLED=True,
)
)
def test_build_llm_runtime_requires_nvidia_key() -> None:
with pytest.raises(ValueError, match="NVIDIA_API_KEY is required"):
build_llm_runtime(
_build_llm_settings(
LLM_PROVIDER="nvidia",
NVIDIA_API_KEY=None,
)
)
def test_build_llm_runtime_passes_nvidia_disable_thinking_payload(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured_kwargs: dict[str, Any] = {}
fake_llm = object()
def _fake_openai_llm(**kwargs: Any) -> object:
captured_kwargs.update(kwargs)
return fake_llm
monkeypatch.setattr(
"src.agent.models.llm_runtime.openai_plugin.LLM",
_fake_openai_llm,
)
runtime = build_llm_runtime(
_build_llm_settings(
LLM_PROVIDER="nvidia",
NVIDIA_API_KEY="nvapi-test",
NVIDIA_MODEL="qwen/qwen3-next-80b-a3b-instruct",
)
)
assert runtime.llm is fake_llm
assert captured_kwargs["base_url"] == "https://integrate.api.nvidia.com/v1"
assert captured_kwargs["extra_body"] == {
"chat_template_kwargs": {"enable_thinking": False}
}
def test_build_llm_runtime_passes_ollama_disable_thinking_payload(
monkeypatch: pytest.MonkeyPatch,
) -> None:
captured_kwargs: dict[str, Any] = {}
fake_llm = object()
def _fake_openai_llm(**kwargs: Any) -> object:
captured_kwargs.update(kwargs)
return fake_llm
monkeypatch.setattr(
"src.agent.models.llm_runtime.openai_plugin.LLM",
_fake_openai_llm,
)
runtime = build_llm_runtime(_build_llm_settings())
assert runtime.llm is fake_llm
assert captured_kwargs["base_url"] == "http://localhost:11434/v1"
assert captured_kwargs["extra_body"] == {"think": False}
def test_assistant_instructions_enforce_ultra_short_answers() -> None:
assert "Answer as quickly as possible." in ASSISTANT_INSTRUCTIONS
assert "Give the shortest complete answer." in ASSISTANT_INSTRUCTIONS
assert "Default to one short sentence." in ASSISTANT_INSTRUCTIONS
assert "If a few words are enough, use a few words." in ASSISTANT_INSTRUCTIONS
assert "Expand only when the user asks or accuracy requires it." in ASSISTANT_INSTRUCTIONS
assert "The user only speaks to you, not types, so do not ask them to write, type, paste, or use chat." in ASSISTANT_INSTRUCTIONS
def test_assistant_instructions_prioritize_knowledge_before_tools() -> None:
assert "Answer from your own knowledge first, then from the setup summary." in ASSISTANT_INSTRUCTIONS
def test_assistant_instructions_limit_huggingface_and_livekit_tools() -> None:
assert "Do not claim you can generate or edit images." in ASSISTANT_INSTRUCTIONS
assert "Use Hugging Face tools only when needed and only if they are available in the current session." in ASSISTANT_INSTRUCTIONS
assert "Use LiveKit tools only for LiveKit-specific questions not covered by your own knowledge or the setup summary." in ASSISTANT_INSTRUCTIONS
def test_assistant_instructions_avoid_same_turn_retry_after_tool_failure() -> None:
assert (
"If a tool reports timeout or temporary unavailability, do not call that same tool again in the same turn; answer briefly without it when you can."
in ASSISTANT_INSTRUCTIONS
)
def test_assistant_instructions_describe_pipeline_identity() -> None:
assert (
"You are Open Voice Agent, a real-time voice pipeline."
in ASSISTANT_INSTRUCTIONS
)
assert (
"You are a pipeline, not a single model."
in ASSISTANT_INSTRUCTIONS
)
assert (
"For self-description, describe the pipeline briefly and mention VAD, turn detection, STT, LLM, and TTS only when relevant."
in ASSISTANT_INSTRUCTIONS
)
def test_assistant_instructions_handle_setup_questions_from_summary() -> None:
assert "For setup questions, answer from the setup summary." in ASSISTANT_INSTRUCTIONS
def test_build_assistant_instructions_includes_current_date() -> None:
instructions = build_assistant_instructions(current_date=date(2026, 3, 7))
assert "Current date: March 7, 2026." in instructions
def test_assistant_instructions_include_capabilities_and_limitations() -> None:
instructions = build_assistant_instructions(current_date=date(2026, 3, 7))
assert "Priority 1. Identity." in instructions
assert "Priority 2. Response style." in instructions
assert "Priority 3. Tools." in instructions
assert "Priority 4. Safety and setup." in instructions
assert "The user only speaks to you, not types" in instructions
assert "Do not claim you can generate or edit images." in instructions
assert "Only call tools that are available in the current session." in instructions
assert "Never invent tools." in instructions
assert "Use plain voice-friendly text only" in instructions
def test_build_assistant_instructions_include_configuration_summary() -> None:
instructions = build_assistant_instructions(current_date=date(2026, 3, 7))
assert "Setup summary:" in instructions
assert "Voice stack: STT provider=" in instructions
assert "LLM provider=" in instructions
assert "mcp_timeout_sec=" in instructions
assert "TTS provider=" in instructions
assert "LiveKit: runtime agent_name=" in instructions
assert "audio sample_rate=" in instructions
assert "MCP runtime: enabled=" in instructions
assert "MCP runtime:" in instructions
def test_configured_mcp_server_http_converts_timeout_to_retry_guidance() -> None:
async def _run() -> None:
server = ConfiguredMCPServerHTTP(
url="https://docs.livekit.io/mcp",
timeout_seconds=20.0,
)
server._client = _FakeMCPClient(
exc=httpx.ReadTimeout(
"timed out",
request=httpx.Request("POST", "https://docs.livekit.io/mcp"),
)
)
tool = server._make_function_tool(
"code_search",
"Search docs",
{"type": "object"},
None,
)
with pytest.raises(
ToolError,
match=re.escape(
"The external tool 'code_search' timed out. Do not retry 'code_search' again in this turn. Give the user a brief answer without it."
),
):
await tool({"query": "agent session"})
asyncio.run(_run())
def test_configured_mcp_server_http_converts_transport_errors_to_retry_guidance() -> None:
async def _run() -> None:
server = ConfiguredMCPServerHTTP(
url="https://docs.livekit.io/mcp",
timeout_seconds=20.0,
)
server._client = _FakeMCPClient(
exc=OSError("connection refused"),
)
tool = server._make_function_tool(
"code_search",
"Search docs",
{"type": "object"},
None,
)
with pytest.raises(
ToolError,
match=re.escape(
"The external tool 'code_search' is temporarily unavailable. Do not retry 'code_search' again in this turn. Give the user a brief answer without it."
),
):
await tool({"query": "agent session"})
asyncio.run(_run())
def test_configured_mcp_server_http_treats_mcp_timeout_error_as_timeout() -> None:
async def _run() -> None:
from mcp.shared.exceptions import McpError
server = ConfiguredMCPServerHTTP(
url="https://docs.livekit.io/mcp",
timeout_seconds=20.0,
)
server._client = _FakeMCPClient(
exc=McpError(
ErrorData(
code=1,
message="Timed out while waiting for response to ClientRequest. Waited 20.0 seconds.",
)
),
)
tool = server._make_function_tool(
"code_search",
"Search docs",
{"type": "object"},
None,
)
with pytest.raises(ToolError, match="timed out"):
await tool({"query": "agent session"})
asyncio.run(_run())
def test_build_assistant_instructions_hide_pocket_specific_details_for_deepgram() -> None:
settings = Settings()
settings.voice.TTS_PROVIDER = "deepgram"
settings.voice.DEEPGRAM_API_KEY = "deepgram-secret-test-value"
instructions = build_assistant_instructions(
current_date=date(2026, 3, 7),
current_settings=settings,
)
assert "TTS provider=deepgram, model=plugin-default" in instructions
assert "voice=alba" not in instructions
assert "lsd_decode_steps=" not in instructions
def test_build_assistant_instructions_hide_pocket_specific_details_for_nvidia_tts() -> None:
settings = Settings()
settings.voice.TTS_PROVIDER = "nvidia"
settings.voice.NVIDIA_TTS_VOICE = "Magpie-Multilingual.EN-US.Leo"
settings.voice.NVIDIA_TTS_LANGUAGE_CODE = "en-US"
settings.voice.NVIDIA_TTS_SERVER = "grpc.nvcf.nvidia.com:443"
settings.voice.NVIDIA_TTS_USE_SSL = True
settings.llm.NVIDIA_API_KEY = "nvapi-secret-test-value"
instructions = build_assistant_instructions(
current_date=date(2026, 3, 7),
current_settings=settings,
)
assert (
"TTS provider=nvidia, voice=Magpie-Multilingual.EN-US.Leo, "
"language_code=en-US, server=grpc.nvcf.nvidia.com:443, "
"use_ssl=True, timeout_sec=45.0."
) in instructions
assert "language_code=en-US" in instructions
assert "server=grpc.nvcf.nvidia.com:443" in instructions
assert "use_ssl=True" in instructions
assert "voice=alba" not in instructions
assert "lsd_decode_steps=" not in instructions
def test_build_assistant_instructions_include_deepgram_stt_details() -> None:
settings = Settings()
settings.voice.DEEPGRAM_API_KEY = "deepgram-secret-test-value"
settings.stt.STT_PROVIDER = "deepgram"
settings.stt.DEEPGRAM_STT_MODEL = "nova-3"
settings.stt.DEEPGRAM_STT_LANGUAGE = "en-US"
instructions = build_assistant_instructions(
current_date=date(2026, 3, 7),
current_settings=settings,
)
assert "STT provider=deepgram, model=nova-3, language=en-US;" in instructions
assert "usefulsensors/moonshine-streaming-medium" not in instructions
assert "parakeet-1.1b-en-US-asr-streaming-silero-vad-sortformer" not in instructions
def test_build_assistant_instructions_redacts_sensitive_values() -> None:
settings = Settings()
settings.voice.DEEPGRAM_API_KEY = "deepgram-secret-test-value"
settings.voice.NVIDIA_TTS_API_KEY = "nvidia-tts-secret-test-value"
settings.llm.NVIDIA_API_KEY = "nvapi-secret-test-value"
settings.stt.NVIDIA_STT_API_KEY = "nvidia-stt-secret-test-value"
settings.livekit.LIVEKIT_API_KEY = "livekit-api-key-test-value"
settings.livekit.LIVEKIT_API_SECRET = "livekit-api-secret-test-value"
instructions = build_assistant_instructions(
current_date=date(2026, 3, 7),
current_settings=settings,
)
assert "deepgram-secret-test-value" not in instructions
assert "nvidia-tts-secret-test-value" not in instructions
assert "nvapi-secret-test-value" not in instructions
assert "nvidia-stt-secret-test-value" not in instructions
assert "livekit-api-key-test-value" not in instructions
assert "livekit-api-secret-test-value" not in instructions
assert "Credential state (redacted):" in instructions
assert "<redacted>" in instructions
def testrun_llm_warmup_consumes_first_chunk_and_closes_stream() -> None:
stream = _FakeLLMStream(chunks=[{"delta": "OK"}, {"delta": "ignored"}])
llm_client = _FakeLLMClient(stream)
asyncio.run(
run_llm_warmup(
llm_client=llm_client,
conn_options=object(), # type: ignore[arg-type]
provider="nvidia",
model="test-model",
)
)
assert len(llm_client.chat_calls) == 1
call = llm_client.chat_calls[0]
assert "chat_ctx" in call
assert call["tools"] is None
assert stream.iterations == 1
assert stream.closed is True
assert stream.aclose_calls == 1
def testrun_llm_warmup_swallows_stream_errors_and_closes() -> None:
stream = _FakeLLMStream(iter_error=RuntimeError("warmup boom"))
llm_client = _FakeLLMClient(stream)
asyncio.run(
run_llm_warmup(
llm_client=llm_client,
conn_options=object(), # type: ignore[arg-type]
provider="nvidia",
model="test-model",
)
)
assert len(llm_client.chat_calls) == 1
assert stream.closed is True
assert stream.aclose_calls == 1
def testschedule_llm_warmup_task_is_shutdown_safe() -> None:
async def _run() -> tuple[asyncio.Task[Any], _FakeLLMStream]:
stream = _FakeLLMStream(block_forever=True)
llm_client = _FakeLLMClient(stream)
task = schedule_llm_warmup_task(
llm_client=llm_client,
conn_options=object(), # type: ignore[arg-type]
provider="nvidia",
model="test-model",
)
assert not task.done()
await asyncio.sleep(0)
await cancel_task_for_shutdown(task, task_name="llm warm-up", timeout_sec=0.1)
return task, stream
task, stream = asyncio.run(_run())
assert task.done()
assert stream.closed is True
assert stream.aclose_calls == 1
def testinject_pre_tool_feedback_uses_model_leadin_before_tool_call() -> None:
feedback = _FakeToolFeedback()
input_chunk = _tool_chunk(content="I'll check that now.")
output = asyncio.run(
_collect(
inject_pre_tool_feedback(
_iter_items([input_chunk]),
tool_feedback=feedback, # type: ignore[arg-type]
)
)
)
assert output[0] == "I'll check that now."
assert isinstance(output[1], llm.ChatChunk)
assert output[1].delta is not None
assert output[1].delta.content is None
assert len(output[1].delta.tool_calls) == 1
assert feedback.start_typing_calls == 1
def testinject_pre_tool_feedback_uses_fallback_when_model_omits_leadin() -> None:
feedback = _FakeToolFeedback(fallback_phrase="One sec, checking now.")
input_chunk = _tool_chunk(content=None)
output = asyncio.run(
_collect(
inject_pre_tool_feedback(
_iter_items([input_chunk]),
tool_feedback=feedback, # type: ignore[arg-type]
)
)
)
assert output[0] == "One sec, checking now."
assert output[1] == input_chunk
assert feedback.start_typing_calls == 1
def testinject_pre_tool_feedback_announces_once_per_tool_step() -> None:
feedback = _FakeToolFeedback(fallback_phrase="Checking that for you.")
first = _tool_chunk(content=None, call_id="call-1")
second = _tool_chunk(content=None, call_id="call-2")
output = asyncio.run(
_collect(
inject_pre_tool_feedback(
_iter_items([first, second]),
tool_feedback=feedback, # type: ignore[arg-type]
)
)
)
assert output[0] == "Checking that for you."
assert output[1] == first
assert output[2] == second
assert feedback.start_typing_calls == 1
def testinject_pre_tool_feedback_marks_tool_step_once() -> None:
callback = _FakeToolStepStartedCallback()
first = _tool_chunk(content=None, call_id="call-1")
second = _tool_chunk(content=None, call_id="call-2")
_ = asyncio.run(
_collect(
inject_pre_tool_feedback(
_iter_items([first, second]),
tool_feedback=None,
on_tool_step_started=callback,
)
)
)
assert callback.calls == 1
def testinject_pre_tool_feedback_skips_unknown_tool_names() -> None:
feedback = _FakeToolFeedback()
unknown_tool_chunk = _tool_chunk(content=None, name="explain_pipeline")
output = asyncio.run(
_collect(
inject_pre_tool_feedback(
_iter_items([unknown_tool_chunk]),
tool_feedback=feedback, # type: ignore[arg-type]
allowed_tool_names={"paper_search"},
)
)
)
assert output == [unknown_tool_chunk]
assert feedback.start_typing_calls == 0
def testinject_pre_tool_feedback_respects_announcement_suppression() -> None:
feedback = _FakeToolFeedback()
callback = _FakeToolStepStartedCallback(announce=False)
tool_chunk = _tool_chunk(content=None)
output = asyncio.run(
_collect(
inject_pre_tool_feedback(
_iter_items([tool_chunk]),
tool_feedback=feedback, # type: ignore[arg-type]
should_announce_tool_step=callback,
allowed_tool_names={"paper_search"},
)
)
)
assert output == [tool_chunk]
assert callback.calls == 1
assert feedback.start_typing_calls == 0
def testinject_pre_tool_feedback_does_not_modify_non_tool_chunks() -> None:
feedback = _FakeToolFeedback()
non_tool_chunk = llm.ChatChunk(
id="chunk-plain",
delta=llm.ChoiceDelta(content="Hello there."),
)
output = asyncio.run(
_collect(
inject_pre_tool_feedback(
_iter_items([non_tool_chunk]),
tool_feedback=feedback, # type: ignore[arg-type]
)
)
)
assert output == [non_tool_chunk]
assert feedback.start_typing_calls == 0