| """ |
| Unit tests for ModelOptModelLoader class. |
| |
| This test module verifies the functionality of ModelOptModelLoader, which |
| applies NVIDIA Model Optimizer quantization to models during loading. |
| """ |
|
|
| import unittest |
| from unittest.mock import MagicMock, patch |
|
|
| import torch.nn as nn |
|
|
| from sglang.srt.configs.device_config import DeviceConfig |
| from sglang.srt.configs.load_config import LoadConfig |
| from sglang.srt.configs.model_config import ModelConfig |
| from sglang.srt.layers.modelopt_utils import QUANT_CFG_CHOICES |
| from sglang.srt.model_loader.loader import ModelOptModelLoader |
| from sglang.srt.utils import get_device |
| from sglang.test.ci.ci_register import register_cuda_ci |
| from sglang.test.test_utils import CustomTestCase |
|
|
| |
|
|
| |
| CALIBRATION_BATCH_SIZE = 36 |
| CALIBRATION_NUM_SAMPLES = 512 |
| DEFAULT_DEVICE = "cuda:0" |
|
|
| register_cuda_ci(est_time=11, suite="stage-b-test-small-1-gpu") |
|
|
|
|
| class TestModelOptModelLoader(CustomTestCase): |
| """Test cases for ModelOptModelLoader functionality.""" |
|
|
| def setUp(self): |
| """Set up test fixtures.""" |
| |
| self.mock_tp_rank = patch( |
| "sglang.srt.distributed.parallel_state.get_tensor_model_parallel_rank", |
| return_value=0, |
| ) |
| self.mock_tp_rank.start() |
|
|
| self.mock_rank0_log = patch("sglang.srt.model_loader.loader.rank0_log") |
| self.mock_rank0_log.start() |
|
|
| |
| self.mock_logger = patch("sglang.srt.model_loader.loader.logger") |
| self.mock_logger.start() |
|
|
| |
| self.mock_get_tp_group = patch( |
| "sglang.srt.distributed.parallel_state.get_tp_group" |
| ) |
| self.mock_get_tp_group.start() |
|
|
| |
| self.mock_mp_is_initialized = patch( |
| "sglang.srt.distributed.parallel_state.model_parallel_is_initialized", |
| return_value=True, |
| ) |
| self.mock_mp_is_initialized.start() |
|
|
| self.model_path = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
| self.load_config = LoadConfig() |
| self.device_config = DeviceConfig(device=get_device()) |
|
|
| |
| self.model_config = ModelConfig( |
| model_path=self.model_path, |
| quantization="modelopt_fp8", |
| ) |
|
|
| |
| self.unified_model_config = ModelConfig( |
| model_path=self.model_path, quantization="modelopt_fp8" |
| ) |
|
|
| |
| self.mock_base_model = MagicMock(spec=nn.Module) |
| self.mock_base_model.eval.return_value = self.mock_base_model |
| self.mock_base_model.device = ( |
| DEFAULT_DEVICE |
| ) |
|
|
| def tearDown(self): |
| """Clean up test fixtures.""" |
| |
| self.mock_tp_rank.stop() |
| self.mock_rank0_log.stop() |
| self.mock_logger.stop() |
| self.mock_get_tp_group.stop() |
| self.mock_mp_is_initialized.stop() |
|
|
| @patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES) |
| @patch("sglang.srt.model_loader.loader.logger") |
| def test_successful_fp8_quantization(self, mock_logger): |
| """Test successful FP8 quantization workflow.""" |
|
|
| |
| loader = ModelOptModelLoader(self.load_config) |
|
|
| |
| mock_mtq = MagicMock() |
|
|
| |
| mock_fp8_cfg = MagicMock() |
| mock_mtq.FP8_DEFAULT_CFG = mock_fp8_cfg |
| mock_mtq.quantize.return_value = self.mock_base_model |
| mock_mtq.print_quant_summary = MagicMock() |
|
|
| |
| def mock_load_model(*, model_config, device_config): |
| mock_logger.info("ModelOptModelLoader: Loading base model...") |
|
|
| |
| model = self.mock_base_model |
|
|
| |
| quant_choice_str = model_config._get_modelopt_quant_type() |
| quant_cfg_name = QUANT_CFG_CHOICES.get(quant_choice_str) |
|
|
| if not quant_cfg_name: |
| raise ValueError(f"Invalid modelopt_quant choice: '{quant_choice_str}'") |
|
|
| |
| if quant_cfg_name == "FP8_DEFAULT_CFG": |
| quant_cfg = mock_fp8_cfg |
|
|
| mock_logger.info( |
| f"Quantizing model with ModelOpt using config attribute: mtq.{quant_cfg_name}" |
| ) |
|
|
| |
| quantized_model = mock_mtq.quantize(model, quant_cfg, forward_loop=None) |
| mock_logger.info("Model successfully quantized with ModelOpt.") |
|
|
| |
| mock_mtq.print_quant_summary(quantized_model) |
|
|
| return quantized_model.eval() |
|
|
| return model.eval() |
|
|
| |
| with patch.object(loader, "load_model", side_effect=mock_load_model): |
| |
| result_model = loader.load_model( |
| model_config=self.model_config, device_config=self.device_config |
| ) |
|
|
| |
| mock_mtq.quantize.assert_called_once_with( |
| self.mock_base_model, mock_fp8_cfg, forward_loop=None |
| ) |
|
|
| |
| mock_logger.info.assert_any_call( |
| "ModelOptModelLoader: Loading base model..." |
| ) |
| mock_logger.info.assert_any_call( |
| "Quantizing model with ModelOpt using config attribute: mtq.FP8_DEFAULT_CFG" |
| ) |
| mock_logger.info.assert_any_call( |
| "Model successfully quantized with ModelOpt." |
| ) |
|
|
| |
| mock_mtq.print_quant_summary.assert_called_once_with(self.mock_base_model) |
|
|
| |
| self.mock_base_model.eval.assert_called() |
|
|
| |
| self.assertEqual(result_model, self.mock_base_model) |
|
|
| @patch("sglang.srt.model_loader.loader.logger") |
| def test_missing_modelopt_import(self, mock_logger): |
| """Test error handling when modelopt library is not available.""" |
|
|
| loader = ModelOptModelLoader(self.load_config) |
|
|
| |
| with patch.object( |
| loader, "_load_modelopt_base_model", return_value=self.mock_base_model |
| ): |
| |
| original_import = __import__ |
|
|
| def mock_import(name, *args, **kwargs): |
| if name.startswith("modelopt"): |
| raise ImportError("No module named 'modelopt'") |
| |
| return original_import(name, *args, **kwargs) |
|
|
| with patch("builtins.__import__", side_effect=mock_import): |
| |
| with self.assertRaises(ImportError): |
| loader.load_model( |
| model_config=self.model_config, device_config=self.device_config |
| ) |
|
|
| |
| mock_logger.error.assert_called_with( |
| "NVIDIA Model Optimizer (modelopt) library not found. " |
| "Please install it to use ModelOpt quantization." |
| ) |
|
|
| @patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES) |
| @patch("sglang.srt.model_loader.loader.AutoTokenizer") |
| @patch("sglang.srt.model_loader.loader.logger") |
| def test_calibration_workflow_integration(self, mock_logger, mock_auto_tokenizer): |
| """Test end-to-end calibration workflow integration.""" |
|
|
| loader = ModelOptModelLoader(self.load_config) |
|
|
| |
| mock_tokenizer = MagicMock() |
| mock_tokenizer.padding_side = "right" |
| mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer |
|
|
| |
| mock_mtq = MagicMock() |
| mock_mto = MagicMock() |
| mock_dataset_utils = MagicMock() |
|
|
| |
| mock_fp8_cfg = MagicMock() |
| mock_mtq.FP8_DEFAULT_CFG = mock_fp8_cfg |
|
|
| |
| mock_calib_dataloader = MagicMock() |
| mock_calibrate_loop = MagicMock() |
| mock_dataset_utils.get_dataset_dataloader.return_value = mock_calib_dataloader |
| mock_dataset_utils.create_forward_loop.return_value = mock_calibrate_loop |
|
|
| |
| mock_is_quantized = MagicMock(return_value=False) |
|
|
| with patch.object( |
| loader, "_load_modelopt_base_model", return_value=self.mock_base_model |
| ): |
| with patch.dict( |
| "sys.modules", |
| { |
| "modelopt": MagicMock(), |
| "modelopt.torch": MagicMock(), |
| "modelopt.torch.opt": mock_mto, |
| "modelopt.torch.quantization": mock_mtq, |
| "modelopt.torch.quantization.utils": MagicMock( |
| is_quantized=mock_is_quantized |
| ), |
| "modelopt.torch.utils": MagicMock(), |
| "modelopt.torch.utils.dataset_utils": mock_dataset_utils, |
| }, |
| ): |
| |
| result_model = loader.load_model( |
| model_config=self.model_config, device_config=self.device_config |
| ) |
|
|
| |
| self.assertEqual(result_model, self.mock_base_model) |
|
|
| |
| |
| |
|
|
| @patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES) |
| @patch("sglang.srt.model_loader.loader.AutoTokenizer") |
| @patch("sglang.srt.model_loader.loader.logger") |
| def test_quantized_checkpoint_restore(self, mock_logger, mock_auto_tokenizer): |
| """Test restoring from a quantized checkpoint.""" |
|
|
| |
| config_with_restore = ModelConfig( |
| model_path=self.model_path, |
| quantization="modelopt_fp8", |
| ) |
|
|
| |
| load_config_with_restore = LoadConfig( |
| modelopt_checkpoint_restore_path="/path/to/quantized/checkpoint" |
| ) |
|
|
| loader = ModelOptModelLoader(load_config_with_restore) |
|
|
| |
| mock_tokenizer = MagicMock() |
| mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer |
|
|
| |
| mock_mtq = MagicMock() |
| mock_mto = MagicMock() |
|
|
| |
| mock_fp8_cfg = MagicMock() |
| mock_mtq.FP8_DEFAULT_CFG = mock_fp8_cfg |
|
|
| |
| mock_is_quantized = MagicMock(return_value=False) |
|
|
| with patch.object( |
| loader, "_load_modelopt_base_model", return_value=self.mock_base_model |
| ): |
| with patch.dict( |
| "sys.modules", |
| { |
| "modelopt": MagicMock(), |
| "modelopt.torch": MagicMock(), |
| "modelopt.torch.opt": mock_mto, |
| "modelopt.torch.quantization": mock_mtq, |
| "modelopt.torch.quantization.utils": MagicMock( |
| is_quantized=mock_is_quantized |
| ), |
| }, |
| ): |
| with patch.object(loader, "_setup_modelopt_quantization") as mock_setup: |
| |
| def mock_setup_quantization( |
| model, |
| tokenizer, |
| quant_cfg, |
| quantized_ckpt_restore_path=None, |
| **kwargs, |
| ): |
| if quantized_ckpt_restore_path: |
| mock_mto.restore(model, quantized_ckpt_restore_path) |
| print( |
| f"Restored quantized model from {quantized_ckpt_restore_path}" |
| ) |
| return |
|
|
| mock_setup.side_effect = mock_setup_quantization |
|
|
| |
| result_model = loader.load_model( |
| model_config=config_with_restore, |
| device_config=self.device_config, |
| ) |
|
|
| |
| mock_setup.assert_called_once() |
| call_args = mock_setup.call_args |
| |
| self.assertIn("quantized_ckpt_restore_path", call_args[1]) |
| self.assertEqual( |
| call_args[1]["quantized_ckpt_restore_path"], |
| "/path/to/quantized/checkpoint", |
| ) |
|
|
| |
| mock_mto.restore.assert_called_once_with( |
| self.mock_base_model, "/path/to/quantized/checkpoint" |
| ) |
|
|
| |
| self.assertEqual(result_model, self.mock_base_model) |
|
|
| @patch("sglang.srt.model_loader.loader.QUANT_CFG_CHOICES", QUANT_CFG_CHOICES) |
| @patch("sglang.srt.model_loader.loader.AutoTokenizer") |
| @patch("sglang.srt.model_loader.loader.logger") |
| def test_quantized_checkpoint_save(self, mock_logger, mock_auto_tokenizer): |
| """Test saving quantized checkpoint after calibration.""" |
|
|
| |
| config_with_save = ModelConfig( |
| model_path=self.model_path, |
| quantization="modelopt_fp8", |
| ) |
|
|
| |
| load_config_with_save = LoadConfig( |
| modelopt_checkpoint_save_path="/path/to/save/checkpoint" |
| ) |
|
|
| loader = ModelOptModelLoader(load_config_with_save) |
|
|
| |
| mock_tokenizer = MagicMock() |
| mock_auto_tokenizer.from_pretrained.return_value = mock_tokenizer |
|
|
| |
| mock_mtq = MagicMock() |
| mock_mto = MagicMock() |
| mock_dataset_utils = MagicMock() |
|
|
| |
| mock_fp8_cfg = MagicMock() |
| mock_mtq.FP8_DEFAULT_CFG = mock_fp8_cfg |
|
|
| |
| mock_is_quantized = MagicMock(return_value=False) |
|
|
| with patch.object( |
| loader, "_load_modelopt_base_model", return_value=self.mock_base_model |
| ): |
| with patch.dict( |
| "sys.modules", |
| { |
| "modelopt": MagicMock(), |
| "modelopt.torch": MagicMock(), |
| "modelopt.torch.opt": mock_mto, |
| "modelopt.torch.quantization": mock_mtq, |
| "modelopt.torch.quantization.utils": MagicMock( |
| is_quantized=mock_is_quantized |
| ), |
| "modelopt.torch.utils": MagicMock(), |
| "modelopt.torch.utils.dataset_utils": mock_dataset_utils, |
| }, |
| ): |
| with patch.object(loader, "_setup_modelopt_quantization") as mock_setup: |
| |
| def mock_setup_quantization( |
| model, |
| tokenizer, |
| quant_cfg, |
| quantized_ckpt_save_path=None, |
| **kwargs, |
| ): |
| |
| mock_mtq.quantize(model, quant_cfg, forward_loop=MagicMock()) |
| mock_mtq.print_quant_summary(model) |
|
|
| |
| if quantized_ckpt_save_path: |
| mock_mto.save(model, quantized_ckpt_save_path) |
| print( |
| f"Quantized model saved to {quantized_ckpt_save_path}" |
| ) |
|
|
| mock_setup.side_effect = mock_setup_quantization |
|
|
| |
| result_model = loader.load_model( |
| model_config=config_with_save, device_config=self.device_config |
| ) |
|
|
| |
| mock_setup.assert_called_once() |
| call_args = mock_setup.call_args |
| |
| self.assertIn("quantized_ckpt_save_path", call_args[1]) |
| self.assertEqual( |
| call_args[1]["quantized_ckpt_save_path"], |
| "/path/to/save/checkpoint", |
| ) |
|
|
| |
| mock_mto.save.assert_called_once_with( |
| self.mock_base_model, "/path/to/save/checkpoint" |
| ) |
|
|
| |
| self.assertEqual(result_model, self.mock_base_model) |
|
|
| def test_unified_quantization_flag_support(self): |
| """Test that ModelOptModelLoader supports unified quantization flags.""" |
| |
| config_fp8 = ModelConfig( |
| model_path=self.model_path, quantization="modelopt_fp8" |
| ) |
| self.assertEqual(config_fp8._get_modelopt_quant_type(), "fp8") |
|
|
| |
| config_fp4 = ModelConfig( |
| model_path=self.model_path, quantization="modelopt_fp4" |
| ) |
| self.assertEqual(config_fp4._get_modelopt_quant_type(), "nvfp4") |
|
|
| |
| config_auto = ModelConfig(model_path=self.model_path, quantization="modelopt") |
| |
| self.assertEqual(config_auto._get_modelopt_quant_type(), "fp8") |
|
|
|
|
| class TestModelOptLoaderIntegration(CustomTestCase): |
| """Integration tests for ModelOptModelLoader with Engine API.""" |
|
|
| @patch("sglang.srt.model_loader.loader.get_model_loader") |
| @patch("sglang.srt.entrypoints.engine.Engine.__init__") |
| def test_engine_with_modelopt_quant_parameter( |
| self, mock_engine_init, mock_get_model_loader |
| ): |
| """Test that Engine properly handles modelopt_quant parameter.""" |
|
|
| |
| mock_engine_init.return_value = None |
|
|
| |
| mock_loader = MagicMock(spec=ModelOptModelLoader) |
| mock_get_model_loader.return_value = mock_loader |
|
|
| |
| |
|
|
| |
| |
| try: |
| engine_args = { |
| "model_path": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", |
| "modelopt_quant": "fp8", |
| "log_level": "error", |
| } |
|
|
| |
| from sglang.srt.server_args import ServerArgs |
|
|
| server_args = ServerArgs(**engine_args) |
|
|
| |
| self.assertEqual(server_args.modelopt_quant, "fp8") |
|
|
| except Exception as e: |
| |
| |
| if "modelopt_quant" not in str(e): |
| |
| pass |
| else: |
| self.fail(f"modelopt_quant parameter not properly handled: {e}") |
|
|
| @patch("sglang.srt.model_loader.loader.get_model_loader") |
| @patch("sglang.srt.entrypoints.engine.Engine.__init__") |
| def test_engine_with_modelopt_quant_cli_argument( |
| self, mock_engine_init, mock_get_model_loader |
| ): |
| """Test that CLI argument --modelopt-quant is properly parsed.""" |
|
|
| |
| mock_engine_init.return_value = None |
|
|
| |
| mock_loader = MagicMock(spec=ModelOptModelLoader) |
| mock_get_model_loader.return_value = mock_loader |
|
|
| |
| import argparse |
|
|
| from sglang.srt.server_args import ServerArgs |
|
|
| |
| parser = argparse.ArgumentParser() |
| ServerArgs.add_cli_args(parser) |
|
|
| |
| args = parser.parse_args( |
| [ |
| "--model-path", |
| "TinyLlama/TinyLlama-1.1B-Chat-v1.0", |
| "--modelopt-quant", |
| "fp8", |
| ] |
| ) |
|
|
| |
| server_args = ServerArgs.from_cli_args(args) |
|
|
| |
| self.assertEqual(server_args.modelopt_quant, "fp8") |
| self.assertEqual(server_args.model_path, "TinyLlama/TinyLlama-1.1B-Chat-v1.0") |
|
|
|
|
| class TestParseQuantHfConfig(CustomTestCase): |
| """Tests for _parse_quant_hf_config and _parse_modelopt_quant_config. |
| |
| Regression tests for the fix where quant_method='modelopt' ignoring quant_algo. |
| """ |
|
|
| |
| _MODELOPT_CASES = [ |
| ({"quant_method": "modelopt", "quant_algo": "FP8"}, "modelopt_fp8"), |
| ({"quant_method": "modelopt", "quant_algo": "FP4"}, "modelopt_fp4"), |
| ({"quant_method": "modelopt", "quant_algo": "NVFP4"}, "modelopt_fp4"), |
| ({"quant_method": "modelopt", "quant_algo": "MIXED_PRECISION"}, "w4afp8"), |
| ({"quant_algo": "FP8"}, "modelopt_fp8"), |
| ({"quant_algo": "FP4"}, "modelopt_fp4"), |
| ({"quant_algo": "MIXED_PRECISION"}, "w4afp8"), |
| ({"quant_method": "modelopt"}, "modelopt"), |
| ] |
|
|
| def setUp(self): |
| """Set up a real ModelConfig using TinyLlama (already used elsewhere).""" |
| self.mock_tp_rank = patch( |
| "sglang.srt.distributed.parallel_state.get_tensor_model_parallel_rank", |
| return_value=0, |
| ) |
| self.mock_tp_rank.start() |
|
|
| self.mock_mp_is_initialized = patch( |
| "sglang.srt.distributed.parallel_state.model_parallel_is_initialized", |
| return_value=True, |
| ) |
| self.mock_mp_is_initialized.start() |
|
|
| self.model_config = ModelConfig( |
| model_path="TinyLlama/TinyLlama-1.1B-Chat-v1.0", |
| ) |
|
|
| def tearDown(self): |
| self.mock_tp_rank.stop() |
| self.mock_mp_is_initialized.stop() |
|
|
| def test_modelopt_quant_parsing(self): |
| """Modelopt quant configs must resolve to the correct quant_method.""" |
| for quant_cfg_input, expected in self._MODELOPT_CASES: |
| with self.subTest(quant_cfg=quant_cfg_input): |
| self.model_config.hf_config.quantization_config = dict(quant_cfg_input) |
| result = self.model_config._parse_quant_hf_config() |
| self.assertEqual(result["quant_method"], expected) |
|
|
| def test_non_modelopt_quant_method_unchanged(self): |
| """Non-modelopt quant_method (e.g. 'gptq') must NOT enter the modelopt path.""" |
| self.model_config.hf_config.quantization_config = { |
| "quant_method": "gptq", |
| "bits": 4, |
| } |
| result = self.model_config._parse_quant_hf_config() |
| self.assertEqual(result["quant_method"], "gptq") |
| self.assertNotIn("quant_algo", result) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|