File size: 4,406 Bytes
9281fab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""
Unit tests for the LLM abstraction layer.
"""

import pytest
from unittest.mock import Mock, patch, MagicMock
from pathlib import Path

from coda.core.llm import LLMProvider, GroqLLM, LLMResponse


class TestLLMResponse:
    """Tests for the LLMResponse model."""
    
    def test_response_creation(self):
        """Test creating a valid response."""
        response = LLMResponse(
            content="Test content",
            model="test-model",
            usage={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
            finish_reason="stop"
        )
        
        assert response.content == "Test content"
        assert response.model == "test-model"
        assert response.usage["total_tokens"] == 30
        assert response.finish_reason == "stop"


class TestGroqLLM:
    """Tests for the GroqLLM implementation."""
    
    @pytest.fixture
    def mock_groq_client(self):
        """Create a mock Groq client."""
        with patch("coda.core.llm.Groq") as mock:
            client_instance = Mock()
            mock.return_value = client_instance
            yield client_instance
    
    def test_initialization(self, mock_groq_client):
        """Test LLM initialization with custom parameters."""
        llm = GroqLLM(
            api_key="test-key",
            default_model="custom-model",
            temperature=0.5,
            max_tokens=2048,
        )
        
        assert llm._default_model == "custom-model"
        assert llm._temperature == 0.5
        assert llm._max_tokens == 2048
    
    def test_complete_success(self, mock_groq_client):
        """Test successful completion."""
        mock_response = MagicMock()
        mock_response.choices = [MagicMock()]
        mock_response.choices[0].message.content = "Generated text"
        mock_response.choices[0].finish_reason = "stop"
        mock_response.model = "llama-3.3-70b-versatile"
        mock_response.usage = MagicMock()
        mock_response.usage.prompt_tokens = 10
        mock_response.usage.completion_tokens = 20
        mock_response.usage.total_tokens = 30
        
        mock_groq_client.chat.completions.create.return_value = mock_response
        
        llm = GroqLLM(api_key="test-key")
        response = llm.complete(
            prompt="Test prompt",
            system_prompt="System prompt"
        )
        
        assert response.content == "Generated text"
        assert response.finish_reason == "stop"
        mock_groq_client.chat.completions.create.assert_called_once()
    
    def test_complete_with_retry(self, mock_groq_client):
        """Test retry logic on failure."""
        mock_groq_client.chat.completions.create.side_effect = [
            Exception("Rate limited"),
            MagicMock(
                choices=[MagicMock(message=MagicMock(content="Success"), finish_reason="stop")],
                model="test",
                usage=MagicMock(prompt_tokens=0, completion_tokens=0, total_tokens=0)
            )
        ]
        
        llm = GroqLLM(api_key="test-key", retry_delay=0.01)
        response = llm.complete(prompt="Test")
        
        assert response.content == "Success"
        assert mock_groq_client.chat.completions.create.call_count == 2
    
    def test_build_messages(self, mock_groq_client):
        """Test message building with system prompt."""
        llm = GroqLLM(api_key="test-key")
        
        messages = llm._build_messages(
            prompt="User message",
            system_prompt="System message"
        )
        
        assert len(messages) == 2
        assert messages[0]["role"] == "system"
        assert messages[0]["content"] == "System message"
        assert messages[1]["role"] == "user"
        assert messages[1]["content"] == "User message"
    
    def test_build_messages_no_system(self, mock_groq_client):
        """Test message building without system prompt."""
        llm = GroqLLM(api_key="test-key")
        
        messages = llm._build_messages(prompt="User message")
        
        assert len(messages) == 1
        assert messages[0]["role"] == "user"


class TestLLMProviderInterface:
    """Tests for the abstract interface."""
    
    def test_interface_methods(self):
        """Verify LLMProvider defines required methods."""
        assert hasattr(LLMProvider, "complete")
        assert hasattr(LLMProvider, "complete_with_image")