File size: 4,540 Bytes
b386992
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
# Copyright (c) 2025, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from unittest.mock import MagicMock, patch

import pytest
from omegaconf import DictConfig

from nemo.export.quantize.quantizer import QUANT_CFG_CHOICES, Quantizer


@pytest.fixture
def basic_quantization_config():
    return DictConfig(
        {'algorithm': 'int8', 'decoder_type': 'llama', 'awq_block_size': 128, 'sq_alpha': 0.5, 'enable_kv_cache': True}
    )


@pytest.fixture
def basic_export_config():
    return DictConfig(
        {
            'dtype': '16',
            'decoder_type': 'llama',
            'inference_tensor_parallel': 1,
            'inference_pipeline_parallel': 1,
            'save_path': '/tmp/model.qnemo',
        }
    )


class TestQuantizer:
    def test_init_valid_configs(self, basic_quantization_config, basic_export_config):
        quantizer = Quantizer(basic_quantization_config, basic_export_config)
        assert quantizer.quantization_config == basic_quantization_config
        assert quantizer.export_config == basic_export_config
        assert quantizer.quant_cfg == QUANT_CFG_CHOICES['int8']

    def test_init_invalid_algorithm(self, basic_quantization_config, basic_export_config):
        basic_quantization_config.algorithm = 'invalid_algo'
        with pytest.raises(AssertionError):
            Quantizer(basic_quantization_config, basic_export_config)

    def test_init_invalid_dtype(self, basic_quantization_config, basic_export_config):
        basic_export_config.dtype = '32'
        with pytest.raises(AssertionError):
            Quantizer(basic_quantization_config, basic_export_config)

    def test_null_algorithm(self, basic_quantization_config, basic_export_config):
        basic_quantization_config.algorithm = None
        quantizer = Quantizer(basic_quantization_config, basic_export_config)
        assert quantizer.quant_cfg is None

    @patch('nemo.export.quantize.quantizer.dist')
    def test_quantize_method(self, mock_dist, basic_quantization_config, basic_export_config):
        mock_dist.get_rank.return_value = 0

        # Create mock model and forward loop
        mock_model = MagicMock()
        mock_forward_loop = MagicMock()

        quantizer = Quantizer(basic_quantization_config, basic_export_config)

        with patch('modelopt.torch.quantization.quantize') as mock_quantize:
            with patch('modelopt.torch.quantization.print_quant_summary'):
                quantizer.quantize(mock_model, mock_forward_loop)

                # Verify quantize was called with correct arguments
                mock_quantize.assert_called_once_with(mock_model, QUANT_CFG_CHOICES['int8'], mock_forward_loop)

    @patch('nemo.export.quantize.quantizer.dist')
    def test_modify_model_config(self, mock_dist):
        mock_config = DictConfig({'sequence_parallel': True})
        modified_config = Quantizer.modify_model_config(mock_config)

        assert modified_config.sequence_parallel is False
        assert modified_config.name == 'modelopt'
        assert modified_config.apply_rope_fusion is False

    @patch('nemo.export.quantize.quantizer.dist')
    @patch('nemo.export.quantize.quantizer.export_tensorrt_llm_checkpoint')
    def test_export_method(self, mock_export, mock_dist, basic_quantization_config, basic_export_config):
        mock_dist.get_rank.return_value = 0
        mock_model = MagicMock()
        mock_model.cfg.megatron_amp_O2 = False
        mock_model.trainer.num_nodes = 1

        quantizer = Quantizer(basic_quantization_config, basic_export_config)

        with patch('nemo.export.quantize.quantizer.save_artifacts'):
            quantizer.export(mock_model)

            # Verify export was called with correct arguments
            mock_export.assert_called_once()
            call_args = mock_export.call_args[1]
            assert call_args['decoder_type'] == 'llama'
            assert call_args['inference_tensor_parallel'] == 1
            assert call_args['inference_pipeline_parallel'] == 1