| | import gc |
| | import tempfile |
| | import unittest |
| |
|
| | from diffusers import NVIDIAModelOptConfig, SD3Transformer2DModel, StableDiffusion3Pipeline |
| | from diffusers.utils import is_nvidia_modelopt_available, is_torch_available |
| | from diffusers.utils.testing_utils import ( |
| | backend_empty_cache, |
| | backend_reset_peak_memory_stats, |
| | enable_full_determinism, |
| | nightly, |
| | numpy_cosine_similarity_distance, |
| | require_accelerate, |
| | require_big_accelerator, |
| | require_modelopt_version_greater_or_equal, |
| | require_torch_cuda_compatibility, |
| | torch_device, |
| | ) |
| |
|
| |
|
| | if is_nvidia_modelopt_available(): |
| | import modelopt.torch.quantization as mtq |
| |
|
| | if is_torch_available(): |
| | import torch |
| |
|
| | from ..utils import LoRALayer, get_memory_consumption_stat |
| |
|
| | enable_full_determinism() |
| |
|
| |
|
| | @nightly |
| | @require_big_accelerator |
| | @require_accelerate |
| | @require_modelopt_version_greater_or_equal("0.33.1") |
| | class ModelOptBaseTesterMixin: |
| | model_id = "hf-internal-testing/tiny-sd3-pipe" |
| | model_cls = SD3Transformer2DModel |
| | pipeline_cls = StableDiffusion3Pipeline |
| | torch_dtype = torch.bfloat16 |
| | expected_memory_reduction = 0.0 |
| | keep_in_fp32_module = "" |
| | modules_to_not_convert = "" |
| | _test_torch_compile = False |
| |
|
| | def setUp(self): |
| | backend_reset_peak_memory_stats(torch_device) |
| | backend_empty_cache(torch_device) |
| | gc.collect() |
| |
|
| | def tearDown(self): |
| | backend_reset_peak_memory_stats(torch_device) |
| | backend_empty_cache(torch_device) |
| | gc.collect() |
| |
|
| | def get_dummy_init_kwargs(self): |
| | return {"quant_type": "FP8"} |
| |
|
| | def get_dummy_model_init_kwargs(self): |
| | return { |
| | "pretrained_model_name_or_path": self.model_id, |
| | "torch_dtype": self.torch_dtype, |
| | "quantization_config": NVIDIAModelOptConfig(**self.get_dummy_init_kwargs()), |
| | "subfolder": "transformer", |
| | } |
| |
|
| | def test_modelopt_layers(self): |
| | model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) |
| | for name, module in model.named_modules(): |
| | if isinstance(module, torch.nn.Linear): |
| | assert mtq.utils.is_quantized(module) |
| |
|
| | def test_modelopt_memory_usage(self): |
| | inputs = self.get_dummy_inputs() |
| | inputs = { |
| | k: v.to(device=torch_device, dtype=torch.bfloat16) for k, v in inputs.items() if not isinstance(v, bool) |
| | } |
| |
|
| | unquantized_model = self.model_cls.from_pretrained( |
| | self.model_id, torch_dtype=self.torch_dtype, subfolder="transformer" |
| | ) |
| | unquantized_model.to(torch_device) |
| | unquantized_model_memory = get_memory_consumption_stat(unquantized_model, inputs) |
| |
|
| | quantized_model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) |
| | quantized_model.to(torch_device) |
| | quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs) |
| |
|
| | assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_reduction |
| |
|
| | def test_keep_modules_in_fp32(self): |
| | _keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules |
| | self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module |
| |
|
| | model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) |
| | model.to(torch_device) |
| |
|
| | for name, module in model.named_modules(): |
| | if isinstance(module, torch.nn.Linear): |
| | if name in model._keep_in_fp32_modules: |
| | assert module.weight.dtype == torch.float32 |
| | self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules |
| |
|
| | def test_modules_to_not_convert(self): |
| | init_kwargs = self.get_dummy_model_init_kwargs() |
| | quantization_config_kwargs = self.get_dummy_init_kwargs() |
| | quantization_config_kwargs.update({"modules_to_not_convert": self.modules_to_not_convert}) |
| | quantization_config = NVIDIAModelOptConfig(**quantization_config_kwargs) |
| | init_kwargs.update({"quantization_config": quantization_config}) |
| |
|
| | model = self.model_cls.from_pretrained(**init_kwargs) |
| | model.to(torch_device) |
| |
|
| | for name, module in model.named_modules(): |
| | if name in self.modules_to_not_convert: |
| | assert not mtq.utils.is_quantized(module) |
| |
|
| | def test_dtype_assignment(self): |
| | model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) |
| |
|
| | with self.assertRaises(ValueError): |
| | model.to(torch.float16) |
| |
|
| | with self.assertRaises(ValueError): |
| | device_0 = f"{torch_device}:0" |
| | model.to(device=device_0, dtype=torch.float16) |
| |
|
| | with self.assertRaises(ValueError): |
| | model.float() |
| |
|
| | with self.assertRaises(ValueError): |
| | model.half() |
| |
|
| | model.to(torch_device) |
| |
|
| | def test_serialization(self): |
| | model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) |
| | inputs = self.get_dummy_inputs() |
| |
|
| | model.to(torch_device) |
| | with torch.no_grad(): |
| | model_output = model(**inputs) |
| |
|
| | with tempfile.TemporaryDirectory() as tmp_dir: |
| | model.save_pretrained(tmp_dir) |
| | saved_model = self.model_cls.from_pretrained( |
| | tmp_dir, |
| | torch_dtype=torch.bfloat16, |
| | ) |
| |
|
| | saved_model.to(torch_device) |
| | with torch.no_grad(): |
| | saved_model_output = saved_model(**inputs) |
| |
|
| | assert torch.allclose(model_output.sample, saved_model_output.sample, rtol=1e-5, atol=1e-5) |
| |
|
| | def test_torch_compile(self): |
| | if not self._test_torch_compile: |
| | return |
| |
|
| | model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) |
| | compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True, dynamic=False) |
| |
|
| | model.to(torch_device) |
| | with torch.no_grad(): |
| | model_output = model(**self.get_dummy_inputs()).sample |
| |
|
| | compiled_model.to(torch_device) |
| | with torch.no_grad(): |
| | compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample |
| |
|
| | model_output = model_output.detach().float().cpu().numpy() |
| | compiled_model_output = compiled_model_output.detach().float().cpu().numpy() |
| |
|
| | max_diff = numpy_cosine_similarity_distance(model_output.flatten(), compiled_model_output.flatten()) |
| | assert max_diff < 1e-3 |
| |
|
| | def test_device_map_error(self): |
| | with self.assertRaises(ValueError): |
| | _ = self.model_cls.from_pretrained( |
| | **self.get_dummy_model_init_kwargs(), |
| | device_map={0: "8GB", "cpu": "16GB"}, |
| | ) |
| |
|
| | def get_dummy_inputs(self): |
| | batch_size = 1 |
| | seq_len = 16 |
| | height = width = 32 |
| | num_latent_channels = 4 |
| | caption_channels = 8 |
| |
|
| | torch.manual_seed(0) |
| | hidden_states = torch.randn((batch_size, num_latent_channels, height, width)).to( |
| | torch_device, dtype=torch.bfloat16 |
| | ) |
| | encoder_hidden_states = torch.randn((batch_size, seq_len, caption_channels)).to( |
| | torch_device, dtype=torch.bfloat16 |
| | ) |
| | timestep = torch.tensor([1.0]).to(torch_device, dtype=torch.bfloat16).expand(batch_size) |
| |
|
| | return { |
| | "hidden_states": hidden_states, |
| | "encoder_hidden_states": encoder_hidden_states, |
| | "timestep": timestep, |
| | } |
| |
|
| | def test_model_cpu_offload(self): |
| | init_kwargs = self.get_dummy_init_kwargs() |
| | transformer = self.model_cls.from_pretrained( |
| | self.model_id, |
| | quantization_config=NVIDIAModelOptConfig(**init_kwargs), |
| | subfolder="transformer", |
| | torch_dtype=torch.bfloat16, |
| | ) |
| | pipe = self.pipeline_cls.from_pretrained(self.model_id, transformer=transformer, torch_dtype=torch.bfloat16) |
| | pipe.enable_model_cpu_offload(device=torch_device) |
| | _ = pipe("a cat holding a sign that says hello", num_inference_steps=2) |
| |
|
| | def test_training(self): |
| | quantization_config = NVIDIAModelOptConfig(**self.get_dummy_init_kwargs()) |
| | quantized_model = self.model_cls.from_pretrained( |
| | self.model_id, |
| | subfolder="transformer", |
| | quantization_config=quantization_config, |
| | torch_dtype=torch.bfloat16, |
| | ).to(torch_device) |
| |
|
| | for param in quantized_model.parameters(): |
| | param.requires_grad = False |
| | if param.ndim == 1: |
| | param.data = param.data.to(torch.float32) |
| |
|
| | for _, module in quantized_model.named_modules(): |
| | if hasattr(module, "to_q"): |
| | module.to_q = LoRALayer(module.to_q, rank=4) |
| | if hasattr(module, "to_k"): |
| | module.to_k = LoRALayer(module.to_k, rank=4) |
| | if hasattr(module, "to_v"): |
| | module.to_v = LoRALayer(module.to_v, rank=4) |
| |
|
| | with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16): |
| | inputs = self.get_dummy_inputs() |
| | output = quantized_model(**inputs)[0] |
| | output.norm().backward() |
| |
|
| | for module in quantized_model.modules(): |
| | if isinstance(module, LoRALayer): |
| | self.assertTrue(module.adapter[1].weight.grad is not None) |
| |
|
| |
|
| | class SanaTransformerFP8WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase): |
| | expected_memory_reduction = 0.6 |
| |
|
| | def get_dummy_init_kwargs(self): |
| | return {"quant_type": "FP8"} |
| |
|
| |
|
| | class SanaTransformerINT8WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase): |
| | expected_memory_reduction = 0.6 |
| | _test_torch_compile = True |
| |
|
| | def get_dummy_init_kwargs(self): |
| | return {"quant_type": "INT8"} |
| |
|
| |
|
| | @require_torch_cuda_compatibility(8.0) |
| | class SanaTransformerINT4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase): |
| | expected_memory_reduction = 0.55 |
| |
|
| | def get_dummy_init_kwargs(self): |
| | return { |
| | "quant_type": "INT4", |
| | "block_quantize": 128, |
| | "channel_quantize": -1, |
| | "disable_conv_quantization": True, |
| | } |
| |
|
| |
|
| | @require_torch_cuda_compatibility(8.0) |
| | class SanaTransformerNF4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase): |
| | expected_memory_reduction = 0.65 |
| |
|
| | def get_dummy_init_kwargs(self): |
| | return { |
| | "quant_type": "NF4", |
| | "block_quantize": 128, |
| | "channel_quantize": -1, |
| | "scale_block_quantize": 8, |
| | "scale_channel_quantize": -1, |
| | "modules_to_not_convert": ["conv"], |
| | } |
| |
|
| |
|
| | @require_torch_cuda_compatibility(8.0) |
| | class SanaTransformerNVFP4WeightsTest(ModelOptBaseTesterMixin, unittest.TestCase): |
| | expected_memory_reduction = 0.65 |
| |
|
| | def get_dummy_init_kwargs(self): |
| | return { |
| | "quant_type": "NVFP4", |
| | "block_quantize": 128, |
| | "channel_quantize": -1, |
| | "scale_block_quantize": 8, |
| | "scale_channel_quantize": -1, |
| | "modules_to_not_convert": ["conv"], |
| | } |
| |
|