Spaces:
Running
Running
| """Behavior tests for the Gemini Live handler.""" | |
| import base64 | |
| import asyncio | |
| from types import SimpleNamespace | |
| from typing import Any, Callable, AsyncIterator | |
| from unittest.mock import AsyncMock, MagicMock, call | |
| import numpy as np | |
| import pytest | |
| from fastrtc import AdditionalOutputs | |
| import reachy_mini_conversation_app.gemini_live as gemini_mod | |
| import reachy_mini_conversation_app.tools.core_tools as ct_mod | |
| from reachy_mini_conversation_app.gemini_live import GeminiLiveHandler | |
| from reachy_mini_conversation_app.tools.core_tools import ToolDependencies | |
| from reachy_mini_conversation_app.tools.tool_constants import ToolState | |
| from reachy_mini_conversation_app.tools.background_tool_manager import ToolNotification | |
| def _server_content(**kwargs: Any) -> SimpleNamespace: | |
| defaults = { | |
| "model_turn": None, | |
| "turn_complete": None, | |
| "interrupted": None, | |
| "grounding_metadata": None, | |
| "generation_complete": None, | |
| "input_transcription": None, | |
| "output_transcription": None, | |
| "url_context_metadata": None, | |
| "turn_complete_reason": None, | |
| "waiting_for_input": None, | |
| } | |
| defaults.update(kwargs) | |
| return SimpleNamespace(**defaults) | |
| def _response(server_content: Any = None, tool_call: Any = None) -> SimpleNamespace: | |
| return SimpleNamespace(server_content=server_content, tool_call=tool_call) | |
| async def _wait_for(predicate: Callable[[], bool], timeout: float = 1.0) -> None: | |
| deadline = asyncio.get_running_loop().time() + timeout | |
| while asyncio.get_running_loop().time() < deadline: | |
| if predicate(): | |
| return | |
| await asyncio.sleep(0.01) | |
| raise AssertionError("Timed out waiting for condition") | |
| class _FakeSession: | |
| def __init__(self, batches: list[list[SimpleNamespace]], stop_event: asyncio.Event) -> None: | |
| self._batches = list(batches) | |
| self._stop_event = stop_event | |
| self.realtime_inputs: list[dict[str, Any]] = [] | |
| self.tool_responses: list[dict[str, Any]] = [] | |
| async def close(self) -> None: | |
| self._stop_event.set() | |
| async def send_realtime_input(self, **kwargs: Any) -> None: | |
| self.realtime_inputs.append(kwargs) | |
| return None | |
| async def send_tool_response(self, **kwargs: Any) -> None: | |
| self.tool_responses.append(kwargs) | |
| return None | |
| async def receive(self) -> AsyncIterator[SimpleNamespace]: | |
| if self._batches: | |
| for response in self._batches.pop(0): | |
| yield response | |
| return | |
| await self._stop_event.wait() | |
| return | |
| yield | |
| class _FakeConnectContext: | |
| def __init__(self, session: _FakeSession): | |
| self._session = session | |
| async def __aenter__(self) -> _FakeSession: | |
| return self._session | |
| async def __aexit__(self, *_args: object) -> bool: | |
| return False | |
| class _FakeLiveClient: | |
| def __init__(self, session: _FakeSession) -> None: | |
| self.aio = SimpleNamespace(live=SimpleNamespace(connect=lambda **_kwargs: _FakeConnectContext(session))) | |
| async def test_gemini_turn_buffers_transcripts_and_schedules_motion_reset( | |
| monkeypatch: pytest.MonkeyPatch, | |
| ) -> None: | |
| """Gemini turns should emit one transcript per role and let the wobbler reset after speech.""" | |
| monkeypatch.setattr(gemini_mod, "get_session_instructions", lambda: "test") | |
| monkeypatch.setattr(gemini_mod, "get_session_voice", lambda: "Kore") | |
| monkeypatch.setattr(gemini_mod, "get_active_tool_specs", lambda _: []) | |
| movement_manager = MagicMock() | |
| movement_manager.is_idle.return_value = False | |
| head_wobbler = MagicMock() | |
| robot = SimpleNamespace(media=SimpleNamespace(audio=None)) | |
| deps = ToolDependencies( | |
| reachy_mini=robot, | |
| movement_manager=movement_manager, | |
| head_wobbler=head_wobbler, | |
| ) | |
| handler = GeminiLiveHandler(deps) | |
| monkeypatch.setattr(type(handler.tool_manager), "start_up", MagicMock()) | |
| monkeypatch.setattr(type(handler.tool_manager), "shutdown", AsyncMock()) | |
| audio_bytes = b"\x00\x00\x10\x00" * 256 | |
| session = _FakeSession( | |
| batches=[ | |
| [ | |
| _response( | |
| _server_content( | |
| input_transcription=SimpleNamespace(text="How's it going, Reachy?"), | |
| ) | |
| ), | |
| _response( | |
| _server_content( | |
| model_turn=SimpleNamespace( | |
| parts=[SimpleNamespace(inline_data=SimpleNamespace(data=audio_bytes))] | |
| ), | |
| ) | |
| ), | |
| _response( | |
| _server_content( | |
| output_transcription=SimpleNamespace(text="Doing"), | |
| ) | |
| ), | |
| _response( | |
| _server_content( | |
| output_transcription=SimpleNamespace(text=" great."), | |
| ) | |
| ), | |
| _response( | |
| _server_content( | |
| turn_complete=True, | |
| ) | |
| ), | |
| ] | |
| ], | |
| stop_event=handler._stop_event, | |
| ) | |
| handler.client = _FakeLiveClient(session) | |
| task = asyncio.create_task(handler._run_live_session()) | |
| await _wait_for( | |
| lambda: head_wobbler.request_reset_after_current_audio.called and handler.output_queue.qsize() >= 3 | |
| ) | |
| handler._stop_event.set() | |
| await asyncio.wait_for(task, timeout=1.0) | |
| outputs = [] | |
| while not handler.output_queue.empty(): | |
| outputs.append(handler.output_queue.get_nowait()) | |
| transcript_messages = [ | |
| message | |
| for output in outputs | |
| if isinstance(output, AdditionalOutputs) | |
| for message in output.args | |
| if isinstance(message.get("content"), str) | |
| ] | |
| assert transcript_messages == [ | |
| {"role": "user", "content": "How's it going, Reachy?"}, | |
| {"role": "assistant", "content": "Doing great."}, | |
| ] | |
| assert any(isinstance(output, tuple) for output in outputs), "audio output was not emitted" | |
| movement_manager.set_listening.assert_has_calls([call(True), call(False)]) | |
| assert movement_manager.set_listening.call_args_list[-1] == call(False) | |
| head_wobbler.feed.assert_not_called() | |
| head_wobbler.request_reset_after_current_audio.assert_called_once() | |
| head_wobbler.reset.assert_not_called() | |
| async def test_gemini_camera_tool_sends_snapshot_and_returns_json_result() -> None: | |
| """Camera tool should push the snapshot via realtime video input and return a JSON-safe tool result.""" | |
| camera_worker = MagicMock() | |
| camera_worker.get_latest_frame.return_value = np.zeros((8, 8, 3), dtype=np.uint8) | |
| deps = ToolDependencies( | |
| reachy_mini=MagicMock(), | |
| movement_manager=MagicMock(), | |
| camera_worker=camera_worker, | |
| ) | |
| handler = GeminiLiveHandler(deps) | |
| session = _FakeSession([], handler._stop_event) | |
| handler.session = session | |
| await handler._handle_tool_result( | |
| ToolNotification( | |
| id="call_camera_1", | |
| tool_name="camera", | |
| is_idle_tool_call=False, | |
| status=ToolState.COMPLETED, | |
| result={"b64_im": base64.b64encode(b"jpeg-bytes").decode("ascii")}, | |
| ) | |
| ) | |
| # Image pushed as a realtime video frame (not embedded in FunctionResponse) | |
| assert len(session.realtime_inputs) == 1 | |
| blob = session.realtime_inputs[0]["video"] | |
| assert blob.data == b"jpeg-bytes" | |
| assert blob.mime_type == "image/jpeg" | |
| # Tool response is plain JSON (no binary, no parts) | |
| assert len(session.tool_responses) == 1 | |
| function_response = session.tool_responses[0]["function_responses"][0] | |
| assert function_response.response == {"status": "image_captured"} | |
| assert not hasattr(function_response, "parts") or function_response.parts is None | |
| # Console output is JSON-safe | |
| outputs = [] | |
| while not handler.output_queue.empty(): | |
| outputs.append(handler.output_queue.get_nowait()) | |
| tool_messages = [ | |
| message | |
| for output in outputs | |
| if isinstance(output, AdditionalOutputs) | |
| for message in output.args | |
| if isinstance(message.get("content"), str) | |
| ] | |
| assert tool_messages == [ | |
| { | |
| "role": "assistant", | |
| "content": '{"status": "image_captured"}', | |
| "metadata": {"title": "🛠️ Used tool camera", "status": "done"}, | |
| } | |
| ] | |
| async def test_apply_personality_preserves_manual_voice_override(monkeypatch) -> None: | |
| """Applying a profile should keep a manually selected Gemini voice active.""" | |
| monkeypatch.setattr(gemini_mod, "get_session_instructions", lambda: "test") | |
| monkeypatch.setattr(gemini_mod, "get_session_voice", lambda: "Kore") | |
| monkeypatch.setattr("reachy_mini_conversation_app.config.set_custom_profile", lambda _profile: None) | |
| handler = GeminiLiveHandler(ToolDependencies(reachy_mini=MagicMock(), movement_manager=MagicMock())) | |
| handler.session = object() | |
| handler._voice_override = "Orus" | |
| restart = AsyncMock() | |
| monkeypatch.setattr(handler, "_restart_session", restart) | |
| status = await handler.apply_personality("example") | |
| assert status == "Applied personality and restarted Gemini session." | |
| assert handler.get_current_voice() == "Orus" | |
| restart.assert_awaited_once() | |
| def test_handler_uses_startup_voice_at_startup() -> None: | |
| """Gemini handler startup should restore a persisted startup voice.""" | |
| handler = GeminiLiveHandler( | |
| ToolDependencies(reachy_mini=MagicMock(), movement_manager=MagicMock()), | |
| startup_voice="Orus", | |
| ) | |
| assert handler.get_current_voice() == "Orus" | |
| def test_copy_preserves_current_voice_override() -> None: | |
| """Copied Gemini handlers should keep the current voice override.""" | |
| handler = GeminiLiveHandler( | |
| ToolDependencies(reachy_mini=MagicMock(), movement_manager=MagicMock()), | |
| startup_voice="Orus", | |
| ) | |
| handler._voice_override = "Zephyr" | |
| copied_handler = handler.copy() | |
| assert copied_handler.get_current_voice() == "Zephyr" | |
| def test_gemini_excludes_head_tracking_when_no_head_tracker(monkeypatch) -> None: | |
| """head_tracking tool must not appear in Gemini session config when head_tracker is not active.""" | |
| monkeypatch.setattr(gemini_mod, "get_session_instructions", lambda: "test") | |
| monkeypatch.setattr(gemini_mod, "get_session_voice", lambda: "Kore") | |
| # mock ALL_TOOL_SPECS to include at least head_tracking and one other tool, to verify that only head_tracking is excluded, not all tools | |
| monkeypatch.setattr( | |
| ct_mod, | |
| "ALL_TOOL_SPECS", | |
| [ | |
| {"type": "function", "name": "head_tracking", "description": "head_tracking", "parameters": {}}, | |
| {"type": "function", "name": "fake_tool", "description": "fake_tool", "parameters": {}}, | |
| ], | |
| ) | |
| # case 1: no camera at all, --no-camera flag passed | |
| deps = ToolDependencies(reachy_mini=MagicMock(), movement_manager=MagicMock(), camera_worker=None) | |
| handler = GeminiLiveHandler(deps) | |
| live_config = handler._build_live_config() | |
| tool_names = [fd.name for fd in live_config.tools[0].function_declarations] if live_config.tools else [] | |
| assert "head_tracking" not in tool_names, "case 1 failed: camera_worker=None" | |
| assert "fake_tool" in tool_names, "case 1 failed: a non-head-tracking tool was unexpectedly excluded" | |
| # case 2: camera is running but --head-tracker flag was not passed | |
| camera_worker = MagicMock() | |
| camera_worker.head_tracker = None | |
| deps = ToolDependencies(reachy_mini=MagicMock(), movement_manager=MagicMock(), camera_worker=camera_worker) | |
| handler = GeminiLiveHandler(deps) | |
| live_config = handler._build_live_config() | |
| tool_names = [fd.name for fd in live_config.tools[0].function_declarations] if live_config.tools else [] | |
| assert "head_tracking" not in tool_names, "case 2 failed: camera_worker.head_tracker=None" | |
| assert "fake_tool" in tool_names, "case 2 failed: a non-head-tracking tool was unexpectedly excluded" | |