import pytest import asyncio from unittest.mock import Mock, patch, AsyncMock from app.models import ChatMessage, ChatRequest from app.llm_manager import LLMManager class TestLLMManager: """Test the LLM manager functionality.""" @pytest.fixture def llm_manager(self): """Create a fresh LLM manager instance for each test.""" return LLMManager() @pytest.fixture def sample_request(self): """Create a sample chat request.""" messages = [ ChatMessage(role="system", content="You are helpful."), ChatMessage(role="user", content="Hello!"), ] return ChatRequest(messages=messages, max_tokens=50) def test_initialization(self, llm_manager): """Test LLM manager initialization.""" assert llm_manager.model_path is not None assert llm_manager.model is None assert llm_manager.tokenizer is None assert llm_manager.model_type == "llama_cpp" assert llm_manager.context_window == 2048 assert llm_manager.is_loaded is False assert len(llm_manager.mock_responses) > 0 def test_custom_model_path(self): """Test LLM manager with custom model path.""" custom_path = "/custom/path/model.gguf" llm_manager = LLMManager(model_path=custom_path) assert llm_manager.model_path == custom_path @pytest.mark.asyncio async def test_load_model_mock_fallback(self, llm_manager): """Test model loading falls back to mock when no models available.""" with patch("app.llm_manager.LLAMA_AVAILABLE", False): with patch("app.llm_manager.TRANSFORMERS_AVAILABLE", False): with patch("app.llm_manager.Path") as mock_path: mock_path.return_value.exists.return_value = False success = await llm_manager.load_model() assert success is True assert llm_manager.is_loaded is True assert llm_manager.model_type == "mock" @pytest.mark.asyncio async def test_load_llama_model(self, llm_manager): """Test loading model with llama-cpp-python.""" mock_llama = Mock() with patch("app.llm_manager.LLAMA_AVAILABLE", True): with patch("app.llm_manager.Path") as mock_path: mock_path.return_value.exists.return_value = True with patch("app.llm_manager.Llama", return_value=mock_llama): with patch("os.cpu_count", return_value=4): success = await llm_manager.load_model() assert success is True assert llm_manager.is_loaded is True assert llm_manager.model_type == "llama_cpp" assert llm_manager.model == mock_llama @pytest.mark.asyncio async def test_load_transformers_model(self, llm_manager): """Test loading model with transformers.""" mock_tokenizer = Mock() mock_model = Mock() with patch("app.llm_manager.LLAMA_AVAILABLE", False): with patch("app.llm_manager.TRANSFORMERS_AVAILABLE", True): with patch( "app.llm_manager.AutoTokenizer.from_pretrained", return_value=mock_tokenizer, ): with patch( "app.llm_manager.AutoModelForCausalLM.from_pretrained", return_value=mock_model, ): with patch( "app.llm_manager.torch.cuda.is_available", return_value=False, ): success = await llm_manager.load_model() assert success is True assert llm_manager.is_loaded is True assert llm_manager.model_type == "transformers" assert llm_manager.tokenizer == mock_tokenizer assert llm_manager.model == mock_model @pytest.mark.asyncio async def test_load_model_failure(self, llm_manager): """Test model loading failure handling.""" with patch("app.llm_manager.LLAMA_AVAILABLE", False): with patch("app.llm_manager.TRANSFORMERS_AVAILABLE", False): with patch("app.llm_manager.Path") as mock_path: mock_path.return_value.exists.return_value = False # Force an exception in the mock fallback with patch.object( llm_manager, "_load_transformers_model", side_effect=Exception("Load failed"), ): success = await llm_manager.load_model() assert ( success is True ) # Should still succeed with mock fallback assert llm_manager.is_loaded is True def test_format_messages(self, llm_manager): """Test message formatting.""" messages = [ ChatMessage(role="system", content="You are helpful."), ChatMessage(role="user", content="Hello!"), ] result = llm_manager.format_messages(messages) expected = "<|system|>\nYou are helpful.\n<|/system|>\n<|user|>\nHello!\n<|/user|>\n<|assistant|>" assert result == expected def test_truncate_context_no_tokenizer(self, llm_manager): """Test context truncation when no tokenizer is available.""" prompt = "This is a test prompt" result = llm_manager.truncate_context(prompt, 100) assert result == prompt def test_truncate_context_with_tokenizer(self, llm_manager): """Test context truncation with tokenizer.""" mock_tokenizer = Mock() mock_tokenizer.encode.return_value = [1, 2, 3, 4, 5] * 500 # Long token list mock_tokenizer.decode.return_value = "truncated prompt" llm_manager.tokenizer = mock_tokenizer prompt = "This is a test prompt" result = llm_manager.truncate_context(prompt, 100) assert result == "truncated prompt" mock_tokenizer.encode.assert_called_once_with(prompt) @pytest.mark.asyncio async def test_generate_stream_not_loaded(self, llm_manager, sample_request): """Test that generate_stream raises error when model not loaded.""" with pytest.raises(RuntimeError, match="Model not loaded"): async for _ in llm_manager.generate_stream(sample_request): pass @pytest.mark.asyncio async def test_generate_mock_stream(self, llm_manager, sample_request): """Test mock streaming generation.""" llm_manager.is_loaded = True llm_manager.model_type = "mock" chunks = [] async for chunk in llm_manager.generate_stream(sample_request): chunks.append(chunk) # Should have multiple chunks (words) plus completion signal assert len(chunks) > 1 # Check structure of chunks for chunk in chunks[:-1]: # All except last assert "id" in chunk assert "object" in chunk assert chunk["object"] == "chat.completion.chunk" assert "choices" in chunk assert len(chunk["choices"]) == 1 assert "delta" in chunk["choices"][0] assert "content" in chunk["choices"][0]["delta"] # Check completion signal last_chunk = chunks[-1] assert last_chunk["choices"][0]["finish_reason"] == "stop" @pytest.mark.asyncio async def test_generate_llama_stream(self, llm_manager, sample_request): """Test llama-cpp streaming generation.""" llm_manager.is_loaded = True llm_manager.model_type = "llama_cpp" llm_manager.model = Mock() # Mock llama response mock_response = [ {"choices": [{"delta": {"content": "Hello"}, "finish_reason": None}]}, {"choices": [{"delta": {"content": " world"}, "finish_reason": None}]}, {"choices": [{"delta": {}, "finish_reason": "stop"}]}, ] llm_manager.model.return_value = mock_response chunks = [] async for chunk in llm_manager.generate_stream(sample_request): chunks.append(chunk) # Should have chunks for each token plus completion assert len(chunks) >= 2 # Check that llama model was called correctly llm_manager.model.assert_called_once() call_args = llm_manager.model.call_args assert call_args[1]["stream"] is True assert call_args[1]["max_tokens"] == 50 @pytest.mark.asyncio async def test_generate_transformers_stream(self, llm_manager, sample_request): """Test transformers streaming generation.""" llm_manager.is_loaded = True llm_manager.model_type = "transformers" llm_manager.tokenizer = Mock() llm_manager.model = Mock() # Mock tokenizer and model llm_manager.tokenizer.encode.return_value = [1, 2, 3] llm_manager.tokenizer.decode.return_value = "test" llm_manager.tokenizer.eos_token_id = 0 mock_tensor = Mock() mock_tensor.unsqueeze.return_value = mock_tensor llm_manager.model.generate.return_value = mock_tensor with patch("app.llm_manager.torch") as mock_torch: mock_torch.cuda.is_available.return_value = False mock_torch.cat.return_value = mock_tensor chunks = [] async for chunk in llm_manager.generate_stream(sample_request): chunks.append(chunk) if len(chunks) >= 3: # Limit to avoid infinite loop break # Should have some chunks assert len(chunks) > 0 @pytest.mark.asyncio async def test_generate_stream_error_handling(self, llm_manager, sample_request): """Test error handling in streaming generation.""" llm_manager.is_loaded = True llm_manager.model_type = "llama_cpp" llm_manager.model = Mock() # Mock llama to raise exception llm_manager.model.side_effect = Exception("Generation failed") chunks = [] async for chunk in llm_manager.generate_stream(sample_request): chunks.append(chunk) # Should have error chunk assert len(chunks) == 1 assert "error" in chunks[0] assert chunks[0]["error"]["type"] == "generation_error" def test_get_model_info(self, llm_manager): """Test getting model information.""" llm_manager.is_loaded = True llm_manager.model_type = "llama_cpp" info = llm_manager.get_model_info() assert info["id"] == "llama-2-7b-chat" assert info["object"] == "model" assert info["owned_by"] == "huggingface" assert info["type"] == "llama_cpp" assert info["context_window"] == 2048 assert info["is_loaded"] is True def test_get_model_info_not_loaded(self, llm_manager): """Test getting model info when not loaded.""" info = llm_manager.get_model_info() assert info["is_loaded"] is False class TestLLMManagerIntegration: """Integration tests for LLM manager.""" @pytest.mark.asyncio async def test_full_workflow_mock(self): """Test full workflow with mock model.""" llm_manager = LLMManager() # Force mock mode llm_manager.is_loaded = True llm_manager.model_type = "mock" # Create request messages = [ChatMessage(role="user", content="Hello, how are you?")] request = ChatRequest(messages=messages, max_tokens=20) # Generate response chunks = [] async for chunk in llm_manager.generate_stream(request): chunks.append(chunk) # Verify response assert len(chunks) > 1 assert all("choices" in chunk for chunk in chunks[:-1]) assert chunks[-1]["choices"][0]["finish_reason"] == "stop" @pytest.mark.asyncio async def test_context_truncation_integration(self): """Test context truncation in full workflow.""" llm_manager = LLMManager() await llm_manager.load_model() # Create very long messages long_message = "x" * 10000 messages = [ ChatMessage(role="system", content="You are helpful."), ChatMessage(role="user", content=long_message), ChatMessage(role="assistant", content=long_message), ChatMessage(role="user", content="Short message"), ] request = ChatRequest(messages=messages, max_tokens=50) # Should not raise exception due to truncation chunks = [] async for chunk in llm_manager.generate_stream(request): chunks.append(chunk) assert len(chunks) > 0 @pytest.mark.asyncio async def test_different_model_types(self): """Test different model type configurations.""" llm_manager = LLMManager() # Test llama_cpp type llm_manager.model_type = "llama_cpp" info = llm_manager.get_model_info() assert info["type"] == "llama_cpp" # Test transformers type llm_manager.model_type = "transformers" info = llm_manager.get_model_info() assert info["type"] == "transformers" # Test mock type llm_manager.model_type = "mock" info = llm_manager.get_model_info() assert info["type"] == "mock"