chat-bot / tests /test_llm_manager.py
surahj's picture
Initial commit: LLM Chat Interface for HF Spaces
c2f9396
import pytest
import asyncio
from unittest.mock import Mock, patch, AsyncMock
from app.models import ChatMessage, ChatRequest
from app.llm_manager import LLMManager
class TestLLMManager:
"""Test the LLM manager functionality."""
@pytest.fixture
def llm_manager(self):
"""Create a fresh LLM manager instance for each test."""
return LLMManager()
@pytest.fixture
def sample_request(self):
"""Create a sample chat request."""
messages = [
ChatMessage(role="system", content="You are helpful."),
ChatMessage(role="user", content="Hello!"),
]
return ChatRequest(messages=messages, max_tokens=50)
def test_initialization(self, llm_manager):
"""Test LLM manager initialization."""
assert llm_manager.model_path is not None
assert llm_manager.model is None
assert llm_manager.tokenizer is None
assert llm_manager.model_type == "llama_cpp"
assert llm_manager.context_window == 2048
assert llm_manager.is_loaded is False
assert len(llm_manager.mock_responses) > 0
def test_custom_model_path(self):
"""Test LLM manager with custom model path."""
custom_path = "/custom/path/model.gguf"
llm_manager = LLMManager(model_path=custom_path)
assert llm_manager.model_path == custom_path
@pytest.mark.asyncio
async def test_load_model_mock_fallback(self, llm_manager):
"""Test model loading falls back to mock when no models available."""
with patch("app.llm_manager.LLAMA_AVAILABLE", False):
with patch("app.llm_manager.TRANSFORMERS_AVAILABLE", False):
with patch("app.llm_manager.Path") as mock_path:
mock_path.return_value.exists.return_value = False
success = await llm_manager.load_model()
assert success is True
assert llm_manager.is_loaded is True
assert llm_manager.model_type == "mock"
@pytest.mark.asyncio
async def test_load_llama_model(self, llm_manager):
"""Test loading model with llama-cpp-python."""
mock_llama = Mock()
with patch("app.llm_manager.LLAMA_AVAILABLE", True):
with patch("app.llm_manager.Path") as mock_path:
mock_path.return_value.exists.return_value = True
with patch("app.llm_manager.Llama", return_value=mock_llama):
with patch("os.cpu_count", return_value=4):
success = await llm_manager.load_model()
assert success is True
assert llm_manager.is_loaded is True
assert llm_manager.model_type == "llama_cpp"
assert llm_manager.model == mock_llama
@pytest.mark.asyncio
async def test_load_transformers_model(self, llm_manager):
"""Test loading model with transformers."""
mock_tokenizer = Mock()
mock_model = Mock()
with patch("app.llm_manager.LLAMA_AVAILABLE", False):
with patch("app.llm_manager.TRANSFORMERS_AVAILABLE", True):
with patch(
"app.llm_manager.AutoTokenizer.from_pretrained",
return_value=mock_tokenizer,
):
with patch(
"app.llm_manager.AutoModelForCausalLM.from_pretrained",
return_value=mock_model,
):
with patch(
"app.llm_manager.torch.cuda.is_available",
return_value=False,
):
success = await llm_manager.load_model()
assert success is True
assert llm_manager.is_loaded is True
assert llm_manager.model_type == "transformers"
assert llm_manager.tokenizer == mock_tokenizer
assert llm_manager.model == mock_model
@pytest.mark.asyncio
async def test_load_model_failure(self, llm_manager):
"""Test model loading failure handling."""
with patch("app.llm_manager.LLAMA_AVAILABLE", False):
with patch("app.llm_manager.TRANSFORMERS_AVAILABLE", False):
with patch("app.llm_manager.Path") as mock_path:
mock_path.return_value.exists.return_value = False
# Force an exception in the mock fallback
with patch.object(
llm_manager,
"_load_transformers_model",
side_effect=Exception("Load failed"),
):
success = await llm_manager.load_model()
assert (
success is True
) # Should still succeed with mock fallback
assert llm_manager.is_loaded is True
def test_format_messages(self, llm_manager):
"""Test message formatting."""
messages = [
ChatMessage(role="system", content="You are helpful."),
ChatMessage(role="user", content="Hello!"),
]
result = llm_manager.format_messages(messages)
expected = "<|system|>\nYou are helpful.\n<|/system|>\n<|user|>\nHello!\n<|/user|>\n<|assistant|>"
assert result == expected
def test_truncate_context_no_tokenizer(self, llm_manager):
"""Test context truncation when no tokenizer is available."""
prompt = "This is a test prompt"
result = llm_manager.truncate_context(prompt, 100)
assert result == prompt
def test_truncate_context_with_tokenizer(self, llm_manager):
"""Test context truncation with tokenizer."""
mock_tokenizer = Mock()
mock_tokenizer.encode.return_value = [1, 2, 3, 4, 5] * 500 # Long token list
mock_tokenizer.decode.return_value = "truncated prompt"
llm_manager.tokenizer = mock_tokenizer
prompt = "This is a test prompt"
result = llm_manager.truncate_context(prompt, 100)
assert result == "truncated prompt"
mock_tokenizer.encode.assert_called_once_with(prompt)
@pytest.mark.asyncio
async def test_generate_stream_not_loaded(self, llm_manager, sample_request):
"""Test that generate_stream raises error when model not loaded."""
with pytest.raises(RuntimeError, match="Model not loaded"):
async for _ in llm_manager.generate_stream(sample_request):
pass
@pytest.mark.asyncio
async def test_generate_mock_stream(self, llm_manager, sample_request):
"""Test mock streaming generation."""
llm_manager.is_loaded = True
llm_manager.model_type = "mock"
chunks = []
async for chunk in llm_manager.generate_stream(sample_request):
chunks.append(chunk)
# Should have multiple chunks (words) plus completion signal
assert len(chunks) > 1
# Check structure of chunks
for chunk in chunks[:-1]: # All except last
assert "id" in chunk
assert "object" in chunk
assert chunk["object"] == "chat.completion.chunk"
assert "choices" in chunk
assert len(chunk["choices"]) == 1
assert "delta" in chunk["choices"][0]
assert "content" in chunk["choices"][0]["delta"]
# Check completion signal
last_chunk = chunks[-1]
assert last_chunk["choices"][0]["finish_reason"] == "stop"
@pytest.mark.asyncio
async def test_generate_llama_stream(self, llm_manager, sample_request):
"""Test llama-cpp streaming generation."""
llm_manager.is_loaded = True
llm_manager.model_type = "llama_cpp"
llm_manager.model = Mock()
# Mock llama response
mock_response = [
{"choices": [{"delta": {"content": "Hello"}, "finish_reason": None}]},
{"choices": [{"delta": {"content": " world"}, "finish_reason": None}]},
{"choices": [{"delta": {}, "finish_reason": "stop"}]},
]
llm_manager.model.return_value = mock_response
chunks = []
async for chunk in llm_manager.generate_stream(sample_request):
chunks.append(chunk)
# Should have chunks for each token plus completion
assert len(chunks) >= 2
# Check that llama model was called correctly
llm_manager.model.assert_called_once()
call_args = llm_manager.model.call_args
assert call_args[1]["stream"] is True
assert call_args[1]["max_tokens"] == 50
@pytest.mark.asyncio
async def test_generate_transformers_stream(self, llm_manager, sample_request):
"""Test transformers streaming generation."""
llm_manager.is_loaded = True
llm_manager.model_type = "transformers"
llm_manager.tokenizer = Mock()
llm_manager.model = Mock()
# Mock tokenizer and model
llm_manager.tokenizer.encode.return_value = [1, 2, 3]
llm_manager.tokenizer.decode.return_value = "test"
llm_manager.tokenizer.eos_token_id = 0
mock_tensor = Mock()
mock_tensor.unsqueeze.return_value = mock_tensor
llm_manager.model.generate.return_value = mock_tensor
with patch("app.llm_manager.torch") as mock_torch:
mock_torch.cuda.is_available.return_value = False
mock_torch.cat.return_value = mock_tensor
chunks = []
async for chunk in llm_manager.generate_stream(sample_request):
chunks.append(chunk)
if len(chunks) >= 3: # Limit to avoid infinite loop
break
# Should have some chunks
assert len(chunks) > 0
@pytest.mark.asyncio
async def test_generate_stream_error_handling(self, llm_manager, sample_request):
"""Test error handling in streaming generation."""
llm_manager.is_loaded = True
llm_manager.model_type = "llama_cpp"
llm_manager.model = Mock()
# Mock llama to raise exception
llm_manager.model.side_effect = Exception("Generation failed")
chunks = []
async for chunk in llm_manager.generate_stream(sample_request):
chunks.append(chunk)
# Should have error chunk
assert len(chunks) == 1
assert "error" in chunks[0]
assert chunks[0]["error"]["type"] == "generation_error"
def test_get_model_info(self, llm_manager):
"""Test getting model information."""
llm_manager.is_loaded = True
llm_manager.model_type = "llama_cpp"
info = llm_manager.get_model_info()
assert info["id"] == "llama-2-7b-chat"
assert info["object"] == "model"
assert info["owned_by"] == "huggingface"
assert info["type"] == "llama_cpp"
assert info["context_window"] == 2048
assert info["is_loaded"] is True
def test_get_model_info_not_loaded(self, llm_manager):
"""Test getting model info when not loaded."""
info = llm_manager.get_model_info()
assert info["is_loaded"] is False
class TestLLMManagerIntegration:
"""Integration tests for LLM manager."""
@pytest.mark.asyncio
async def test_full_workflow_mock(self):
"""Test full workflow with mock model."""
llm_manager = LLMManager()
# Force mock mode
llm_manager.is_loaded = True
llm_manager.model_type = "mock"
# Create request
messages = [ChatMessage(role="user", content="Hello, how are you?")]
request = ChatRequest(messages=messages, max_tokens=20)
# Generate response
chunks = []
async for chunk in llm_manager.generate_stream(request):
chunks.append(chunk)
# Verify response
assert len(chunks) > 1
assert all("choices" in chunk for chunk in chunks[:-1])
assert chunks[-1]["choices"][0]["finish_reason"] == "stop"
@pytest.mark.asyncio
async def test_context_truncation_integration(self):
"""Test context truncation in full workflow."""
llm_manager = LLMManager()
await llm_manager.load_model()
# Create very long messages
long_message = "x" * 10000
messages = [
ChatMessage(role="system", content="You are helpful."),
ChatMessage(role="user", content=long_message),
ChatMessage(role="assistant", content=long_message),
ChatMessage(role="user", content="Short message"),
]
request = ChatRequest(messages=messages, max_tokens=50)
# Should not raise exception due to truncation
chunks = []
async for chunk in llm_manager.generate_stream(request):
chunks.append(chunk)
assert len(chunks) > 0
@pytest.mark.asyncio
async def test_different_model_types(self):
"""Test different model type configurations."""
llm_manager = LLMManager()
# Test llama_cpp type
llm_manager.model_type = "llama_cpp"
info = llm_manager.get_model_info()
assert info["type"] == "llama_cpp"
# Test transformers type
llm_manager.model_type = "transformers"
info = llm_manager.get_model_info()
assert info["type"] == "transformers"
# Test mock type
llm_manager.model_type = "mock"
info = llm_manager.get_model_info()
assert info["type"] == "mock"