SCoDA / tests /test_llm.py
vanishingradient's picture
Added init files
9281fab
"""
Unit tests for the LLM abstraction layer.
"""
import pytest
from unittest.mock import Mock, patch, MagicMock
from pathlib import Path
from coda.core.llm import LLMProvider, GroqLLM, LLMResponse
class TestLLMResponse:
"""Tests for the LLMResponse model."""
def test_response_creation(self):
"""Test creating a valid response."""
response = LLMResponse(
content="Test content",
model="test-model",
usage={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
finish_reason="stop"
)
assert response.content == "Test content"
assert response.model == "test-model"
assert response.usage["total_tokens"] == 30
assert response.finish_reason == "stop"
class TestGroqLLM:
"""Tests for the GroqLLM implementation."""
@pytest.fixture
def mock_groq_client(self):
"""Create a mock Groq client."""
with patch("coda.core.llm.Groq") as mock:
client_instance = Mock()
mock.return_value = client_instance
yield client_instance
def test_initialization(self, mock_groq_client):
"""Test LLM initialization with custom parameters."""
llm = GroqLLM(
api_key="test-key",
default_model="custom-model",
temperature=0.5,
max_tokens=2048,
)
assert llm._default_model == "custom-model"
assert llm._temperature == 0.5
assert llm._max_tokens == 2048
def test_complete_success(self, mock_groq_client):
"""Test successful completion."""
mock_response = MagicMock()
mock_response.choices = [MagicMock()]
mock_response.choices[0].message.content = "Generated text"
mock_response.choices[0].finish_reason = "stop"
mock_response.model = "llama-3.3-70b-versatile"
mock_response.usage = MagicMock()
mock_response.usage.prompt_tokens = 10
mock_response.usage.completion_tokens = 20
mock_response.usage.total_tokens = 30
mock_groq_client.chat.completions.create.return_value = mock_response
llm = GroqLLM(api_key="test-key")
response = llm.complete(
prompt="Test prompt",
system_prompt="System prompt"
)
assert response.content == "Generated text"
assert response.finish_reason == "stop"
mock_groq_client.chat.completions.create.assert_called_once()
def test_complete_with_retry(self, mock_groq_client):
"""Test retry logic on failure."""
mock_groq_client.chat.completions.create.side_effect = [
Exception("Rate limited"),
MagicMock(
choices=[MagicMock(message=MagicMock(content="Success"), finish_reason="stop")],
model="test",
usage=MagicMock(prompt_tokens=0, completion_tokens=0, total_tokens=0)
)
]
llm = GroqLLM(api_key="test-key", retry_delay=0.01)
response = llm.complete(prompt="Test")
assert response.content == "Success"
assert mock_groq_client.chat.completions.create.call_count == 2
def test_build_messages(self, mock_groq_client):
"""Test message building with system prompt."""
llm = GroqLLM(api_key="test-key")
messages = llm._build_messages(
prompt="User message",
system_prompt="System message"
)
assert len(messages) == 2
assert messages[0]["role"] == "system"
assert messages[0]["content"] == "System message"
assert messages[1]["role"] == "user"
assert messages[1]["content"] == "User message"
def test_build_messages_no_system(self, mock_groq_client):
"""Test message building without system prompt."""
llm = GroqLLM(api_key="test-key")
messages = llm._build_messages(prompt="User message")
assert len(messages) == 1
assert messages[0]["role"] == "user"
class TestLLMProviderInterface:
"""Tests for the abstract interface."""
def test_interface_methods(self):
"""Verify LLMProvider defines required methods."""
assert hasattr(LLMProvider, "complete")
assert hasattr(LLMProvider, "complete_with_image")