|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import gc |
|
|
import tempfile |
|
|
import unittest |
|
|
from typing import List |
|
|
|
|
|
import numpy as np |
|
|
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel |
|
|
|
|
|
from diffusers import ( |
|
|
AutoencoderKL, |
|
|
FlowMatchEulerDiscreteScheduler, |
|
|
FluxPipeline, |
|
|
FluxTransformer2DModel, |
|
|
TorchAoConfig, |
|
|
) |
|
|
from diffusers.models.attention_processor import Attention |
|
|
from diffusers.utils.testing_utils import ( |
|
|
enable_full_determinism, |
|
|
is_torch_available, |
|
|
is_torchao_available, |
|
|
nightly, |
|
|
numpy_cosine_similarity_distance, |
|
|
require_torch, |
|
|
require_torch_gpu, |
|
|
require_torchao_version_greater_or_equal, |
|
|
slow, |
|
|
torch_device, |
|
|
) |
|
|
|
|
|
|
|
|
enable_full_determinism() |
|
|
|
|
|
|
|
|
if is_torch_available(): |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from ..utils import LoRALayer, get_memory_consumption_stat |
|
|
|
|
|
|
|
|
if is_torchao_available(): |
|
|
from torchao.dtypes import AffineQuantizedTensor |
|
|
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor |
|
|
from torchao.quantization.quant_primitives import MappingType |
|
|
from torchao.utils import get_model_size_in_bytes |
|
|
|
|
|
|
|
|
@require_torch |
|
|
@require_torch_gpu |
|
|
@require_torchao_version_greater_or_equal("0.7.0") |
|
|
class TorchAoConfigTest(unittest.TestCase): |
|
|
def test_to_dict(self): |
|
|
""" |
|
|
Makes sure the config format is properly set |
|
|
""" |
|
|
quantization_config = TorchAoConfig("int4_weight_only") |
|
|
torchao_orig_config = quantization_config.to_dict() |
|
|
|
|
|
for key in torchao_orig_config: |
|
|
self.assertEqual(getattr(quantization_config, key), torchao_orig_config[key]) |
|
|
|
|
|
def test_post_init_check(self): |
|
|
""" |
|
|
Test kwargs validations in TorchAoConfig |
|
|
""" |
|
|
_ = TorchAoConfig("int4_weight_only") |
|
|
with self.assertRaisesRegex(ValueError, "is not supported yet"): |
|
|
_ = TorchAoConfig("uint8") |
|
|
|
|
|
with self.assertRaisesRegex(ValueError, "does not support the following keyword arguments"): |
|
|
_ = TorchAoConfig("int4_weight_only", group_size1=32) |
|
|
|
|
|
def test_repr(self): |
|
|
""" |
|
|
Check that there is no error in the repr |
|
|
""" |
|
|
quantization_config = TorchAoConfig("int4_weight_only", modules_to_not_convert=["conv"], group_size=8) |
|
|
expected_repr = """TorchAoConfig { |
|
|
"modules_to_not_convert": [ |
|
|
"conv" |
|
|
], |
|
|
"quant_method": "torchao", |
|
|
"quant_type": "int4_weight_only", |
|
|
"quant_type_kwargs": { |
|
|
"group_size": 8 |
|
|
} |
|
|
}""".replace(" ", "").replace("\n", "") |
|
|
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "") |
|
|
self.assertEqual(quantization_repr, expected_repr) |
|
|
|
|
|
quantization_config = TorchAoConfig("int4dq", group_size=64, act_mapping_type=MappingType.SYMMETRIC) |
|
|
expected_repr = """TorchAoConfig { |
|
|
"modules_to_not_convert": null, |
|
|
"quant_method": "torchao", |
|
|
"quant_type": "int4dq", |
|
|
"quant_type_kwargs": { |
|
|
"act_mapping_type": "SYMMETRIC", |
|
|
"group_size": 64 |
|
|
} |
|
|
}""".replace(" ", "").replace("\n", "") |
|
|
quantization_repr = repr(quantization_config).replace(" ", "").replace("\n", "") |
|
|
self.assertEqual(quantization_repr, expected_repr) |
|
|
|
|
|
|
|
|
|
|
|
@require_torch |
|
|
@require_torch_gpu |
|
|
@require_torchao_version_greater_or_equal("0.7.0") |
|
|
class TorchAoTest(unittest.TestCase): |
|
|
def tearDown(self): |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def get_dummy_components( |
|
|
self, quantization_config: TorchAoConfig, model_id: str = "hf-internal-testing/tiny-flux-pipe" |
|
|
): |
|
|
transformer = FluxTransformer2DModel.from_pretrained( |
|
|
model_id, |
|
|
subfolder="transformer", |
|
|
quantization_config=quantization_config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16) |
|
|
text_encoder_2 = T5EncoderModel.from_pretrained( |
|
|
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16 |
|
|
) |
|
|
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer") |
|
|
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2") |
|
|
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16) |
|
|
scheduler = FlowMatchEulerDiscreteScheduler() |
|
|
|
|
|
return { |
|
|
"scheduler": scheduler, |
|
|
"text_encoder": text_encoder, |
|
|
"text_encoder_2": text_encoder_2, |
|
|
"tokenizer": tokenizer, |
|
|
"tokenizer_2": tokenizer_2, |
|
|
"transformer": transformer, |
|
|
"vae": vae, |
|
|
} |
|
|
|
|
|
def get_dummy_inputs(self, device: torch.device, seed: int = 0): |
|
|
if str(device).startswith("mps"): |
|
|
generator = torch.manual_seed(seed) |
|
|
else: |
|
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
|
|
inputs = { |
|
|
"prompt": "an astronaut riding a horse in space", |
|
|
"height": 32, |
|
|
"width": 32, |
|
|
"num_inference_steps": 2, |
|
|
"output_type": "np", |
|
|
"generator": generator, |
|
|
} |
|
|
|
|
|
return inputs |
|
|
|
|
|
def get_dummy_tensor_inputs(self, device=None, seed: int = 0): |
|
|
batch_size = 1 |
|
|
num_latent_channels = 4 |
|
|
num_image_channels = 3 |
|
|
height = width = 4 |
|
|
sequence_length = 48 |
|
|
embedding_dim = 32 |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16) |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( |
|
|
device, dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16) |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16) |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16) |
|
|
|
|
|
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size) |
|
|
|
|
|
return { |
|
|
"hidden_states": hidden_states, |
|
|
"encoder_hidden_states": encoder_hidden_states, |
|
|
"pooled_projections": pooled_prompt_embeds, |
|
|
"txt_ids": text_ids, |
|
|
"img_ids": image_ids, |
|
|
"timestep": timestep, |
|
|
} |
|
|
|
|
|
def _test_quant_type(self, quantization_config: TorchAoConfig, expected_slice: List[float], model_id: str): |
|
|
components = self.get_dummy_components(quantization_config, model_id) |
|
|
pipe = FluxPipeline(**components) |
|
|
pipe.to(device=torch_device) |
|
|
|
|
|
inputs = self.get_dummy_inputs(torch_device) |
|
|
output = pipe(**inputs)[0] |
|
|
output_slice = output[-1, -1, -3:, -3:].flatten() |
|
|
|
|
|
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) |
|
|
|
|
|
def test_quantization(self): |
|
|
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: |
|
|
|
|
|
QUANTIZATION_TYPES_TO_TEST = [ |
|
|
("int4wo", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6445, 0.4336, 0.4531, 0.5625])), |
|
|
("int4dq", np.array([0.4688, 0.5195, 0.5547, 0.418, 0.4414, 0.6406, 0.4336, 0.4531, 0.5625])), |
|
|
("int8wo", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), |
|
|
("int8dq", np.array([0.4648, 0.5195, 0.5547, 0.4199, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), |
|
|
("uint4wo", np.array([0.4609, 0.5234, 0.5508, 0.4199, 0.4336, 0.6406, 0.4316, 0.4531, 0.5625])), |
|
|
("uint7wo", np.array([0.4648, 0.5195, 0.5547, 0.4219, 0.4414, 0.6445, 0.4316, 0.4531, 0.5625])), |
|
|
] |
|
|
|
|
|
if TorchAoConfig._is_cuda_capability_atleast_8_9(): |
|
|
QUANTIZATION_TYPES_TO_TEST.extend([ |
|
|
("float8wo_e5m2", np.array([0.4590, 0.5273, 0.5547, 0.4219, 0.4375, 0.6406, 0.4316, 0.4512, 0.5625])), |
|
|
("float8wo_e4m3", np.array([0.4648, 0.5234, 0.5547, 0.4219, 0.4414, 0.6406, 0.4316, 0.4531, 0.5625])), |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
("fp4", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), |
|
|
("fp6", np.array([0.4668, 0.5195, 0.5547, 0.4199, 0.4434, 0.6445, 0.4316, 0.4531, 0.5625])), |
|
|
]) |
|
|
|
|
|
|
|
|
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: |
|
|
quant_kwargs = {} |
|
|
if quantization_name in ["uint4wo", "uint7wo"]: |
|
|
|
|
|
quant_kwargs.update({"group_size": 16}) |
|
|
quantization_config = TorchAoConfig( |
|
|
quant_type=quantization_name, modules_to_not_convert=["x_embedder"], **quant_kwargs |
|
|
) |
|
|
self._test_quant_type(quantization_config, expected_slice, model_id) |
|
|
|
|
|
def test_int4wo_quant_bfloat16_conversion(self): |
|
|
""" |
|
|
Tests whether the dtype of model will be modified to bfloat16 for int4 weight-only quantization. |
|
|
""" |
|
|
quantization_config = TorchAoConfig("int4_weight_only", group_size=64) |
|
|
quantized_model = FluxTransformer2DModel.from_pretrained( |
|
|
"hf-internal-testing/tiny-flux-pipe", |
|
|
subfolder="transformer", |
|
|
quantization_config=quantization_config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
|
|
|
weight = quantized_model.transformer_blocks[0].ff.net[2].weight |
|
|
self.assertTrue(isinstance(weight, AffineQuantizedTensor)) |
|
|
self.assertEqual(weight.quant_min, 0) |
|
|
self.assertEqual(weight.quant_max, 15) |
|
|
|
|
|
def test_device_map(self): |
|
|
""" |
|
|
Test if the quantized model int4 weight-only is working properly with "auto" and custom device maps. |
|
|
The custom device map performs cpu/disk offloading as well. Also verifies that the device map is |
|
|
correctly set (in the `hf_device_map` attribute of the model). |
|
|
""" |
|
|
custom_device_map_dict = { |
|
|
"time_text_embed": torch_device, |
|
|
"context_embedder": torch_device, |
|
|
"x_embedder": torch_device, |
|
|
"transformer_blocks.0": "cpu", |
|
|
"single_transformer_blocks.0": "disk", |
|
|
"norm_out": torch_device, |
|
|
"proj_out": "cpu", |
|
|
} |
|
|
device_maps = ["auto", custom_device_map_dict] |
|
|
|
|
|
inputs = self.get_dummy_tensor_inputs(torch_device) |
|
|
|
|
|
expected_slice_auto = np.array( |
|
|
[ |
|
|
0.34179688, |
|
|
-0.03613281, |
|
|
0.01428223, |
|
|
-0.22949219, |
|
|
-0.49609375, |
|
|
0.4375, |
|
|
-0.1640625, |
|
|
-0.66015625, |
|
|
0.43164062, |
|
|
] |
|
|
) |
|
|
expected_slice_offload = np.array( |
|
|
[0.34375, -0.03515625, 0.0123291, -0.22753906, -0.49414062, 0.4375, -0.16308594, -0.66015625, 0.43554688] |
|
|
) |
|
|
for device_map in device_maps: |
|
|
if device_map == "auto": |
|
|
expected_slice = expected_slice_auto |
|
|
else: |
|
|
expected_slice = expected_slice_offload |
|
|
with tempfile.TemporaryDirectory() as offload_folder: |
|
|
quantization_config = TorchAoConfig("int4_weight_only", group_size=64) |
|
|
quantized_model = FluxTransformer2DModel.from_pretrained( |
|
|
"hf-internal-testing/tiny-flux-pipe", |
|
|
subfolder="transformer", |
|
|
quantization_config=quantization_config, |
|
|
device_map=device_map, |
|
|
torch_dtype=torch.bfloat16, |
|
|
offload_folder=offload_folder, |
|
|
) |
|
|
|
|
|
weight = quantized_model.transformer_blocks[0].ff.net[2].weight |
|
|
|
|
|
|
|
|
|
|
|
if "transformer_blocks.0" in device_map: |
|
|
self.assertTrue(isinstance(weight, nn.Parameter)) |
|
|
else: |
|
|
self.assertTrue(isinstance(weight, AffineQuantizedTensor)) |
|
|
|
|
|
output = quantized_model(**inputs)[0] |
|
|
output_slice = output.flatten()[-9:].detach().float().cpu().numpy() |
|
|
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) |
|
|
|
|
|
with tempfile.TemporaryDirectory() as offload_folder: |
|
|
quantization_config = TorchAoConfig("int4_weight_only", group_size=64) |
|
|
quantized_model = FluxTransformer2DModel.from_pretrained( |
|
|
"hf-internal-testing/tiny-flux-sharded", |
|
|
subfolder="transformer", |
|
|
quantization_config=quantization_config, |
|
|
device_map=device_map, |
|
|
torch_dtype=torch.bfloat16, |
|
|
offload_folder=offload_folder, |
|
|
) |
|
|
|
|
|
weight = quantized_model.transformer_blocks[0].ff.net[2].weight |
|
|
if "transformer_blocks.0" in device_map: |
|
|
self.assertTrue(isinstance(weight, nn.Parameter)) |
|
|
else: |
|
|
self.assertTrue(isinstance(weight, AffineQuantizedTensor)) |
|
|
|
|
|
output = quantized_model(**inputs)[0] |
|
|
output_slice = output.flatten()[-9:].detach().float().cpu().numpy() |
|
|
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) |
|
|
|
|
|
def test_modules_to_not_convert(self): |
|
|
quantization_config = TorchAoConfig("int8_weight_only", modules_to_not_convert=["transformer_blocks.0"]) |
|
|
quantized_model_with_not_convert = FluxTransformer2DModel.from_pretrained( |
|
|
"hf-internal-testing/tiny-flux-pipe", |
|
|
subfolder="transformer", |
|
|
quantization_config=quantization_config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
|
|
|
unquantized_layer = quantized_model_with_not_convert.transformer_blocks[0].ff.net[2] |
|
|
self.assertTrue(isinstance(unquantized_layer, torch.nn.Linear)) |
|
|
self.assertFalse(isinstance(unquantized_layer.weight, AffineQuantizedTensor)) |
|
|
self.assertEqual(unquantized_layer.weight.dtype, torch.bfloat16) |
|
|
|
|
|
quantized_layer = quantized_model_with_not_convert.proj_out |
|
|
self.assertTrue(isinstance(quantized_layer.weight, AffineQuantizedTensor)) |
|
|
|
|
|
quantization_config = TorchAoConfig("int8_weight_only") |
|
|
quantized_model = FluxTransformer2DModel.from_pretrained( |
|
|
"hf-internal-testing/tiny-flux-pipe", |
|
|
subfolder="transformer", |
|
|
quantization_config=quantization_config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
|
|
|
size_quantized_with_not_convert = get_model_size_in_bytes(quantized_model_with_not_convert) |
|
|
size_quantized = get_model_size_in_bytes(quantized_model) |
|
|
|
|
|
self.assertTrue(size_quantized < size_quantized_with_not_convert) |
|
|
|
|
|
def test_training(self): |
|
|
quantization_config = TorchAoConfig("int8_weight_only") |
|
|
quantized_model = FluxTransformer2DModel.from_pretrained( |
|
|
"hf-internal-testing/tiny-flux-pipe", |
|
|
subfolder="transformer", |
|
|
quantization_config=quantization_config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
).to(torch_device) |
|
|
|
|
|
for param in quantized_model.parameters(): |
|
|
|
|
|
param.requires_grad = False |
|
|
if param.ndim == 1: |
|
|
param.data = param.data.to(torch.float32) |
|
|
|
|
|
for _, module in quantized_model.named_modules(): |
|
|
if isinstance(module, Attention): |
|
|
module.to_q = LoRALayer(module.to_q, rank=4) |
|
|
module.to_k = LoRALayer(module.to_k, rank=4) |
|
|
module.to_v = LoRALayer(module.to_v, rank=4) |
|
|
|
|
|
with torch.amp.autocast(str(torch_device), dtype=torch.bfloat16): |
|
|
inputs = self.get_dummy_tensor_inputs(torch_device) |
|
|
output = quantized_model(**inputs)[0] |
|
|
output.norm().backward() |
|
|
|
|
|
for module in quantized_model.modules(): |
|
|
if isinstance(module, LoRALayer): |
|
|
self.assertTrue(module.adapter[1].weight.grad is not None) |
|
|
self.assertTrue(module.adapter[1].weight.grad.norm().item() > 0) |
|
|
|
|
|
@nightly |
|
|
def test_torch_compile(self): |
|
|
r"""Test that verifies if torch.compile works with torchao quantization.""" |
|
|
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: |
|
|
quantization_config = TorchAoConfig("int8_weight_only") |
|
|
components = self.get_dummy_components(quantization_config, model_id=model_id) |
|
|
pipe = FluxPipeline(**components) |
|
|
pipe.to(device=torch_device) |
|
|
|
|
|
inputs = self.get_dummy_inputs(torch_device) |
|
|
normal_output = pipe(**inputs)[0].flatten()[-32:] |
|
|
|
|
|
pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False) |
|
|
inputs = self.get_dummy_inputs(torch_device) |
|
|
compile_output = pipe(**inputs)[0].flatten()[-32:] |
|
|
|
|
|
|
|
|
self.assertTrue(np.allclose(normal_output, compile_output, atol=1e-2, rtol=1e-3)) |
|
|
|
|
|
def test_memory_footprint(self): |
|
|
r""" |
|
|
A simple test to check if the model conversion has been done correctly by checking on the |
|
|
memory footprint of the converted model and the class type of the linear layers of the converted models |
|
|
""" |
|
|
for model_id in ["hf-internal-testing/tiny-flux-pipe", "hf-internal-testing/tiny-flux-sharded"]: |
|
|
transformer_int4wo = self.get_dummy_components(TorchAoConfig("int4wo"), model_id=model_id)["transformer"] |
|
|
transformer_int4wo_gs32 = self.get_dummy_components( |
|
|
TorchAoConfig("int4wo", group_size=32), model_id=model_id |
|
|
)["transformer"] |
|
|
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"] |
|
|
transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"] |
|
|
|
|
|
|
|
|
for block in transformer_int4wo.transformer_blocks: |
|
|
self.assertTrue(isinstance(block.ff.net[2].weight, AffineQuantizedTensor)) |
|
|
self.assertTrue(isinstance(block.ff_context.net[2].weight, AffineQuantizedTensor)) |
|
|
|
|
|
|
|
|
for name, module in transformer_int4wo_gs32.named_modules(): |
|
|
if isinstance(module, nn.Linear) and name not in ["x_embedder"]: |
|
|
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor)) |
|
|
|
|
|
|
|
|
for module in transformer_int8wo.modules(): |
|
|
if isinstance(module, nn.Linear): |
|
|
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor)) |
|
|
|
|
|
total_int4wo = get_model_size_in_bytes(transformer_int4wo) |
|
|
total_int4wo_gs32 = get_model_size_in_bytes(transformer_int4wo_gs32) |
|
|
total_int8wo = get_model_size_in_bytes(transformer_int8wo) |
|
|
total_bf16 = get_model_size_in_bytes(transformer_bf16) |
|
|
|
|
|
|
|
|
|
|
|
self.assertTrue(total_int4wo < total_int4wo_gs32) |
|
|
|
|
|
self.assertTrue(total_int8wo < total_int4wo) |
|
|
|
|
|
|
|
|
self.assertTrue(total_bf16 < total_int4wo) |
|
|
|
|
|
def test_model_memory_usage(self): |
|
|
model_id = "hf-internal-testing/tiny-flux-pipe" |
|
|
expected_memory_saving_ratio = 2.0 |
|
|
|
|
|
inputs = self.get_dummy_tensor_inputs(device=torch_device) |
|
|
|
|
|
transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"] |
|
|
transformer_bf16.to(torch_device) |
|
|
unquantized_model_memory = get_memory_consumption_stat(transformer_bf16, inputs) |
|
|
del transformer_bf16 |
|
|
|
|
|
transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"] |
|
|
transformer_int8wo.to(torch_device) |
|
|
quantized_model_memory = get_memory_consumption_stat(transformer_int8wo, inputs) |
|
|
assert unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio |
|
|
|
|
|
def test_wrong_config(self): |
|
|
with self.assertRaises(ValueError): |
|
|
self.get_dummy_components(TorchAoConfig("int42")) |
|
|
|
|
|
def test_sequential_cpu_offload(self): |
|
|
r""" |
|
|
A test that checks if inference runs as expected when sequential cpu offloading is enabled. |
|
|
""" |
|
|
quantization_config = TorchAoConfig("int8wo") |
|
|
components = self.get_dummy_components(quantization_config) |
|
|
pipe = FluxPipeline(**components) |
|
|
pipe.enable_sequential_cpu_offload() |
|
|
|
|
|
inputs = self.get_dummy_inputs(torch_device) |
|
|
_ = pipe(**inputs) |
|
|
|
|
|
|
|
|
|
|
|
@require_torch |
|
|
@require_torch_gpu |
|
|
@require_torchao_version_greater_or_equal("0.7.0") |
|
|
class TorchAoSerializationTest(unittest.TestCase): |
|
|
model_name = "hf-internal-testing/tiny-flux-pipe" |
|
|
|
|
|
def tearDown(self): |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def get_dummy_model(self, quant_method, quant_method_kwargs, device=None): |
|
|
quantization_config = TorchAoConfig(quant_method, **quant_method_kwargs) |
|
|
quantized_model = FluxTransformer2DModel.from_pretrained( |
|
|
self.model_name, |
|
|
subfolder="transformer", |
|
|
quantization_config=quantization_config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
) |
|
|
return quantized_model.to(device) |
|
|
|
|
|
def get_dummy_tensor_inputs(self, device=None, seed: int = 0): |
|
|
batch_size = 1 |
|
|
num_latent_channels = 4 |
|
|
num_image_channels = 3 |
|
|
height = width = 4 |
|
|
sequence_length = 48 |
|
|
embedding_dim = 32 |
|
|
|
|
|
torch.manual_seed(seed) |
|
|
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(device, dtype=torch.bfloat16) |
|
|
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to( |
|
|
device, dtype=torch.bfloat16 |
|
|
) |
|
|
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(device, dtype=torch.bfloat16) |
|
|
text_ids = torch.randn((sequence_length, num_image_channels)).to(device, dtype=torch.bfloat16) |
|
|
image_ids = torch.randn((height * width, num_image_channels)).to(device, dtype=torch.bfloat16) |
|
|
timestep = torch.tensor([1.0]).to(device, dtype=torch.bfloat16).expand(batch_size) |
|
|
|
|
|
return { |
|
|
"hidden_states": hidden_states, |
|
|
"encoder_hidden_states": encoder_hidden_states, |
|
|
"pooled_projections": pooled_prompt_embeds, |
|
|
"txt_ids": text_ids, |
|
|
"img_ids": image_ids, |
|
|
"timestep": timestep, |
|
|
} |
|
|
|
|
|
def _test_original_model_expected_slice(self, quant_method, quant_method_kwargs, expected_slice): |
|
|
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, torch_device) |
|
|
inputs = self.get_dummy_tensor_inputs(torch_device) |
|
|
output = quantized_model(**inputs)[0] |
|
|
output_slice = output.flatten()[-9:].detach().float().cpu().numpy() |
|
|
weight = quantized_model.transformer_blocks[0].ff.net[2].weight |
|
|
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))) |
|
|
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) |
|
|
|
|
|
def _check_serialization_expected_slice(self, quant_method, quant_method_kwargs, expected_slice, device): |
|
|
quantized_model = self.get_dummy_model(quant_method, quant_method_kwargs, device) |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
quantized_model.save_pretrained(tmp_dir, safe_serialization=False) |
|
|
loaded_quantized_model = FluxTransformer2DModel.from_pretrained( |
|
|
tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False |
|
|
).to(device=torch_device) |
|
|
|
|
|
inputs = self.get_dummy_tensor_inputs(torch_device) |
|
|
output = loaded_quantized_model(**inputs)[0] |
|
|
|
|
|
output_slice = output.flatten()[-9:].detach().float().cpu().numpy() |
|
|
self.assertTrue( |
|
|
isinstance( |
|
|
loaded_quantized_model.proj_out.weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor) |
|
|
) |
|
|
) |
|
|
self.assertTrue(numpy_cosine_similarity_distance(output_slice, expected_slice) < 1e-3) |
|
|
|
|
|
def test_int_a8w8_cuda(self): |
|
|
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} |
|
|
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) |
|
|
device = "cuda" |
|
|
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) |
|
|
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) |
|
|
|
|
|
def test_int_a16w8_cuda(self): |
|
|
quant_method, quant_method_kwargs = "int8_weight_only", {} |
|
|
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) |
|
|
device = "cuda" |
|
|
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) |
|
|
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) |
|
|
|
|
|
def test_int_a8w8_cpu(self): |
|
|
quant_method, quant_method_kwargs = "int8_dynamic_activation_int8_weight", {} |
|
|
expected_slice = np.array([0.3633, -0.1357, -0.0188, -0.249, -0.4688, 0.5078, -0.1289, -0.6914, 0.4551]) |
|
|
device = "cpu" |
|
|
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) |
|
|
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) |
|
|
|
|
|
def test_int_a16w8_cpu(self): |
|
|
quant_method, quant_method_kwargs = "int8_weight_only", {} |
|
|
expected_slice = np.array([0.3613, -0.127, -0.0223, -0.2539, -0.459, 0.4961, -0.1357, -0.6992, 0.4551]) |
|
|
device = "cpu" |
|
|
self._test_original_model_expected_slice(quant_method, quant_method_kwargs, expected_slice) |
|
|
self._check_serialization_expected_slice(quant_method, quant_method_kwargs, expected_slice, device) |
|
|
|
|
|
|
|
|
|
|
|
@require_torch |
|
|
@require_torch_gpu |
|
|
@require_torchao_version_greater_or_equal("0.7.0") |
|
|
@slow |
|
|
@nightly |
|
|
class SlowTorchAoTests(unittest.TestCase): |
|
|
def tearDown(self): |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def get_dummy_components(self, quantization_config: TorchAoConfig): |
|
|
|
|
|
cache_dir = None |
|
|
model_id = "black-forest-labs/FLUX.1-dev" |
|
|
transformer = FluxTransformer2DModel.from_pretrained( |
|
|
model_id, |
|
|
subfolder="transformer", |
|
|
quantization_config=quantization_config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
cache_dir=cache_dir, |
|
|
) |
|
|
text_encoder = CLIPTextModel.from_pretrained( |
|
|
model_id, subfolder="text_encoder", torch_dtype=torch.bfloat16, cache_dir=cache_dir |
|
|
) |
|
|
text_encoder_2 = T5EncoderModel.from_pretrained( |
|
|
model_id, subfolder="text_encoder_2", torch_dtype=torch.bfloat16, cache_dir=cache_dir |
|
|
) |
|
|
tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer", cache_dir=cache_dir) |
|
|
tokenizer_2 = AutoTokenizer.from_pretrained(model_id, subfolder="tokenizer_2", cache_dir=cache_dir) |
|
|
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.bfloat16, cache_dir=cache_dir) |
|
|
scheduler = FlowMatchEulerDiscreteScheduler() |
|
|
|
|
|
return { |
|
|
"scheduler": scheduler, |
|
|
"text_encoder": text_encoder, |
|
|
"text_encoder_2": text_encoder_2, |
|
|
"tokenizer": tokenizer, |
|
|
"tokenizer_2": tokenizer_2, |
|
|
"transformer": transformer, |
|
|
"vae": vae, |
|
|
} |
|
|
|
|
|
def get_dummy_inputs(self, device: torch.device, seed: int = 0): |
|
|
if str(device).startswith("mps"): |
|
|
generator = torch.manual_seed(seed) |
|
|
else: |
|
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
|
|
inputs = { |
|
|
"prompt": "an astronaut riding a horse in space", |
|
|
"height": 512, |
|
|
"width": 512, |
|
|
"num_inference_steps": 20, |
|
|
"output_type": "np", |
|
|
"generator": generator, |
|
|
} |
|
|
|
|
|
return inputs |
|
|
|
|
|
def _test_quant_type(self, quantization_config, expected_slice): |
|
|
components = self.get_dummy_components(quantization_config) |
|
|
pipe = FluxPipeline(**components) |
|
|
pipe.enable_model_cpu_offload() |
|
|
|
|
|
weight = pipe.transformer.transformer_blocks[0].ff.net[2].weight |
|
|
self.assertTrue(isinstance(weight, (AffineQuantizedTensor, LinearActivationQuantizedTensor))) |
|
|
|
|
|
inputs = self.get_dummy_inputs(torch_device) |
|
|
output = pipe(**inputs)[0].flatten() |
|
|
output_slice = np.concatenate((output[:16], output[-16:])) |
|
|
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) |
|
|
|
|
|
def test_quantization(self): |
|
|
|
|
|
QUANTIZATION_TYPES_TO_TEST = [ |
|
|
("int8wo", np.array([0.0505, 0.0742, 0.1367, 0.0429, 0.0585, 0.1386, 0.0585, 0.0703, 0.1367, 0.0566, 0.0703, 0.1464, 0.0546, 0.0703, 0.1425, 0.0546, 0.3535, 0.7578, 0.5000, 0.4062, 0.7656, 0.5117, 0.4121, 0.7656, 0.5117, 0.3984, 0.7578, 0.5234, 0.4023, 0.7382, 0.5390, 0.4570])), |
|
|
("int8dq", np.array([0.0546, 0.0761, 0.1386, 0.0488, 0.0644, 0.1425, 0.0605, 0.0742, 0.1406, 0.0625, 0.0722, 0.1523, 0.0625, 0.0742, 0.1503, 0.0605, 0.3886, 0.7968, 0.5507, 0.4492, 0.7890, 0.5351, 0.4316, 0.8007, 0.5390, 0.4179, 0.8281, 0.5820, 0.4531, 0.7812, 0.5703, 0.4921])), |
|
|
] |
|
|
|
|
|
if TorchAoConfig._is_cuda_capability_atleast_8_9(): |
|
|
QUANTIZATION_TYPES_TO_TEST.extend([ |
|
|
("float8wo_e4m3", np.array([0.0546, 0.0722, 0.1328, 0.0468, 0.0585, 0.1367, 0.0605, 0.0703, 0.1328, 0.0625, 0.0703, 0.1445, 0.0585, 0.0703, 0.1406, 0.0605, 0.3496, 0.7109, 0.4843, 0.4042, 0.7226, 0.5000, 0.4160, 0.7031, 0.4824, 0.3886, 0.6757, 0.4667, 0.3710, 0.6679, 0.4902, 0.4238])), |
|
|
("fp5_e3m1", np.array([0.0527, 0.0762, 0.1309, 0.0449, 0.0645, 0.1328, 0.0566, 0.0723, 0.125, 0.0566, 0.0703, 0.1328, 0.0566, 0.0742, 0.1348, 0.0566, 0.3633, 0.7617, 0.5273, 0.4277, 0.7891, 0.5469, 0.4375, 0.8008, 0.5586, 0.4336, 0.7383, 0.5156, 0.3906, 0.6992, 0.5156, 0.4375])), |
|
|
]) |
|
|
|
|
|
|
|
|
for quantization_name, expected_slice in QUANTIZATION_TYPES_TO_TEST: |
|
|
quantization_config = TorchAoConfig(quant_type=quantization_name, modules_to_not_convert=["x_embedder"]) |
|
|
self._test_quant_type(quantization_config, expected_slice) |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
def test_serialization_int8wo(self): |
|
|
quantization_config = TorchAoConfig("int8wo") |
|
|
components = self.get_dummy_components(quantization_config) |
|
|
pipe = FluxPipeline(**components) |
|
|
pipe.enable_model_cpu_offload() |
|
|
|
|
|
weight = pipe.transformer.x_embedder.weight |
|
|
self.assertTrue(isinstance(weight, AffineQuantizedTensor)) |
|
|
|
|
|
inputs = self.get_dummy_inputs(torch_device) |
|
|
output = pipe(**inputs)[0].flatten()[:128] |
|
|
|
|
|
with tempfile.TemporaryDirectory() as tmp_dir: |
|
|
pipe.transformer.save_pretrained(tmp_dir, safe_serialization=False) |
|
|
pipe.remove_all_hooks() |
|
|
del pipe.transformer |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
torch.cuda.synchronize() |
|
|
transformer = FluxTransformer2DModel.from_pretrained( |
|
|
tmp_dir, torch_dtype=torch.bfloat16, use_safetensors=False |
|
|
) |
|
|
pipe.transformer = transformer |
|
|
pipe.enable_model_cpu_offload() |
|
|
|
|
|
weight = transformer.x_embedder.weight |
|
|
self.assertTrue(isinstance(weight, AffineQuantizedTensor)) |
|
|
|
|
|
loaded_output = pipe(**inputs)[0].flatten()[:128] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.assertTrue(np.allclose(output, loaded_output, atol=0.06)) |
|
|
|
|
|
def test_memory_footprint_int4wo(self): |
|
|
|
|
|
expected_memory_in_gb = 6.0 |
|
|
quantization_config = TorchAoConfig("int4wo") |
|
|
cache_dir = None |
|
|
transformer = FluxTransformer2DModel.from_pretrained( |
|
|
"black-forest-labs/FLUX.1-dev", |
|
|
subfolder="transformer", |
|
|
quantization_config=quantization_config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
cache_dir=cache_dir, |
|
|
) |
|
|
int4wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3 |
|
|
self.assertTrue(int4wo_memory_in_gb < expected_memory_in_gb) |
|
|
|
|
|
def test_memory_footprint_int8wo(self): |
|
|
|
|
|
expected_memory_in_gb = 12.0 |
|
|
quantization_config = TorchAoConfig("int8wo") |
|
|
cache_dir = None |
|
|
transformer = FluxTransformer2DModel.from_pretrained( |
|
|
"black-forest-labs/FLUX.1-dev", |
|
|
subfolder="transformer", |
|
|
quantization_config=quantization_config, |
|
|
torch_dtype=torch.bfloat16, |
|
|
cache_dir=cache_dir, |
|
|
) |
|
|
int8wo_memory_in_gb = get_model_size_in_bytes(transformer) / 1024**3 |
|
|
self.assertTrue(int8wo_memory_in_gb < expected_memory_in_gb) |
|
|
|
|
|
|
|
|
@require_torch |
|
|
@require_torch_gpu |
|
|
@require_torchao_version_greater_or_equal("0.7.0") |
|
|
@slow |
|
|
@nightly |
|
|
class SlowTorchAoPreserializedModelTests(unittest.TestCase): |
|
|
def tearDown(self): |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
def get_dummy_inputs(self, device: torch.device, seed: int = 0): |
|
|
if str(device).startswith("mps"): |
|
|
generator = torch.manual_seed(seed) |
|
|
else: |
|
|
generator = torch.Generator().manual_seed(seed) |
|
|
|
|
|
inputs = { |
|
|
"prompt": "an astronaut riding a horse in space", |
|
|
"height": 512, |
|
|
"width": 512, |
|
|
"num_inference_steps": 20, |
|
|
"output_type": "np", |
|
|
"generator": generator, |
|
|
} |
|
|
|
|
|
return inputs |
|
|
|
|
|
def test_transformer_int8wo(self): |
|
|
|
|
|
expected_slice = np.array([0.0566, 0.0781, 0.1426, 0.0488, 0.0684, 0.1504, 0.0625, 0.0781, 0.1445, 0.0625, 0.0781, 0.1562, 0.0547, 0.0723, 0.1484, 0.0566, 0.5703, 0.8867, 0.7266, 0.5742, 0.875, 0.7148, 0.5586, 0.875, 0.7148, 0.5547, 0.8633, 0.7109, 0.5469, 0.8398, 0.6992, 0.5703]) |
|
|
|
|
|
|
|
|
|
|
|
cache_dir = None |
|
|
transformer = FluxTransformer2DModel.from_pretrained( |
|
|
"hf-internal-testing/FLUX.1-Dev-TorchAO-int8wo-transformer", |
|
|
torch_dtype=torch.bfloat16, |
|
|
use_safetensors=False, |
|
|
cache_dir=cache_dir, |
|
|
) |
|
|
pipe = FluxPipeline.from_pretrained( |
|
|
"black-forest-labs/FLUX.1-dev", transformer=transformer, torch_dtype=torch.bfloat16, cache_dir=cache_dir |
|
|
) |
|
|
pipe.enable_model_cpu_offload() |
|
|
|
|
|
|
|
|
for name, module in pipe.transformer.named_modules(): |
|
|
if isinstance(module, nn.Linear): |
|
|
self.assertTrue(isinstance(module.weight, AffineQuantizedTensor)) |
|
|
|
|
|
|
|
|
inputs = self.get_dummy_inputs(torch_device) |
|
|
output = pipe(**inputs)[0].flatten() |
|
|
output_slice = np.concatenate((output[:16], output[-16:])) |
|
|
self.assertTrue(np.allclose(output_slice, expected_slice, atol=1e-3, rtol=1e-3)) |
|
|
|