File size: 5,475 Bytes
dc14519
 
f6fdf6a
dc14519
 
f6fdf6a
dc14519
f6fdf6a
 
 
 
 
dc14519
f6fdf6a
dc14519
 
 
 
 
f6fdf6a
 
 
dc14519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f6fdf6a
 
dc14519
 
 
 
 
 
 
f6fdf6a
 
 
dc14519
f6fdf6a
dc14519
 
 
 
 
 
 
 
 
 
 
 
 
f6fdf6a
dc14519
 
 
 
f6fdf6a
 
dc14519
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64c014e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
"""Tests for Transformers provider."""

import pytest
from unittest.mock import patch, MagicMock, AsyncMock
import torch

from app.providers.transformers_provider import list_models, chat, is_model_ready, TransformersProvider


@pytest.mark.asyncio
async def test_list_models_success():
    """Test successful model listing."""
    result = await list_models()
    
    assert "object" in result
    assert result["object"] == "list"
    assert "data" in result
    assert len(result["data"]) > 0
    assert result["data"][0]["object"] == "model"


@pytest.mark.asyncio
async def test_list_models_structure():
    """Test model listing returns correct structure."""
    result = await list_models()
    
    model = result["data"][0]
    assert "id" in model
    assert "object" in model
    assert "owned_by" in model
    assert model["object"] == "model"


@pytest.mark.asyncio
async def test_chat_with_mock_model():
    """Test chat completion with mocked model."""
    payload = {
        "model": "test-model",
        "messages": [{"role": "user", "content": "hello"}],
        "temperature": 0.7,
        "max_tokens": 100
    }
    
    # Mock the global model and tokenizer
    mock_tokenizer = MagicMock()
    mock_tokenizer.apply_chat_template.return_value = "formatted prompt"
    mock_tokenizer.encode.return_value = [1, 2, 3]
    mock_tokenizer.decode.return_value = "test response"
    mock_tokenizer.__call__.return_value = {
        "input_ids": torch.tensor([[1, 2, 3]]),
        "attention_mask": torch.tensor([[1, 1, 1]])
    }
    
    mock_model = MagicMock()
    mock_outputs = MagicMock()
    mock_outputs[0] = torch.tensor([[1, 2, 3, 4, 5]])
    mock_model.generate.return_value = mock_outputs
    mock_model.get_input_embeddings.return_value.num_embeddings = 1000
    
    with patch('app.providers.transformers_provider.model', mock_model), \
         patch('app.providers.transformers_provider.tokenizer', mock_tokenizer), \
         patch('app.providers.transformers_provider.is_model_ready', return_value=True), \
         patch('app.providers.transformers_provider._initialized', True):
        
        result = await chat(payload, stream=False)
        
        assert "id" in result
        assert "object" in result
        assert result["object"] == "chat.completion"
        assert "choices" in result
        assert len(result["choices"]) > 0
        assert "usage" in result


@pytest.mark.asyncio
async def test_chat_streaming():
    """Test chat completion with streaming."""
    payload = {
        "model": "test-model",
        "messages": [{"role": "user", "content": "hello"}],
        "stream": True
    }
    
    # Mock for streaming
    mock_tokenizer = MagicMock()
    mock_tokenizer.apply_chat_template.return_value = "formatted prompt"
    mock_tokenizer.__call__.return_value = {
        "input_ids": torch.tensor([[1, 2, 3]]),
        "attention_mask": torch.tensor([[1, 1, 1]])
    }
    
    with patch('app.providers.transformers_provider.model', MagicMock()), \
         patch('app.providers.transformers_provider.tokenizer', mock_tokenizer), \
         patch('app.providers.transformers_provider.is_model_ready', return_value=True), \
         patch('app.providers.transformers_provider._initialized', True):
        
        result = await chat(payload, stream=True)
        
        # Should return an async iterator
        assert hasattr(result, '__aiter__')


def test_is_model_ready_false_when_not_initialized():
    """Test is_model_ready returns False when model not initialized."""
    with patch('app.providers.transformers_provider._initialized', False), \
         patch('app.providers.transformers_provider.model', None), \
         patch('app.providers.transformers_provider.tokenizer', None):
        
        assert is_model_ready() is False


def test_is_model_ready_true_when_initialized():
    """Test is_model_ready returns True when model is initialized."""
    mock_model = MagicMock()
    mock_tokenizer = MagicMock()
    
    with patch('app.providers.transformers_provider._initialized', True), \
         patch('app.providers.transformers_provider.model', mock_model), \
         patch('app.providers.transformers_provider.tokenizer', mock_tokenizer):
        
        assert is_model_ready() is True


def test_provider_format_tools_for_prompt():
    """Test tool formatting for prompt."""
    provider = TransformersProvider()
    tools = [
        {
            "function": {
                "name": "test_tool",
                "description": "A test tool",
                "parameters": {"type": "object", "properties": {}}
            }
        }
    ]
    
    result = provider._format_tools_for_prompt(tools)
    
    assert "test_tool" in result
    assert "CRITICAL" in result
    assert "<tool_call>" in result


def test_provider_remove_reasoning_tags():
    """Test reasoning tag removal."""
    provider = TransformersProvider()
    
    text_with_tags = "<think>Some reasoning</think>Actual answer"
    result = provider._remove_reasoning_tags(text_with_tags)
    
    assert "<think>" not in result
    assert "Actual answer" in result


def test_provider_extract_json_by_brace_matching():
    """Test JSON extraction by brace matching."""
    provider = TransformersProvider()
    
    text = "Some text {\"key\": \"value\"} more text"
    result = provider._extract_json_by_brace_matching(text)
    
    assert result is not None
    assert result.get("key") == "value"