File size: 6,007 Bytes
6172a47 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 | import asyncio
import contextlib
import logging
import os
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from config.settings import Settings
# Set mock environment BEFORE any imports that use Settings
os.environ.setdefault("NVIDIA_NIM_API_KEY", "test_key")
os.environ.setdefault("MODEL", "nvidia_nim/test-model")
os.environ["PTB_TIMEDELTA"] = "1"
# Ensure tests don't pick up a server API key from the repo .env
# (tests expect endpoints to be unauthenticated by default)
os.environ["ANTHROPIC_AUTH_TOKEN"] = ""
Settings.model_config = {**Settings.model_config, "env_file": None}
@pytest.fixture(autouse=True)
def _isolate_from_dotenv(monkeypatch):
"""Prevent Pydantic BaseSettings from reading the .env file during tests."""
monkeypatch.setattr(
Settings, "model_config", {**Settings.model_config, "env_file": None}
)
@pytest.fixture
def provider_config():
from providers.base import ProviderConfig
return ProviderConfig(
api_key="test_key",
base_url="https://test.api.nvidia.com/v1",
rate_limit=10,
rate_window=60,
)
@pytest.fixture
def nim_provider(provider_config):
from config.nim import NimSettings
from providers.nvidia_nim import NvidiaNimProvider
return NvidiaNimProvider(provider_config, nim_settings=NimSettings())
@pytest.fixture
def open_router_provider(provider_config):
from providers.open_router import OpenRouterProvider
return OpenRouterProvider(provider_config)
@pytest.fixture
def lmstudio_provider(provider_config):
from providers.base import ProviderConfig
from providers.lmstudio import LMStudioProvider
lmstudio_config = ProviderConfig(
api_key="lm-studio",
base_url="http://localhost:1234/v1",
rate_limit=provider_config.rate_limit,
rate_window=provider_config.rate_window,
)
return LMStudioProvider(lmstudio_config)
@pytest.fixture
def llamacpp_provider(provider_config):
from providers.base import ProviderConfig
from providers.llamacpp import LlamaCppProvider
llamacpp_config = ProviderConfig(
api_key="llamacpp",
base_url="http://localhost:8080/v1",
rate_limit=10,
rate_window=60,
)
return LlamaCppProvider(llamacpp_config)
@pytest.fixture
def mock_cli_session():
from messaging.platforms.base import CLISession
session = MagicMock(spec=CLISession)
session.start_task = MagicMock() # This will return an async generator
session.is_busy = False
return session
@pytest.fixture
def mock_cli_manager():
from messaging.platforms.base import SessionManagerInterface
manager = MagicMock(spec=SessionManagerInterface)
manager.get_or_create_session = AsyncMock()
manager.register_real_session_id = AsyncMock(return_value=True)
manager.stop_all = AsyncMock()
manager.remove_session = AsyncMock(return_value=True)
manager.get_stats = MagicMock(return_value={"active_sessions": 0})
return manager
@pytest.fixture
def mock_platform():
from messaging.platforms.base import MessagingPlatform
platform = MagicMock(spec=MessagingPlatform)
platform.send_message = AsyncMock(return_value="msg_123")
platform.edit_message = AsyncMock()
platform.delete_message = AsyncMock()
platform.queue_send_message = AsyncMock(return_value="msg_123")
platform.queue_edit_message = AsyncMock()
platform.queue_delete_message = AsyncMock()
def _fire_and_forget(task):
if asyncio.iscoroutine(task):
# Create a task to avoid "coroutine was never awaited" warning
return asyncio.create_task(task)
return None
platform.fire_and_forget = MagicMock(side_effect=_fire_and_forget)
return platform
@pytest.fixture
def mock_session_store():
from messaging.session import SessionStore
store = MagicMock(spec=SessionStore)
store.save_tree = MagicMock()
store.get_tree = MagicMock(return_value=None)
store.register_node = MagicMock()
store.clear_all = MagicMock()
store.record_message_id = MagicMock()
store.get_message_ids_for_chat = MagicMock(return_value=[])
return store
@pytest.fixture
def incoming_message_factory():
_valid_keys = frozenset(
{
"text",
"chat_id",
"user_id",
"message_id",
"platform",
"reply_to_message_id",
"message_thread_id",
"username",
"timestamp",
"raw_event",
"status_message_id",
}
)
def _create(**kwargs):
from messaging.models import IncomingMessage
defaults: dict[str, Any] = {
"text": "hello",
"chat_id": "chat_1",
"user_id": "user_1",
"message_id": "msg_1",
"platform": "telegram",
}
defaults.update(kwargs)
if "timestamp" in defaults and isinstance(defaults["timestamp"], str):
from datetime import datetime
defaults["timestamp"] = datetime.fromisoformat(defaults["timestamp"])
filtered = {k: v for k, v in defaults.items() if k in _valid_keys}
return IncomingMessage(**filtered)
return _create
@pytest.fixture(autouse=True)
def _propagate_loguru_to_caplog():
"""Route loguru logs to stdlib logging so pytest caplog captures them."""
from loguru import logger as loguru_logger
class _PropagateHandler:
def write(self, message):
record = message.record
level = record["level"].no
stdlib_level = min(level, logging.CRITICAL)
py_logger = logging.getLogger(record["name"])
py_logger.log(stdlib_level, record["message"])
handler_id = loguru_logger.add(_PropagateHandler(), format="{message}")
yield
with contextlib.suppress(ValueError):
loguru_logger.remove(
handler_id
) # Handler already removed (e.g. by test_logging_config)
|