File size: 13,637 Bytes
c2f9396 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 | import pytest
import asyncio
from unittest.mock import Mock, patch, AsyncMock
from app.models import ChatMessage, ChatRequest
from app.llm_manager import LLMManager
class TestLLMManager:
"""Test the LLM manager functionality."""
@pytest.fixture
def llm_manager(self):
"""Create a fresh LLM manager instance for each test."""
return LLMManager()
@pytest.fixture
def sample_request(self):
"""Create a sample chat request."""
messages = [
ChatMessage(role="system", content="You are helpful."),
ChatMessage(role="user", content="Hello!"),
]
return ChatRequest(messages=messages, max_tokens=50)
def test_initialization(self, llm_manager):
"""Test LLM manager initialization."""
assert llm_manager.model_path is not None
assert llm_manager.model is None
assert llm_manager.tokenizer is None
assert llm_manager.model_type == "llama_cpp"
assert llm_manager.context_window == 2048
assert llm_manager.is_loaded is False
assert len(llm_manager.mock_responses) > 0
def test_custom_model_path(self):
"""Test LLM manager with custom model path."""
custom_path = "/custom/path/model.gguf"
llm_manager = LLMManager(model_path=custom_path)
assert llm_manager.model_path == custom_path
@pytest.mark.asyncio
async def test_load_model_mock_fallback(self, llm_manager):
"""Test model loading falls back to mock when no models available."""
with patch("app.llm_manager.LLAMA_AVAILABLE", False):
with patch("app.llm_manager.TRANSFORMERS_AVAILABLE", False):
with patch("app.llm_manager.Path") as mock_path:
mock_path.return_value.exists.return_value = False
success = await llm_manager.load_model()
assert success is True
assert llm_manager.is_loaded is True
assert llm_manager.model_type == "mock"
@pytest.mark.asyncio
async def test_load_llama_model(self, llm_manager):
"""Test loading model with llama-cpp-python."""
mock_llama = Mock()
with patch("app.llm_manager.LLAMA_AVAILABLE", True):
with patch("app.llm_manager.Path") as mock_path:
mock_path.return_value.exists.return_value = True
with patch("app.llm_manager.Llama", return_value=mock_llama):
with patch("os.cpu_count", return_value=4):
success = await llm_manager.load_model()
assert success is True
assert llm_manager.is_loaded is True
assert llm_manager.model_type == "llama_cpp"
assert llm_manager.model == mock_llama
@pytest.mark.asyncio
async def test_load_transformers_model(self, llm_manager):
"""Test loading model with transformers."""
mock_tokenizer = Mock()
mock_model = Mock()
with patch("app.llm_manager.LLAMA_AVAILABLE", False):
with patch("app.llm_manager.TRANSFORMERS_AVAILABLE", True):
with patch(
"app.llm_manager.AutoTokenizer.from_pretrained",
return_value=mock_tokenizer,
):
with patch(
"app.llm_manager.AutoModelForCausalLM.from_pretrained",
return_value=mock_model,
):
with patch(
"app.llm_manager.torch.cuda.is_available",
return_value=False,
):
success = await llm_manager.load_model()
assert success is True
assert llm_manager.is_loaded is True
assert llm_manager.model_type == "transformers"
assert llm_manager.tokenizer == mock_tokenizer
assert llm_manager.model == mock_model
@pytest.mark.asyncio
async def test_load_model_failure(self, llm_manager):
"""Test model loading failure handling."""
with patch("app.llm_manager.LLAMA_AVAILABLE", False):
with patch("app.llm_manager.TRANSFORMERS_AVAILABLE", False):
with patch("app.llm_manager.Path") as mock_path:
mock_path.return_value.exists.return_value = False
# Force an exception in the mock fallback
with patch.object(
llm_manager,
"_load_transformers_model",
side_effect=Exception("Load failed"),
):
success = await llm_manager.load_model()
assert (
success is True
) # Should still succeed with mock fallback
assert llm_manager.is_loaded is True
def test_format_messages(self, llm_manager):
"""Test message formatting."""
messages = [
ChatMessage(role="system", content="You are helpful."),
ChatMessage(role="user", content="Hello!"),
]
result = llm_manager.format_messages(messages)
expected = "<|system|>\nYou are helpful.\n<|/system|>\n<|user|>\nHello!\n<|/user|>\n<|assistant|>"
assert result == expected
def test_truncate_context_no_tokenizer(self, llm_manager):
"""Test context truncation when no tokenizer is available."""
prompt = "This is a test prompt"
result = llm_manager.truncate_context(prompt, 100)
assert result == prompt
def test_truncate_context_with_tokenizer(self, llm_manager):
"""Test context truncation with tokenizer."""
mock_tokenizer = Mock()
mock_tokenizer.encode.return_value = [1, 2, 3, 4, 5] * 500 # Long token list
mock_tokenizer.decode.return_value = "truncated prompt"
llm_manager.tokenizer = mock_tokenizer
prompt = "This is a test prompt"
result = llm_manager.truncate_context(prompt, 100)
assert result == "truncated prompt"
mock_tokenizer.encode.assert_called_once_with(prompt)
@pytest.mark.asyncio
async def test_generate_stream_not_loaded(self, llm_manager, sample_request):
"""Test that generate_stream raises error when model not loaded."""
with pytest.raises(RuntimeError, match="Model not loaded"):
async for _ in llm_manager.generate_stream(sample_request):
pass
@pytest.mark.asyncio
async def test_generate_mock_stream(self, llm_manager, sample_request):
"""Test mock streaming generation."""
llm_manager.is_loaded = True
llm_manager.model_type = "mock"
chunks = []
async for chunk in llm_manager.generate_stream(sample_request):
chunks.append(chunk)
# Should have multiple chunks (words) plus completion signal
assert len(chunks) > 1
# Check structure of chunks
for chunk in chunks[:-1]: # All except last
assert "id" in chunk
assert "object" in chunk
assert chunk["object"] == "chat.completion.chunk"
assert "choices" in chunk
assert len(chunk["choices"]) == 1
assert "delta" in chunk["choices"][0]
assert "content" in chunk["choices"][0]["delta"]
# Check completion signal
last_chunk = chunks[-1]
assert last_chunk["choices"][0]["finish_reason"] == "stop"
@pytest.mark.asyncio
async def test_generate_llama_stream(self, llm_manager, sample_request):
"""Test llama-cpp streaming generation."""
llm_manager.is_loaded = True
llm_manager.model_type = "llama_cpp"
llm_manager.model = Mock()
# Mock llama response
mock_response = [
{"choices": [{"delta": {"content": "Hello"}, "finish_reason": None}]},
{"choices": [{"delta": {"content": " world"}, "finish_reason": None}]},
{"choices": [{"delta": {}, "finish_reason": "stop"}]},
]
llm_manager.model.return_value = mock_response
chunks = []
async for chunk in llm_manager.generate_stream(sample_request):
chunks.append(chunk)
# Should have chunks for each token plus completion
assert len(chunks) >= 2
# Check that llama model was called correctly
llm_manager.model.assert_called_once()
call_args = llm_manager.model.call_args
assert call_args[1]["stream"] is True
assert call_args[1]["max_tokens"] == 50
@pytest.mark.asyncio
async def test_generate_transformers_stream(self, llm_manager, sample_request):
"""Test transformers streaming generation."""
llm_manager.is_loaded = True
llm_manager.model_type = "transformers"
llm_manager.tokenizer = Mock()
llm_manager.model = Mock()
# Mock tokenizer and model
llm_manager.tokenizer.encode.return_value = [1, 2, 3]
llm_manager.tokenizer.decode.return_value = "test"
llm_manager.tokenizer.eos_token_id = 0
mock_tensor = Mock()
mock_tensor.unsqueeze.return_value = mock_tensor
llm_manager.model.generate.return_value = mock_tensor
with patch("app.llm_manager.torch") as mock_torch:
mock_torch.cuda.is_available.return_value = False
mock_torch.cat.return_value = mock_tensor
chunks = []
async for chunk in llm_manager.generate_stream(sample_request):
chunks.append(chunk)
if len(chunks) >= 3: # Limit to avoid infinite loop
break
# Should have some chunks
assert len(chunks) > 0
@pytest.mark.asyncio
async def test_generate_stream_error_handling(self, llm_manager, sample_request):
"""Test error handling in streaming generation."""
llm_manager.is_loaded = True
llm_manager.model_type = "llama_cpp"
llm_manager.model = Mock()
# Mock llama to raise exception
llm_manager.model.side_effect = Exception("Generation failed")
chunks = []
async for chunk in llm_manager.generate_stream(sample_request):
chunks.append(chunk)
# Should have error chunk
assert len(chunks) == 1
assert "error" in chunks[0]
assert chunks[0]["error"]["type"] == "generation_error"
def test_get_model_info(self, llm_manager):
"""Test getting model information."""
llm_manager.is_loaded = True
llm_manager.model_type = "llama_cpp"
info = llm_manager.get_model_info()
assert info["id"] == "llama-2-7b-chat"
assert info["object"] == "model"
assert info["owned_by"] == "huggingface"
assert info["type"] == "llama_cpp"
assert info["context_window"] == 2048
assert info["is_loaded"] is True
def test_get_model_info_not_loaded(self, llm_manager):
"""Test getting model info when not loaded."""
info = llm_manager.get_model_info()
assert info["is_loaded"] is False
class TestLLMManagerIntegration:
"""Integration tests for LLM manager."""
@pytest.mark.asyncio
async def test_full_workflow_mock(self):
"""Test full workflow with mock model."""
llm_manager = LLMManager()
# Force mock mode
llm_manager.is_loaded = True
llm_manager.model_type = "mock"
# Create request
messages = [ChatMessage(role="user", content="Hello, how are you?")]
request = ChatRequest(messages=messages, max_tokens=20)
# Generate response
chunks = []
async for chunk in llm_manager.generate_stream(request):
chunks.append(chunk)
# Verify response
assert len(chunks) > 1
assert all("choices" in chunk for chunk in chunks[:-1])
assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
@pytest.mark.asyncio
async def test_context_truncation_integration(self):
"""Test context truncation in full workflow."""
llm_manager = LLMManager()
await llm_manager.load_model()
# Create very long messages
long_message = "x" * 10000
messages = [
ChatMessage(role="system", content="You are helpful."),
ChatMessage(role="user", content=long_message),
ChatMessage(role="assistant", content=long_message),
ChatMessage(role="user", content="Short message"),
]
request = ChatRequest(messages=messages, max_tokens=50)
# Should not raise exception due to truncation
chunks = []
async for chunk in llm_manager.generate_stream(request):
chunks.append(chunk)
assert len(chunks) > 0
@pytest.mark.asyncio
async def test_different_model_types(self):
"""Test different model type configurations."""
llm_manager = LLMManager()
# Test llama_cpp type
llm_manager.model_type = "llama_cpp"
info = llm_manager.get_model_info()
assert info["type"] == "llama_cpp"
# Test transformers type
llm_manager.model_type = "transformers"
info = llm_manager.get_model_info()
assert info["type"] == "transformers"
# Test mock type
llm_manager.model_type = "mock"
info = llm_manager.get_model_info()
assert info["type"] == "mock"
|