NeMo_Canary / tests /export /test_quantizer.py
Respair's picture
Upload folder using huggingface_hub
b386992 verified
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock, patch
import pytest
from omegaconf import DictConfig
from nemo.export.quantize.quantizer import QUANT_CFG_CHOICES, Quantizer
@pytest.fixture
def basic_quantization_config():
return DictConfig(
{'algorithm': 'int8', 'decoder_type': 'llama', 'awq_block_size': 128, 'sq_alpha': 0.5, 'enable_kv_cache': True}
)
@pytest.fixture
def basic_export_config():
return DictConfig(
{
'dtype': '16',
'decoder_type': 'llama',
'inference_tensor_parallel': 1,
'inference_pipeline_parallel': 1,
'save_path': '/tmp/model.qnemo',
}
)
class TestQuantizer:
def test_init_valid_configs(self, basic_quantization_config, basic_export_config):
quantizer = Quantizer(basic_quantization_config, basic_export_config)
assert quantizer.quantization_config == basic_quantization_config
assert quantizer.export_config == basic_export_config
assert quantizer.quant_cfg == QUANT_CFG_CHOICES['int8']
def test_init_invalid_algorithm(self, basic_quantization_config, basic_export_config):
basic_quantization_config.algorithm = 'invalid_algo'
with pytest.raises(AssertionError):
Quantizer(basic_quantization_config, basic_export_config)
def test_init_invalid_dtype(self, basic_quantization_config, basic_export_config):
basic_export_config.dtype = '32'
with pytest.raises(AssertionError):
Quantizer(basic_quantization_config, basic_export_config)
def test_null_algorithm(self, basic_quantization_config, basic_export_config):
basic_quantization_config.algorithm = None
quantizer = Quantizer(basic_quantization_config, basic_export_config)
assert quantizer.quant_cfg is None
@patch('nemo.export.quantize.quantizer.dist')
def test_quantize_method(self, mock_dist, basic_quantization_config, basic_export_config):
mock_dist.get_rank.return_value = 0
# Create mock model and forward loop
mock_model = MagicMock()
mock_forward_loop = MagicMock()
quantizer = Quantizer(basic_quantization_config, basic_export_config)
with patch('modelopt.torch.quantization.quantize') as mock_quantize:
with patch('modelopt.torch.quantization.print_quant_summary'):
quantizer.quantize(mock_model, mock_forward_loop)
# Verify quantize was called with correct arguments
mock_quantize.assert_called_once_with(mock_model, QUANT_CFG_CHOICES['int8'], mock_forward_loop)
@patch('nemo.export.quantize.quantizer.dist')
def test_modify_model_config(self, mock_dist):
mock_config = DictConfig({'sequence_parallel': True})
modified_config = Quantizer.modify_model_config(mock_config)
assert modified_config.sequence_parallel is False
assert modified_config.name == 'modelopt'
assert modified_config.apply_rope_fusion is False
@patch('nemo.export.quantize.quantizer.dist')
@patch('nemo.export.quantize.quantizer.export_tensorrt_llm_checkpoint')
def test_export_method(self, mock_export, mock_dist, basic_quantization_config, basic_export_config):
mock_dist.get_rank.return_value = 0
mock_model = MagicMock()
mock_model.cfg.megatron_amp_O2 = False
mock_model.trainer.num_nodes = 1
quantizer = Quantizer(basic_quantization_config, basic_export_config)
with patch('nemo.export.quantize.quantizer.save_artifacts'):
quantizer.export(mock_model)
# Verify export was called with correct arguments
mock_export.assert_called_once()
call_args = mock_export.call_args[1]
assert call_args['decoder_type'] == 'llama'
assert call_args['inference_tensor_parallel'] == 1
assert call_args['inference_pipeline_parallel'] == 1