File size: 5,338 Bytes
adcb9bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""Tests for model loading and inference."""

import pytest
from unittest.mock import patch, MagicMock
import sys
import os

# Add parent directory to path for imports
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

from config import estimate_model_size, should_quantize


class TestModelSizeEstimation:
    """Test model size estimation logic."""

    def test_known_model_size(self):
        """Test size estimation for known models."""
        assert estimate_model_size("meta-llama/Llama-3.1-8B-Instruct") == 8
        assert estimate_model_size("meta-llama/Llama-3.1-70B-Instruct") == 70
        assert estimate_model_size("mistralai/Mistral-7B-Instruct-v0.3") == 7

    def test_extract_size_from_name(self):
        """Test size extraction from model name pattern."""
        assert estimate_model_size("some-org/CustomModel-13B") == 13
        assert estimate_model_size("another/model-2B-test") == 2
        assert estimate_model_size("org/Model-32B-Instruct") == 32

    def test_unknown_model_size(self):
        """Test handling of models with unknown size."""
        assert estimate_model_size("unknown/model-without-size") is None
        assert estimate_model_size("org/mystery-model") is None


class TestQuantizationDecision:
    """Test automatic quantization decisions."""

    def test_small_model_no_quantization(self):
        """Small models should not be quantized."""
        assert should_quantize("meta-llama/Llama-3.1-8B-Instruct") == "none"
        assert should_quantize("mistralai/Mistral-7B-Instruct-v0.3") == "none"

    def test_large_model_int4_quantization(self):
        """70B+ models should use INT4."""
        assert should_quantize("meta-llama/Llama-3.1-70B-Instruct") == "int4"
        assert should_quantize("Qwen/Qwen2.5-72B-Instruct") == "int4"

    def test_unknown_model_no_quantization(self):
        """Unknown models should not be auto-quantized."""
        assert should_quantize("unknown/mystery-model") == "none"


class TestModelLoading:
    """Test model loading functionality."""

    @patch("models.AutoModelForCausalLM")
    @patch("models.AutoTokenizer")
    def test_load_model_creates_loaded_model(
        self, mock_tokenizer_class, mock_model_class, mock_tokenizer, mock_model
    ):
        """Test that load_model returns a LoadedModel instance."""
        mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
        mock_model_class.from_pretrained.return_value = mock_model

        from models import load_model, unload_model

        # Ensure clean state
        unload_model()

        loaded = load_model("test-model/test-7B")

        assert loaded.model_id == "test-model/test-7B"
        assert loaded.model is not None
        assert loaded.tokenizer is not None

    @patch("models.AutoModelForCausalLM")
    @patch("models.AutoTokenizer")
    def test_load_model_caches_result(
        self, mock_tokenizer_class, mock_model_class, mock_tokenizer, mock_model
    ):
        """Test that loading the same model twice uses cache."""
        mock_tokenizer_class.from_pretrained.return_value = mock_tokenizer
        mock_model_class.from_pretrained.return_value = mock_model

        from models import load_model, unload_model

        # Ensure clean state
        unload_model()

        # First load
        load_model("test-model/test-7B")
        first_call_count = mock_model_class.from_pretrained.call_count

        # Second load (should use cache)
        load_model("test-model/test-7B")
        second_call_count = mock_model_class.from_pretrained.call_count

        # Should not have called from_pretrained again
        assert first_call_count == second_call_count


class TestChatTemplate:
    """Test chat template application."""

    @patch("models.load_model")
    def test_apply_chat_template_with_tokenizer_method(self, mock_load_model, mock_tokenizer):
        """Test chat template when tokenizer has apply_chat_template."""
        from models import apply_chat_template, LoadedModel

        mock_load_model.return_value = LoadedModel(
            model_id="test-model",
            model=MagicMock(),
            tokenizer=mock_tokenizer,
        )

        messages = [
            {"role": "user", "content": "Hello!"},
        ]

        result = apply_chat_template("test-model", messages)

        assert "<|user|>" in result
        assert "Hello!" in result
        assert "<|assistant|>" in result  # Generation prompt

    @patch("models.load_model")
    def test_apply_chat_template_fallback(self, mock_load_model):
        """Test fallback formatting when tokenizer lacks apply_chat_template."""
        from models import apply_chat_template, LoadedModel

        # Tokenizer without apply_chat_template
        simple_tokenizer = MagicMock()
        del simple_tokenizer.apply_chat_template

        mock_load_model.return_value = LoadedModel(
            model_id="test-model",
            model=MagicMock(),
            tokenizer=simple_tokenizer,
        )

        messages = [
            {"role": "system", "content": "You are helpful."},
            {"role": "user", "content": "Hi!"},
        ]

        result = apply_chat_template("test-model", messages)

        assert "System:" in result
        assert "User:" in result
        assert "Assistant:" in result