|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
|
import tempfile |
|
|
import unittest |
|
|
|
|
|
import torch |
|
|
from parameterized import parameterized |
|
|
|
|
|
from diffusers import BitsAndBytesConfig, DiffusionPipeline, QuantoConfig |
|
|
from diffusers.quantizers import PipelineQuantizationConfig |
|
|
from diffusers.utils import logging |
|
|
from diffusers.utils.testing_utils import ( |
|
|
CaptureLogger, |
|
|
is_transformers_available, |
|
|
require_accelerate, |
|
|
require_bitsandbytes_version_greater, |
|
|
require_quanto, |
|
|
require_torch, |
|
|
require_torch_accelerator, |
|
|
slow, |
|
|
torch_device, |
|
|
) |
|
|
|
|
|
|
|
|
if is_transformers_available(): |
|
|
from transformers import BitsAndBytesConfig as TranBitsAndBytesConfig |
|
|
else: |
|
|
TranBitsAndBytesConfig = None |
|
|
|
|
|
|
|
|
@require_bitsandbytes_version_greater("0.43.2") |
|
|
@require_quanto |
|
|
@require_accelerate |
|
|
@require_torch |
|
|
@require_torch_accelerator |
|
|
@slow |
|
|
class PipelineQuantizationTests(unittest.TestCase): |
|
|
model_name = "hf-internal-testing/tiny-flux-pipe" |
|
|
prompt = "a beautiful sunset amidst the mountains." |
|
|
num_inference_steps = 10 |
|
|
seed = 0 |
|
|
|
|
|
def test_quant_config_set_correctly_through_kwargs(self): |
|
|
components_to_quantize = ["transformer", "text_encoder_2"] |
|
|
quant_config = PipelineQuantizationConfig( |
|
|
quant_backend="bitsandbytes_4bit", |
|
|
quant_kwargs={ |
|
|
"load_in_4bit": True, |
|
|
"bnb_4bit_quant_type": "nf4", |
|
|
"bnb_4bit_compute_dtype": torch.bfloat16, |
|
|
}, |
|
|
components_to_quantize=components_to_quantize, |
|
|
) |
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
|
self.model_name, |
|
|
quantization_config=quant_config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
).to(torch_device) |
|
|
for name, component in pipe.components.items(): |
|
|
if name in components_to_quantize: |
|
|
self.assertTrue(getattr(component.config, "quantization_config", None) is not None) |
|
|
quantization_config = component.config.quantization_config |
|
|
self.assertTrue(quantization_config.load_in_4bit) |
|
|
self.assertTrue(quantization_config.quant_method == "bitsandbytes") |
|
|
|
|
|
_ = pipe(self.prompt, num_inference_steps=self.num_inference_steps) |
|
|
|
|
|
def test_quant_config_set_correctly_through_granular(self): |
|
|
quant_config = PipelineQuantizationConfig( |
|
|
quant_mapping={ |
|
|
"transformer": QuantoConfig(weights_dtype="int8"), |
|
|
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), |
|
|
} |
|
|
) |
|
|
components_to_quantize = list(quant_config.quant_mapping.keys()) |
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
|
self.model_name, |
|
|
quantization_config=quant_config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
).to(torch_device) |
|
|
for name, component in pipe.components.items(): |
|
|
if name in components_to_quantize: |
|
|
self.assertTrue(getattr(component.config, "quantization_config", None) is not None) |
|
|
quantization_config = component.config.quantization_config |
|
|
|
|
|
if name == "text_encoder_2": |
|
|
self.assertTrue(quantization_config.load_in_4bit) |
|
|
self.assertTrue(quantization_config.quant_method == "bitsandbytes") |
|
|
else: |
|
|
self.assertTrue(quantization_config.quant_method == "quanto") |
|
|
|
|
|
_ = pipe(self.prompt, num_inference_steps=self.num_inference_steps) |
|
|
|
|
|
def test_raises_error_for_invalid_config(self): |
|
|
with self.assertRaises(ValueError) as err_context: |
|
|
_ = PipelineQuantizationConfig( |
|
|
quant_mapping={ |
|
|
"transformer": QuantoConfig(weights_dtype="int8"), |
|
|
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), |
|
|
}, |
|
|
quant_backend="bitsandbytes_4bit", |
|
|
) |
|
|
|
|
|
self.assertTrue( |
|
|
str(err_context.exception) |
|
|
== "Both `quant_backend` and `quant_mapping` cannot be specified at the same time." |
|
|
) |
|
|
|
|
|
def test_validation_for_kwargs(self): |
|
|
components_to_quantize = ["transformer", "text_encoder_2"] |
|
|
with self.assertRaises(ValueError) as err_context: |
|
|
_ = PipelineQuantizationConfig( |
|
|
quant_backend="quanto", |
|
|
quant_kwargs={"weights_dtype": "int8"}, |
|
|
components_to_quantize=components_to_quantize, |
|
|
) |
|
|
|
|
|
self.assertTrue( |
|
|
"The signatures of the __init__ methods of the quantization config classes" in str(err_context.exception) |
|
|
) |
|
|
|
|
|
def test_raises_error_for_wrong_config_class(self): |
|
|
quant_config = { |
|
|
"transformer": QuantoConfig(weights_dtype="int8"), |
|
|
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), |
|
|
} |
|
|
with self.assertRaises(ValueError) as err_context: |
|
|
_ = DiffusionPipeline.from_pretrained( |
|
|
self.model_name, |
|
|
quantization_config=quant_config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
self.assertTrue( |
|
|
str(err_context.exception) == "`quantization_config` must be an instance of `PipelineQuantizationConfig`." |
|
|
) |
|
|
|
|
|
def test_validation_for_mapping(self): |
|
|
with self.assertRaises(ValueError) as err_context: |
|
|
_ = PipelineQuantizationConfig( |
|
|
quant_mapping={ |
|
|
"transformer": DiffusionPipeline(), |
|
|
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), |
|
|
} |
|
|
) |
|
|
|
|
|
self.assertTrue("Provided config for module_name=transformer could not be found" in str(err_context.exception)) |
|
|
|
|
|
def test_saving_loading(self): |
|
|
quant_config = PipelineQuantizationConfig( |
|
|
quant_mapping={ |
|
|
"transformer": QuantoConfig(weights_dtype="int8"), |
|
|
"text_encoder_2": TranBitsAndBytesConfig(load_in_4bit=True, compute_dtype=torch.bfloat16), |
|
|
} |
|
|
) |
|
|
components_to_quantize = list(quant_config.quant_mapping.keys()) |
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
|
self.model_name, |
|
|
quantization_config=quant_config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
).to(torch_device) |
|
|
|
|
|
pipe_inputs = {"prompt": self.prompt, "num_inference_steps": self.num_inference_steps, "output_type": "latent"} |
|
|
output_1 = pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
|
pipe.save_pretrained(tmpdir) |
|
|
loaded_pipe = DiffusionPipeline.from_pretrained(tmpdir, torch_dtype=torch.bfloat16).to(torch_device) |
|
|
for name, component in loaded_pipe.components.items(): |
|
|
if name in components_to_quantize: |
|
|
self.assertTrue(getattr(component.config, "quantization_config", None) is not None) |
|
|
quantization_config = component.config.quantization_config |
|
|
|
|
|
if name == "text_encoder_2": |
|
|
self.assertTrue(quantization_config.load_in_4bit) |
|
|
self.assertTrue(quantization_config.quant_method == "bitsandbytes") |
|
|
else: |
|
|
self.assertTrue(quantization_config.quant_method == "quanto") |
|
|
|
|
|
output_2 = loaded_pipe(**pipe_inputs, generator=torch.manual_seed(self.seed)).images |
|
|
|
|
|
self.assertTrue(torch.allclose(output_1, output_2)) |
|
|
|
|
|
@parameterized.expand(["quant_kwargs", "quant_mapping"]) |
|
|
def test_warn_invalid_component(self, method): |
|
|
invalid_component = "foo" |
|
|
if method == "quant_kwargs": |
|
|
components_to_quantize = ["transformer", invalid_component] |
|
|
quant_config = PipelineQuantizationConfig( |
|
|
quant_backend="bitsandbytes_8bit", |
|
|
quant_kwargs={"load_in_8bit": True}, |
|
|
components_to_quantize=components_to_quantize, |
|
|
) |
|
|
else: |
|
|
quant_config = PipelineQuantizationConfig( |
|
|
quant_mapping={ |
|
|
"transformer": QuantoConfig("int8"), |
|
|
invalid_component: TranBitsAndBytesConfig(load_in_8bit=True), |
|
|
} |
|
|
) |
|
|
|
|
|
logger = logging.get_logger("diffusers.pipelines.pipeline_loading_utils") |
|
|
logger.setLevel(logging.WARNING) |
|
|
with CaptureLogger(logger) as cap_logger: |
|
|
_ = DiffusionPipeline.from_pretrained( |
|
|
self.model_name, |
|
|
quantization_config=quant_config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
self.assertTrue(invalid_component in cap_logger.out) |
|
|
|
|
|
@parameterized.expand(["quant_kwargs", "quant_mapping"]) |
|
|
def test_no_quantization_for_all_invalid_components(self, method): |
|
|
invalid_component = "foo" |
|
|
if method == "quant_kwargs": |
|
|
components_to_quantize = [invalid_component] |
|
|
quant_config = PipelineQuantizationConfig( |
|
|
quant_backend="bitsandbytes_8bit", |
|
|
quant_kwargs={"load_in_8bit": True}, |
|
|
components_to_quantize=components_to_quantize, |
|
|
) |
|
|
else: |
|
|
quant_config = PipelineQuantizationConfig( |
|
|
quant_mapping={invalid_component: TranBitsAndBytesConfig(load_in_8bit=True)} |
|
|
) |
|
|
|
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
|
self.model_name, |
|
|
quantization_config=quant_config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
for name, component in pipe.components.items(): |
|
|
if isinstance(component, torch.nn.Module): |
|
|
self.assertTrue(not hasattr(component.config, "quantization_config")) |
|
|
|
|
|
@parameterized.expand(["quant_kwargs", "quant_mapping"]) |
|
|
def test_quant_config_repr(self, method): |
|
|
component_name = "transformer" |
|
|
if method == "quant_kwargs": |
|
|
components_to_quantize = [component_name] |
|
|
quant_config = PipelineQuantizationConfig( |
|
|
quant_backend="bitsandbytes_8bit", |
|
|
quant_kwargs={"load_in_8bit": True}, |
|
|
components_to_quantize=components_to_quantize, |
|
|
) |
|
|
else: |
|
|
quant_config = PipelineQuantizationConfig( |
|
|
quant_mapping={component_name: BitsAndBytesConfig(load_in_8bit=True)} |
|
|
) |
|
|
|
|
|
pipe = DiffusionPipeline.from_pretrained( |
|
|
self.model_name, |
|
|
quantization_config=quant_config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
self.assertTrue(getattr(pipe, "quantization_config", None) is not None) |
|
|
retrieved_config = pipe.quantization_config |
|
|
expected_config = """ |
|
|
transformer BitsAndBytesConfig { |
|
|
"_load_in_4bit": false, |
|
|
"_load_in_8bit": true, |
|
|
"bnb_4bit_compute_dtype": "float32", |
|
|
"bnb_4bit_quant_storage": "uint8", |
|
|
"bnb_4bit_quant_type": "fp4", |
|
|
"bnb_4bit_use_double_quant": false, |
|
|
"llm_int8_enable_fp32_cpu_offload": false, |
|
|
"llm_int8_has_fp16_weight": false, |
|
|
"llm_int8_skip_modules": null, |
|
|
"llm_int8_threshold": 6.0, |
|
|
"load_in_4bit": false, |
|
|
"load_in_8bit": true, |
|
|
"quant_method": "bitsandbytes" |
|
|
} |
|
|
|
|
|
""" |
|
|
expected_data = self._parse_config_string(expected_config) |
|
|
actual_data = self._parse_config_string(str(retrieved_config)) |
|
|
self.assertTrue(actual_data == expected_data) |
|
|
|
|
|
def _parse_config_string(self, config_string: str) -> tuple[str, dict]: |
|
|
first_brace = config_string.find("{") |
|
|
if first_brace == -1: |
|
|
raise ValueError("Could not find opening brace '{' in the string.") |
|
|
|
|
|
json_part = config_string[first_brace:] |
|
|
data = json.loads(json_part) |
|
|
|
|
|
return data |
|
|
|