|
|
import gc |
|
|
import tempfile |
|
|
import unittest |
|
|
|
|
|
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig |
|
|
from diffusers.models.attention_processor import Attention |
|
|
from diffusers.utils import is_optimum_quanto_available, is_torch_available |
|
|
from diffusers.utils.testing_utils import ( |
|
|
backend_empty_cache, |
|
|
backend_reset_peak_memory_stats, |
|
|
enable_full_determinism, |
|
|
nightly, |
|
|
numpy_cosine_similarity_distance, |
|
|
require_accelerate, |
|
|
require_big_accelerator, |
|
|
require_torch_cuda_compatibility, |
|
|
torch_device, |
|
|
) |
|
|
|
|
|
|
|
|
if is_optimum_quanto_available(): |
|
|
from optimum.quanto import QLinear |
|
|
|
|
|
if is_torch_available(): |
|
|
import torch |
|
|
|
|
|
from ..utils import LoRALayer, get_memory_consumption_stat |
|
|
|
|
|
enable_full_determinism() |
|
|
|
|
|
|
|
|
@nightly |
|
|
@require_big_accelerator |
|
|
@require_accelerate |
|
|
class QuantoBaseTesterMixin: |
|
|
model_id = None |
|
|
pipeline_model_id = None |
|
|
model_cls = None |
|
|
torch_dtype = torch.bfloat16 |
|
|
|
|
|
expected_memory_reduction = 0.0 |
|
|
keep_in_fp32_module = "" |
|
|
modules_to_not_convert = "" |
|
|
_test_torch_compile = False |
|
|
|
|
|
def setUp(self): |
|
|
backend_reset_peak_memory_stats(torch_device) |
|
|
backend_empty_cache(torch_device) |
|
|
gc.collect() |
|
|
|
|
|
def tearDown(self): |
|
|
backend_reset_peak_memory_stats(torch_device) |
|
|
backend_empty_cache(torch_device) |
|
|
gc.collect() |
|
|
|
|
|
def get_dummy_init_kwargs(self): |
|
|
return {"weights_dtype": "float8"} |
|
|
|
|
|
def get_dummy_model_init_kwargs(self): |
|
|
return { |
|
|
"pretrained_model_name_or_path": self.model_id, |
|
|
"torch_dtype": self.torch_dtype, |
|
|
"quantization_config": QuantoConfig(**self.get_dummy_init_kwargs()), |
|
|
} |
|
|
|
|
|
def test_quanto_layers(self): |
|
|
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) |
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, torch.nn.Linear): |
|
|
assert isinstance(module, QLinear) |
|
|
|
|
|
def test_quanto_memory_usage(self): |
|
|
inputs = self.get_dummy_inputs() |
|
|
inputs = { |
|
|
k: v.to(device=torch_device, dtype=torch.bfloat16) for k, v in inputs.items() if not isinstance(v, bool) |
|
|
} |
|
|
|
|
|
unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype) |
|
|
unquantized_model.to(torch_device) |
|
|
unquantized_model_memory = get_memory_consumption_stat(unquantized_model, inputs) |
|
|
|
|
|
quantized_model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) |
|
|
quantized_model.to(torch_device) |
|
|
quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs) |
|
|
|
|
|
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_reduction |
|
|
|
|
|
def test_keep_modules_in_fp32(self): |
|
|
r""" |
|
|
A simple tests to check if the modules under `_keep_in_fp32_modules` are kept in fp32. |
|
|
Also ensures if inference works. |
|
|
""" |
|
|
_keep_in_fp32_modules = self.model_cls._keep_in_fp32_modules |
|
|
self.model_cls._keep_in_fp32_modules = self.keep_in_fp32_module |
|
|
|
|
|
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) |
|
|
model.to(torch_device) |
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if isinstance(module, torch.nn.Linear): |
|
|
if name in model._keep_in_fp32_modules: |
|
|
assert module.weight.dtype == torch.float32 |
|
|
self.model_cls._keep_in_fp32_modules = _keep_in_fp32_modules |
|
|
|
|
|
def test_modules_to_not_convert(self): |
|
|
init_kwargs = self.get_dummy_model_init_kwargs() |
|
|
|
|
|
quantization_config_kwargs = self.get_dummy_init_kwargs() |
|
|
quantization_config_kwargs.update({"modules_to_not_convert": self.modules_to_not_convert}) |
|
|
quantization_config = QuantoConfig(**quantization_config_kwargs) |
|
|
|
|
|
init_kwargs.update({"quantization_config": quantization_config}) |
|
|
|
|
|
model = self.model_cls.from_pretrained(**init_kwargs) |
|
|
model.to(torch_device) |
|
|
|
|
|
for name, module in model.named_modules(): |
|
|
if name in self.modules_to_not_convert: |
|
|
assert not isinstance(module, QLinear) |
|
|
|
|
|
def test_dtype_assignment(self): |
|
|
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) |
|
|
|
|
|
with self.assertRaises(ValueError): |
|
|
|
|
|
model.to(torch.float16) |
|
|
|
|
|
with self.assertRaises(ValueError): |
|
|
|
|
|
device_0 = f"{torch_device}:0" |
|
|
model.to(device=device_0, dtype=torch.float16) |
|
|
|
|
|
with self.assertRaises(ValueError): |
|
|
|
|
|
model.float() |
|
|
|
|
|
with self.assertRaises(ValueError): |
|
|
|
|
|
model.half() |
|
|
|
|
|
|
|
|
model.to(torch_device) |
|
|
|
|
|
def test_serialization(self): |
|
|
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) |
|
|
inputs = self.get_dummy_inputs() |
|
|
|
|
|
model.to(torch_device) |
|
|
with torch.no_grad(): |
|
|
model_output = model(**inputs) |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
model.save_pretrained(tmp_dir) |
|
|
saved_model = self.model_cls.from_pretrained( |
|
|
tmp_dir, |
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
|
|
|
saved_model.to(torch_device) |
|
|
with torch.no_grad(): |
|
|
saved_model_output = saved_model(**inputs) |
|
|
|
|
|
assert torch.allclose(model_output.sample, saved_model_output.sample, rtol=1e-5, atol=1e-5) |
|
|
|
|
|
def test_torch_compile(self): |
|
|
if not self._test_torch_compile: |
|
|
return |
|
|
|
|
|
model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs()) |
|
|
compiled_model = torch.compile(model, mode="max-autotune", fullgraph=True, dynamic=False) |
|
|
|
|
|
model.to(torch_device) |
|
|
with torch.no_grad(): |
|
|
model_output = model(**self.get_dummy_inputs()).sample |
|
|
|
|
|
compiled_model.to(torch_device) |
|
|
with torch.no_grad(): |
|
|
compiled_model_output = compiled_model(**self.get_dummy_inputs()).sample |
|
|
|
|
|
model_output = model_output.detach().float().cpu().numpy() |
|
|
compiled_model_output = compiled_model_output.detach().float().cpu().numpy() |
|
|
|
|
|
max_diff = numpy_cosine_similarity_distance(model_output.flatten(), compiled_model_output.flatten()) |
|
|
assert max_diff < 1e-3 |
|
|
|
|
|
def test_device_map_error(self): |
|
|
with self.assertRaises(ValueError): |
|
|
_ = self.model_cls.from_pretrained( |
|
|
**self.get_dummy_model_init_kwargs(), device_map={0: "8GB", "cpu": "16GB"} |
|
|
) |
|
|
|
|
|
|
|
|
class FluxTransformerQuantoMixin(QuantoBaseTesterMixin): |
|
|
model_id = "hf-internal-testing/tiny-flux-transformer" |
|
|
model_cls = FluxTransformer2DModel |
|
|
pipeline_cls = FluxPipeline |
|
|
torch_dtype = torch.bfloat16 |
|
|
keep_in_fp32_module = "proj_out" |
|
|
modules_to_not_convert = ["proj_out"] |
|
|
_test_torch_compile = False |
|
|
|
|
|
def get_dummy_inputs(self): |
|
|
return { |
|
|
"hidden_states": torch.randn((1, 4096, 64), generator=torch.Generator("cpu").manual_seed(0)).to( |
|
|
torch_device, self.torch_dtype |
|
|
), |
|
|
"encoder_hidden_states": torch.randn( |
|
|
(1, 512, 4096), |
|
|
generator=torch.Generator("cpu").manual_seed(0), |
|
|
).to(torch_device, self.torch_dtype), |
|
|
"pooled_projections": torch.randn( |
|
|
(1, 768), |
|
|
generator=torch.Generator("cpu").manual_seed(0), |
|
|
).to(torch_device, self.torch_dtype), |
|
|
"timestep": torch.tensor([1]).to(torch_device, self.torch_dtype), |
|
|
"img_ids": torch.randn((4096, 3), generator=torch.Generator("cpu").manual_seed(0)).to( |
|
|
torch_device, self.torch_dtype |
|
|
), |
|
|
"txt_ids": torch.randn((512, 3), generator=torch.Generator("cpu").manual_seed(0)).to( |
|
|
torch_device, self.torch_dtype |
|
|
), |
|
|
"guidance": torch.tensor([3.5]).to(torch_device, self.torch_dtype), |
|
|
} |
|
|
|
|
|
def get_dummy_training_inputs(self, device=None, seed: int = 0): |
|
|
batch_size = 1 |
|
|
num_latent_channels = 4 |
|
|
num_image_channels = 3 |
|
|
height = width = 4 |
|
|
sequence_length = 48 |
|
|
embedding_dim = 32 |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16) |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( |
|
|
device, dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16) |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16) |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16) |
|
|
|
|
|
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size) |
|
|
|
|
|
return { |
|
|
"hidden_states": hidden_states, |
|
|
"encoder_hidden_states": encoder_hidden_states, |
|
|
"pooled_projections": pooled_prompt_embeds, |
|
|
"txt_ids": text_ids, |
|
|
"img_ids": image_ids, |
|
|
"timestep": timestep, |
|
|
} |
|
|
|
|
|
def test_model_cpu_offload(self): |
|
|
init_kwargs = self.get_dummy_init_kwargs() |
|
|
transformer = self.model_cls.from_pretrained( |
|
|
"hf-internal-testing/tiny-flux-pipe", |
|
|
quantization_config=QuantoConfig(**init_kwargs), |
|
|
subfolder="transformer", |
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
pipe = self.pipeline_cls.from_pretrained( |
|
|
"hf-internal-testing/tiny-flux-pipe", transformer=transformer, torch_dtype=torch.bfloat16 |
|
|
) |
|
|
pipe.enable_model_cpu_offload(device=torch_device) |
|
|
_ = pipe("a cat holding a sign that says hello", num_inference_steps=2) |
|
|
|
|
|
def test_training(self): |
|
|
quantization_config = QuantoConfig(**self.get_dummy_init_kwargs()) |
|
|
quantized_model = self.model_cls.from_pretrained( |
|
|
"hf-internal-testing/tiny-flux-pipe", |
|
|
subfolder="transformer", |
|
|
quantization_config=quantization_config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
).to(torch_device) |
|
|
|
|
|
for param in quantized_model.parameters(): |
|
|
|
|
|
param.requires_grad = False |
|
|
if param.ndim == 1: |
|
|
param.data = param.data.to(torch.float32) |
|
|
|
|
|
for _, module in quantized_model.named_modules(): |
|
|
if isinstance(module, Attention): |
|
|
module.to_q = LoRALayer(module.to_q, rank=4) |
|
|
module.to_k = LoRALayer(module.to_k, rank=4) |
|
|
module.to_v = LoRALayer(module.to_v, rank=4) |
|
|
|
|
|
with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16): |
|
|
inputs = self.get_dummy_training_inputs(torch_device) |
|
|
output = quantized_model(**inputs)[0] |
|
|
output.norm().backward() |
|
|
|
|
|
for module in quantized_model.modules(): |
|
|
if isinstance(module, LoRALayer): |
|
|
self.assertTrue(module.adapter[1].weight.grad is not None) |
|
|
|
|
|
|
|
|
class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): |
|
|
expected_memory_reduction = 0.6 |
|
|
|
|
|
def get_dummy_init_kwargs(self): |
|
|
return {"weights_dtype": "float8"} |
|
|
|
|
|
|
|
|
class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): |
|
|
expected_memory_reduction = 0.6 |
|
|
_test_torch_compile = True |
|
|
|
|
|
def get_dummy_init_kwargs(self): |
|
|
return {"weights_dtype": "int8"} |
|
|
|
|
|
|
|
|
@require_torch_cuda_compatibility(8.0) |
|
|
class FluxTransformerInt4WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): |
|
|
expected_memory_reduction = 0.55 |
|
|
|
|
|
def get_dummy_init_kwargs(self): |
|
|
return {"weights_dtype": "int4"} |
|
|
|
|
|
|
|
|
@require_torch_cuda_compatibility(8.0) |
|
|
class FluxTransformerInt2WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase): |
|
|
expected_memory_reduction = 0.65 |
|
|
|
|
|
def get_dummy_init_kwargs(self): |
|
|
return {"weights_dtype": "int2"} |
|
|
|