| | import unittest |
| | import torch |
| | import sys |
| | import os |
| | import json |
| |
|
| | |
| | sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) |
| |
|
| | def has_gpu(): |
| | return torch.cuda.is_available() |
| |
|
| | from comfy.cli_args import args |
| | if not has_gpu(): |
| | args.cpu = True |
| |
|
| | from comfy import ops |
| | from comfy.quant_ops import QuantizedTensor |
| | import comfy.utils |
| |
|
| |
|
| | class SimpleModel(torch.nn.Module): |
| | def __init__(self, operations=ops.disable_weight_init): |
| | super().__init__() |
| | self.layer1 = operations.Linear(10, 20, device="cpu", dtype=torch.bfloat16) |
| | self.layer2 = operations.Linear(20, 30, device="cpu", dtype=torch.bfloat16) |
| | self.layer3 = operations.Linear(30, 40, device="cpu", dtype=torch.bfloat16) |
| |
|
| | def forward(self, x): |
| | x = self.layer1(x) |
| | x = torch.nn.functional.relu(x) |
| | x = self.layer2(x) |
| | x = torch.nn.functional.relu(x) |
| | x = self.layer3(x) |
| | return x |
| |
|
| |
|
| | class TestMixedPrecisionOps(unittest.TestCase): |
| |
|
| | def test_all_layers_standard(self): |
| | """Test that model with no quantization works normally""" |
| | |
| | model = SimpleModel(operations=ops.mixed_precision_ops({})) |
| |
|
| | |
| | model.layer1.weight = torch.nn.Parameter(torch.randn(20, 10, dtype=torch.bfloat16)) |
| | model.layer1.bias = torch.nn.Parameter(torch.randn(20, dtype=torch.bfloat16)) |
| | model.layer2.weight = torch.nn.Parameter(torch.randn(30, 20, dtype=torch.bfloat16)) |
| | model.layer2.bias = torch.nn.Parameter(torch.randn(30, dtype=torch.bfloat16)) |
| | model.layer3.weight = torch.nn.Parameter(torch.randn(40, 30, dtype=torch.bfloat16)) |
| | model.layer3.bias = torch.nn.Parameter(torch.randn(40, dtype=torch.bfloat16)) |
| |
|
| | |
| | for layer in [model.layer1, model.layer2, model.layer3]: |
| | layer.weight_function = [] |
| | layer.bias_function = [] |
| |
|
| | |
| | input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) |
| | output = model(input_tensor) |
| |
|
| | self.assertEqual(output.shape, (5, 40)) |
| | self.assertEqual(output.dtype, torch.bfloat16) |
| |
|
| | def test_mixed_precision_load(self): |
| | """Test loading a mixed precision model from state dict""" |
| | |
| | layer_quant_config = { |
| | "layer1": { |
| | "format": "float8_e4m3fn", |
| | "params": {} |
| | }, |
| | "layer3": { |
| | "format": "float8_e4m3fn", |
| | "params": {} |
| | } |
| | } |
| |
|
| | |
| | fp8_weight1 = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) |
| | fp8_weight3 = torch.randn(40, 30, dtype=torch.float32).to(torch.float8_e4m3fn) |
| |
|
| | state_dict = { |
| | |
| | "layer1.weight": fp8_weight1, |
| | "layer1.bias": torch.randn(20, dtype=torch.bfloat16), |
| | "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), |
| |
|
| | |
| | "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), |
| | "layer2.bias": torch.randn(30, dtype=torch.bfloat16), |
| |
|
| | |
| | "layer3.weight": fp8_weight3, |
| | "layer3.bias": torch.randn(40, dtype=torch.bfloat16), |
| | "layer3.weight_scale": torch.tensor(1.5, dtype=torch.float32), |
| | } |
| |
|
| | state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) |
| | |
| | model = SimpleModel(operations=ops.mixed_precision_ops({})) |
| | model.load_state_dict(state_dict, strict=False) |
| |
|
| | |
| | self.assertIsInstance(model.layer1.weight, QuantizedTensor) |
| | self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout") |
| |
|
| | |
| | self.assertNotIsInstance(model.layer2.weight, QuantizedTensor) |
| |
|
| | |
| | self.assertIsInstance(model.layer3.weight, QuantizedTensor) |
| | self.assertEqual(model.layer3.weight._layout_cls, "TensorCoreFP8E4M3Layout") |
| |
|
| | |
| | self.assertEqual(model.layer1.weight._params.scale.item(), 2.0) |
| | self.assertEqual(model.layer3.weight._params.scale.item(), 1.5) |
| |
|
| | |
| | input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) |
| | with torch.inference_mode(): |
| | output = model(input_tensor) |
| |
|
| | self.assertEqual(output.shape, (5, 40)) |
| |
|
| | def test_state_dict_quantized_preserved(self): |
| | """Test that quantized weights are preserved in state_dict()""" |
| | |
| | layer_quant_config = { |
| | "layer1": { |
| | "format": "float8_e4m3fn", |
| | "params": {} |
| | } |
| | } |
| |
|
| | |
| | fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) |
| | state_dict1 = { |
| | "layer1.weight": fp8_weight, |
| | "layer1.bias": torch.randn(20, dtype=torch.bfloat16), |
| | "layer1.weight_scale": torch.tensor(3.0, dtype=torch.float32), |
| | "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), |
| | "layer2.bias": torch.randn(30, dtype=torch.bfloat16), |
| | "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), |
| | "layer3.bias": torch.randn(40, dtype=torch.bfloat16), |
| | } |
| |
|
| | state_dict1, _ = comfy.utils.convert_old_quants(state_dict1, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) |
| | model = SimpleModel(operations=ops.mixed_precision_ops({})) |
| | model.load_state_dict(state_dict1, strict=False) |
| |
|
| | |
| | state_dict2 = model.state_dict() |
| |
|
| | |
| | self.assertTrue(torch.equal(state_dict2["layer1.weight"].view(torch.uint8), fp8_weight.view(torch.uint8))) |
| | self.assertEqual(state_dict2["layer1.weight_scale"].item(), 3.0) |
| | self.assertEqual(model.layer1.weight._layout_cls, "TensorCoreFP8E4M3Layout") |
| |
|
| | |
| | self.assertNotIsInstance(state_dict2["layer2.weight"], QuantizedTensor) |
| | self.assertNotIsInstance(state_dict2["layer3.weight"], QuantizedTensor) |
| |
|
| | def test_weight_function_compatibility(self): |
| | """Test that weight_function (LoRA) works with quantized layers""" |
| | |
| | layer_quant_config = { |
| | "layer1": { |
| | "format": "float8_e4m3fn", |
| | "params": {} |
| | } |
| | } |
| |
|
| | |
| | fp8_weight = torch.randn(20, 10, dtype=torch.float32).to(torch.float8_e4m3fn) |
| | state_dict = { |
| | "layer1.weight": fp8_weight, |
| | "layer1.bias": torch.randn(20, dtype=torch.bfloat16), |
| | "layer1.weight_scale": torch.tensor(2.0, dtype=torch.float32), |
| | "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), |
| | "layer2.bias": torch.randn(30, dtype=torch.bfloat16), |
| | "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), |
| | "layer3.bias": torch.randn(40, dtype=torch.bfloat16), |
| | } |
| |
|
| | state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) |
| | model = SimpleModel(operations=ops.mixed_precision_ops({})) |
| | model.load_state_dict(state_dict, strict=False) |
| |
|
| | |
| | |
| | def apply_lora(weight): |
| | lora_delta = torch.randn_like(weight) * 0.01 |
| | return weight + lora_delta |
| |
|
| | model.layer1.weight_function.append(apply_lora) |
| |
|
| | |
| | input_tensor = torch.randn(5, 10, dtype=torch.bfloat16) |
| | output = model(input_tensor) |
| |
|
| | self.assertEqual(output.shape, (5, 40)) |
| |
|
| | def test_error_handling_unknown_format(self): |
| | """Test that unknown formats raise error""" |
| | |
| | layer_quant_config = { |
| | "layer1": { |
| | "format": "unknown_format_xyz", |
| | "params": {} |
| | } |
| | } |
| |
|
| | |
| | state_dict = { |
| | "layer1.weight": torch.randn(20, 10, dtype=torch.bfloat16), |
| | "layer1.bias": torch.randn(20, dtype=torch.bfloat16), |
| | "layer2.weight": torch.randn(30, 20, dtype=torch.bfloat16), |
| | "layer2.bias": torch.randn(30, dtype=torch.bfloat16), |
| | "layer3.weight": torch.randn(40, 30, dtype=torch.bfloat16), |
| | "layer3.bias": torch.randn(40, dtype=torch.bfloat16), |
| | } |
| |
|
| | state_dict, _ = comfy.utils.convert_old_quants(state_dict, metadata={"_quantization_metadata": json.dumps({"layers": layer_quant_config})}) |
| |
|
| | |
| | model = SimpleModel(operations=ops.mixed_precision_ops({})) |
| | with self.assertRaises(KeyError): |
| | model.load_state_dict(state_dict, strict=False) |
| |
|
| | if __name__ == "__main__": |
| | unittest.main() |
| |
|
| |
|