Spaces:
Runtime error
Runtime error
| import unittest | |
| import torch | |
| import sys | |
| import os | |
| # Add comfy to path | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..")) | |
| def has_gpu(): | |
| return torch.cuda.is_available() | |
| from comfy.cli_args import args | |
| if not has_gpu(): | |
| args.cpu = True | |
| from comfy.quant_ops import QuantizedTensor, TensorCoreFP8Layout | |
| class TestQuantizedTensor(unittest.TestCase): | |
| """Test the QuantizedTensor subclass with FP8 layout""" | |
| def test_creation(self): | |
| """Test creating a QuantizedTensor with TensorCoreFP8Layout""" | |
| fp8_data = torch.randn(256, 128, dtype=torch.float32).to(torch.float8_e4m3fn) | |
| scale = torch.tensor(2.0) | |
| layout_params = {'scale': scale, 'orig_dtype': torch.bfloat16} | |
| qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) | |
| self.assertIsInstance(qt, QuantizedTensor) | |
| self.assertEqual(qt.shape, (256, 128)) | |
| self.assertEqual(qt.dtype, torch.float8_e4m3fn) | |
| self.assertEqual(qt._layout_params['scale'], scale) | |
| self.assertEqual(qt._layout_params['orig_dtype'], torch.bfloat16) | |
| self.assertEqual(qt._layout_type, "TensorCoreFP8Layout") | |
| def test_dequantize(self): | |
| """Test explicit dequantization""" | |
| fp8_data = torch.ones(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) | |
| scale = torch.tensor(3.0) | |
| layout_params = {'scale': scale, 'orig_dtype': torch.float32} | |
| qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) | |
| dequantized = qt.dequantize() | |
| self.assertEqual(dequantized.dtype, torch.float32) | |
| self.assertTrue(torch.allclose(dequantized, torch.ones(10, 20) * 3.0, rtol=0.1)) | |
| def test_from_float(self): | |
| """Test creating QuantizedTensor from float tensor""" | |
| float_tensor = torch.randn(64, 32, dtype=torch.float32) | |
| scale = torch.tensor(1.5) | |
| qt = QuantizedTensor.from_float( | |
| float_tensor, | |
| "TensorCoreFP8Layout", | |
| scale=scale, | |
| dtype=torch.float8_e4m3fn | |
| ) | |
| self.assertIsInstance(qt, QuantizedTensor) | |
| self.assertEqual(qt.dtype, torch.float8_e4m3fn) | |
| self.assertEqual(qt.shape, (64, 32)) | |
| # Verify dequantization gives approximately original values | |
| dequantized = qt.dequantize() | |
| mean_rel_error = ((dequantized - float_tensor).abs() / (float_tensor.abs() + 1e-6)).mean() | |
| self.assertLess(mean_rel_error, 0.1) | |
| class TestGenericUtilities(unittest.TestCase): | |
| """Test generic utility operations""" | |
| def test_detach(self): | |
| """Test detach operation on quantized tensor""" | |
| fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) | |
| scale = torch.tensor(1.5) | |
| layout_params = {'scale': scale, 'orig_dtype': torch.float32} | |
| qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) | |
| # Detach should return a new QuantizedTensor | |
| qt_detached = qt.detach() | |
| self.assertIsInstance(qt_detached, QuantizedTensor) | |
| self.assertEqual(qt_detached.shape, qt.shape) | |
| self.assertEqual(qt_detached._layout_type, "TensorCoreFP8Layout") | |
| def test_clone(self): | |
| """Test clone operation on quantized tensor""" | |
| fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) | |
| scale = torch.tensor(1.5) | |
| layout_params = {'scale': scale, 'orig_dtype': torch.float32} | |
| qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) | |
| # Clone should return a new QuantizedTensor | |
| qt_cloned = qt.clone() | |
| self.assertIsInstance(qt_cloned, QuantizedTensor) | |
| self.assertEqual(qt_cloned.shape, qt.shape) | |
| self.assertEqual(qt_cloned._layout_type, "TensorCoreFP8Layout") | |
| # Verify it's a deep copy | |
| self.assertIsNot(qt_cloned._qdata, qt._qdata) | |
| def test_to_device(self): | |
| """Test device transfer""" | |
| fp8_data = torch.randn(10, 20, dtype=torch.float32).to(torch.float8_e4m3fn) | |
| scale = torch.tensor(1.5) | |
| layout_params = {'scale': scale, 'orig_dtype': torch.float32} | |
| qt = QuantizedTensor(fp8_data, "TensorCoreFP8Layout", layout_params) | |
| # Moving to same device should work (CPU to CPU) | |
| qt_cpu = qt.to('cpu') | |
| self.assertIsInstance(qt_cpu, QuantizedTensor) | |
| self.assertEqual(qt_cpu.device.type, 'cpu') | |
| self.assertEqual(qt_cpu._layout_params['scale'].device.type, 'cpu') | |
| class TestTensorCoreFP8Layout(unittest.TestCase): | |
| """Test the TensorCoreFP8Layout implementation""" | |
| def test_quantize(self): | |
| """Test quantization method""" | |
| float_tensor = torch.randn(32, 64, dtype=torch.float32) | |
| scale = torch.tensor(1.5) | |
| qdata, layout_params = TensorCoreFP8Layout.quantize( | |
| float_tensor, | |
| scale=scale, | |
| dtype=torch.float8_e4m3fn | |
| ) | |
| self.assertEqual(qdata.dtype, torch.float8_e4m3fn) | |
| self.assertEqual(qdata.shape, float_tensor.shape) | |
| self.assertIn('scale', layout_params) | |
| self.assertIn('orig_dtype', layout_params) | |
| self.assertEqual(layout_params['orig_dtype'], torch.float32) | |
| def test_dequantize(self): | |
| """Test dequantization method""" | |
| float_tensor = torch.ones(10, 20, dtype=torch.float32) * 3.0 | |
| scale = torch.tensor(1.0) | |
| qdata, layout_params = TensorCoreFP8Layout.quantize( | |
| float_tensor, | |
| scale=scale, | |
| dtype=torch.float8_e4m3fn | |
| ) | |
| dequantized = TensorCoreFP8Layout.dequantize(qdata, **layout_params) | |
| # Should approximately match original | |
| self.assertTrue(torch.allclose(dequantized, float_tensor, rtol=0.1, atol=0.1)) | |
| class TestFallbackMechanism(unittest.TestCase): | |
| """Test fallback for unsupported operations""" | |
| def test_unsupported_op_dequantizes(self): | |
| """Test that unsupported operations fall back to dequantization""" | |
| # Set seed for reproducibility | |
| torch.manual_seed(42) | |
| # Create quantized tensor | |
| a_fp32 = torch.randn(10, 20, dtype=torch.float32) | |
| scale = torch.tensor(1.0) | |
| a_q = QuantizedTensor.from_float( | |
| a_fp32, | |
| "TensorCoreFP8Layout", | |
| scale=scale, | |
| dtype=torch.float8_e4m3fn | |
| ) | |
| # Call an operation that doesn't have a registered handler | |
| # For example, torch.abs | |
| result = torch.abs(a_q) | |
| # Should work via fallback (dequantize → abs → return) | |
| self.assertNotIsInstance(result, QuantizedTensor) | |
| expected = torch.abs(a_fp32) | |
| # FP8 introduces quantization error, so use loose tolerance | |
| mean_error = (result - expected).abs().mean() | |
| self.assertLess(mean_error, 0.05, f"Mean error {mean_error:.4f} is too large") | |
| if __name__ == "__main__": | |
| unittest.main() | |