| | |
| |
|
| | import unittest |
| | from itertools import product |
| |
|
| | import mlx.core as mx |
| | import mlx_tests |
| |
|
| |
|
| | class TestQuantized(mlx_tests.MLXTestCase): |
| | def test_quantize_dequantize(self): |
| | w = mx.random.normal(shape=(128, 512)) |
| | for gs in [32, 64, 128]: |
| | for b in [2, 3, 5, 6, 4, 8]: |
| | with self.subTest(gs=gs, b=b): |
| | w_q, scales, biases = mx.quantize(w, group_size=gs, bits=b) |
| | w_hat = mx.dequantize(w_q, scales, biases, gs, b) |
| | errors = (w - w_hat).abs().reshape(*scales.shape, -1) |
| | eps = 1e-6 |
| | self.assertTrue((errors <= (scales[..., None] + eps).abs()).all()) |
| |
|
| | |
| | a = mx.zeros((256, 512)) |
| | for gs in [32, 64, 128]: |
| | for b in [2, 3, 4, 5, 6, 8]: |
| | w_q, scales, biases = mx.quantize(a, gs, b) |
| | a_hat = mx.dequantize(w_q, scales, biases, gs, b) |
| | self.assertTrue(mx.all(a_hat == 0)) |
| |
|
| | def test_mxfp4_quantize_dequantize(self): |
| | lut = mx.array( |
| | [ |
| | +0.0, |
| | +0.5, |
| | +1.0, |
| | +1.5, |
| | +2.0, |
| | +3.0, |
| | +4.0, |
| | +6.0, |
| | -0.0, |
| | -0.5, |
| | -1.0, |
| | -1.5, |
| | -2.0, |
| | -3.0, |
| | -4.0, |
| | -6.0, |
| | ] |
| | ) |
| | w = lut[mx.random.randint(0, 16, shape=(128, 512))] |
| | w = w.reshape(-1, 32) |
| | w[:, 0] = 6 |
| | w = (w + 3e-6).astype(mx.bfloat16) |
| |
|
| | |
| | with self.assertRaises(ValueError): |
| | mx.quantize(w, bits=3, group_size=32, mode="mxfp4") |
| |
|
| | with self.assertRaises(ValueError): |
| | mx.quantize(w, group_size=64, bits=4, mode="mxfp4") |
| |
|
| | w_q, scales = mx.quantize(w, group_size=32, bits=4, mode="mxfp4") |
| |
|
| | with self.assertRaises(ValueError): |
| | mx.dequantize(w_q, scales, bits=3, group_size=32, mode="mxfp4") |
| |
|
| | with self.assertRaises(ValueError): |
| | mx.dequantize(w_q, scales, group_size=64, bits=4, mode="mxfp4") |
| |
|
| | w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4") |
| | self.assertTrue(mx.allclose(w, w_hat, rtol=1e-5, atol=1e-5)) |
| |
|
| | |
| | a = mx.zeros((256, 512)) |
| | w_q, scales = mx.quantize(a, group_size=32, bits=4, mode="mxfp4") |
| | w_hat = mx.dequantize(w_q, scales, group_size=32, bits=4, mode="mxfp4") |
| | self.assertTrue(mx.all(w_hat == 0)) |
| |
|
| | def test_qmm(self): |
| | key = mx.random.key(0) |
| | k1, k2 = mx.random.split(key) |
| | tests = product( |
| | [128, 64, 32], |
| | [2, 4, 8], |
| | [8, 32, 33, 64], |
| | [128, 256], |
| | [128, 256], |
| | [True, False], |
| | ) |
| | for group_size, bits, M, N, K, transposed in tests: |
| | with self.subTest( |
| | shape=(M, N, K), |
| | group_size=group_size, |
| | bits=bits, |
| | transposed=transposed, |
| | ): |
| | x = mx.random.normal(shape=(M, K), key=k1) |
| | w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2) |
| | w_q, scales, biases = mx.quantize(w, group_size, bits) |
| | w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) |
| | y_q = mx.quantized_matmul( |
| | x, w_q, scales, biases, transposed, group_size, bits |
| | ) |
| | y_hat = (x @ w_hat.T) if transposed else (x @ w_hat) |
| | self.assertEqual(y_q.shape, y_hat.shape) |
| | self.assertLess((y_q - y_hat).abs().max(), 1e-3) |
| |
|
| | def test_qmm_vjp(self): |
| | key = mx.random.key(0) |
| | k1, k2 = mx.random.split(key) |
| |
|
| | bits = 8 |
| | group_size = 64 |
| | M = 64 |
| | N = 1024 |
| | K = 512 |
| |
|
| | x = mx.random.normal(shape=(2, M, K), key=k1) |
| | c = mx.ones(shape=(2, M, N)) |
| |
|
| | transposes = [True, False] |
| | for transposed in transposes: |
| | w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2) |
| | w_q, scales, biases = mx.quantize(w, group_size, bits) |
| |
|
| | def fn(x): |
| | return mx.quantized_matmul( |
| | x, w_q, scales, biases, transposed, group_size, bits |
| | ) |
| |
|
| | _, vjp_out = mx.vjp(fn, primals=(x,), cotangents=(c,)) |
| |
|
| | expected_out = mx.quantized_matmul( |
| | c, w_q, scales, biases, not transposed, group_size, bits |
| | ) |
| | self.assertTrue(mx.allclose(vjp_out[0], expected_out)) |
| |
|
| | def test_qmm_jvp(self): |
| | key = mx.random.key(0) |
| | k1, k2 = mx.random.split(key) |
| |
|
| | bits = 8 |
| | group_size = 64 |
| | M = 64 |
| | N = 128 |
| | K = 128 |
| |
|
| | x = mx.random.normal(shape=(2, M, K), key=k1) |
| | x_tan = mx.ones(shape=(2, M, N)) |
| |
|
| | transposes = [True, False] |
| | for transposed in transposes: |
| | w = mx.random.normal(shape=(N, K) if transposed else (K, N), key=k2) |
| | w_q, scales, biases = mx.quantize(w, group_size, bits) |
| |
|
| | def fn(x): |
| | return mx.quantized_matmul( |
| | x, w_q, scales, biases, transposed, group_size, bits |
| | ) |
| |
|
| | _, jvp_out = mx.jvp(fn, primals=(x,), tangents=(x_tan,)) |
| |
|
| | expected_out = mx.quantized_matmul( |
| | x_tan, w_q, scales, biases, transposed, group_size, bits |
| | ) |
| | self.assertTrue(mx.allclose(jvp_out[0], expected_out)) |
| |
|
| | def test_qmm_shapes(self): |
| | key = mx.random.key(0) |
| | k1, k2 = mx.random.split(key) |
| | group_size = 64 |
| | bits = 4 |
| | w = mx.random.normal(shape=(32, 256), key=k2) |
| | w_q, scales, biases = mx.quantize(w, group_size, bits) |
| | w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) |
| | for s in [(3, 256), (2, 1, 7, 256)]: |
| | x = mx.random.normal(shape=s, key=k1) |
| | y_q = mx.quantized_matmul(x, w_q, scales, biases, True, group_size, bits) |
| | y_hat = x @ w_hat.T |
| | self.assertEqual(y_q.shape, y_hat.shape) |
| | self.assertLess((y_q - y_hat).abs().max(), 1e-3) |
| |
|
| | w = mx.random.normal(shape=(256, 256), key=k2) |
| | w_q, scales, biases = mx.quantize(w, group_size, bits) |
| | w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) |
| | for s in [(3, 256), (2, 1, 7, 256)]: |
| | x = mx.random.normal(shape=s, key=k1) |
| | y_q = mx.quantized_matmul(x, w_q, scales, biases, False, group_size, bits) |
| | y_hat = x @ w_hat |
| | self.assertEqual(y_q.shape, y_hat.shape) |
| | self.assertLess((y_q - y_hat).abs().max(), 1e-3) |
| |
|
| | def test_qmv(self): |
| | key = mx.random.key(0) |
| | k1, k2 = mx.random.split(key) |
| | tests = product( |
| | [128, 64, 32], |
| | [2, 3, 4, 5, 6, 8], |
| | [256, 512, 67], |
| | [64, 128], |
| | [0, 1, 3, 8], |
| | ) |
| | for group_size, bits, M, N, B in tests: |
| | if group_size > N: |
| | continue |
| | with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits): |
| | x_shape = (3, 1, N) if B == 0 else (B, 1, N) |
| | w_shape = (M, N) if B == 0 else (B, M, N) |
| | x = mx.random.normal(shape=x_shape, key=k1) |
| | w = mx.random.normal(shape=w_shape, key=k2) |
| | w_q, scales, biases = mx.quantize(w, group_size, bits) |
| | w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) |
| | y_q = mx.quantized_matmul( |
| | x, w_q, scales, biases, True, group_size, bits |
| | ) |
| | y_hat = x @ mx.swapaxes(w_hat, -1, -2) |
| | self.assertEqual(y_q.shape, y_hat.shape) |
| | self.assertLess((y_q - y_hat).abs().max(), 1e-3) |
| |
|
| | def test_mxfp4_qmv(self): |
| | key = mx.random.key(0) |
| | k1, k2 = mx.random.split(key) |
| | tests = product( |
| | [256, 512, 67], |
| | [64, 128], |
| | [0, 1, 3, 8], |
| | ) |
| | for M, N, B in tests: |
| | with self.subTest(shape=(B, M, N), group_size=32): |
| | x_shape = (3, 1, N) if B == 0 else (B, 1, N) |
| | w_shape = (M, N) if B == 0 else (B, M, N) |
| | x = mx.random.normal(shape=x_shape, key=k1) |
| | w = mx.random.normal(shape=w_shape, key=k2) |
| | w_q, scales = mx.quantize(w, group_size=32, mode="mxfp4") |
| | w_hat = mx.dequantize(w_q, scales, group_size=32, mode="mxfp4") |
| | y_q = mx.quantized_matmul( |
| | x, |
| | w_q, |
| | scales, |
| | transpose=True, |
| | group_size=32, |
| | mode="mxfp4", |
| | ) |
| | y_hat = x @ mx.swapaxes(w_hat, -1, -2) |
| | self.assertEqual(y_q.shape, y_hat.shape) |
| | self.assertLess((y_q - y_hat).abs().max(), 1e-3) |
| |
|
| | def test_qvm(self): |
| | key = mx.random.key(0) |
| | k1, k2 = mx.random.split(key) |
| | tests = product( |
| | [128, 64, 32], |
| | [2, 3, 4, 5, 6, 8], |
| | [32, 128, 256], |
| | [128, 256, 67], |
| | [0, 1, 3, 8], |
| | ) |
| | for group_size, bits, M, N, B in tests: |
| | with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits): |
| | if M < group_size: |
| | continue |
| | x_shape = (1, N) if B == 0 else (B, 1, N) |
| | w_shape = (N, M) if B == 0 else (B, N, M) |
| | x = mx.random.normal(shape=x_shape, key=k1) |
| | w = mx.random.normal(shape=w_shape, key=k2) |
| | w_q, scales, biases = mx.quantize(w, group_size, bits) |
| | w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) |
| | y_q = mx.quantized_matmul( |
| | x, w_q, scales, biases, False, group_size, bits |
| | ) |
| | y_hat = x @ w_hat |
| | self.assertEqual(y_q.shape, y_hat.shape) |
| | self.assertLess((y_q - y_hat).abs().max(), 1e-3) |
| |
|
| | def test_qvm_splitk(self): |
| | key = mx.random.key(0) |
| | k1, k2 = mx.random.split(key) |
| | tests = product( |
| | [128, 64, 32], |
| | [2, 4, 8], |
| | [128], |
| | [16384], |
| | [1, 3], |
| | ) |
| | for group_size, bits, M, N, B in tests: |
| | with self.subTest(shape=(B, M, N), group_size=group_size, bits=bits): |
| | x_shape = (1, N) if B == 0 else (B, 1, N) |
| | w_shape = (N, M) if B == 0 else (B, N, M) |
| | x = 1e-1 * mx.random.normal(shape=x_shape, key=k1) |
| | w = 1e-1 * mx.random.normal(shape=w_shape, key=k2) |
| | w_q, scales, biases = mx.quantize(w, group_size, bits) |
| | w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) |
| | y_q = mx.quantized_matmul( |
| | x, w_q, scales, biases, False, group_size, bits |
| | ) |
| | y_hat = x @ w_hat |
| | self.assertEqual(y_q.shape, y_hat.shape) |
| | self.assertLess((y_q - y_hat).abs().max(), 2e-3) |
| |
|
| | |
| | group_size = 32 |
| | bits = 8 |
| | N = 2048 |
| | x = 1e-1 * mx.random.normal(shape=(N,), key=k1) |
| | w = 1e-1 * mx.random.normal(shape=(N, N), key=k2) |
| | w_q, scales, biases = mx.quantize(w, group_size, bits) |
| | w_hat = mx.dequantize(w_q, scales, biases, group_size, bits) |
| | y_q = mx.quantized_matmul(x, w_q, scales, biases, False, group_size, bits) |
| | y_hat = x @ w_hat |
| | self.assertEqual(y_q.shape, y_hat.shape) |
| | self.assertLess((y_q - y_hat).abs().max(), 2e-3) |
| |
|
| | def test_mxfp4_qvm(self): |
| | key = mx.random.key(0) |
| | k1, k2 = mx.random.split(key) |
| | tests = product( |
| | [32, 128, 256], |
| | [128, 256, 67], |
| | [0, 1, 3, 8], |
| | ) |
| | |
| | tests = list(tests) |
| | tests.append((128, 16384, 0)) |
| |
|
| | for M, N, B in tests: |
| | with self.subTest(shape=(B, M, N)): |
| | x_shape = (1, N) if B == 0 else (B, 1, N) |
| | w_shape = (N, M) if B == 0 else (B, N, M) |
| | x = mx.random.normal(shape=x_shape, key=k1) |
| | w = mx.random.normal(shape=w_shape, key=k2) |
| | w_q, scales = mx.quantize(w, group_size=32, mode="mxfp4") |
| | w_hat = mx.dequantize(w_q, scales, group_size=32, mode="mxfp4") |
| | y_q = mx.quantized_matmul( |
| | x, |
| | w_q, |
| | scales, |
| | transpose=False, |
| | group_size=32, |
| | mode="mxfp4", |
| | ) |
| | y_hat = x @ w_hat |
| | self.assertEqual(y_q.shape, y_hat.shape) |
| | self.assertLess((y_q - y_hat).abs().max(), 2e-3) |
| |
|
| | def test_mode_error_cases(self): |
| | w = mx.random.normal(shape=(256, 256)) |
| | x = mx.random.normal(shape=(1, 256)) |
| |
|
| | |
| | with self.assertRaises(ValueError): |
| | mx.quantize(w, mode="xyz") |
| |
|
| | wq, scales, biases = mx.quantize(w, bits=4, group_size=32) |
| |
|
| | with self.assertRaises(ValueError): |
| | mx.dequantize(wq, scales, biases, bits=4, group_size=32, mode="xyz") |
| |
|
| | with self.assertRaises(ValueError): |
| | mx.quantized_matmul( |
| | x, wq, scales, biases, bits=4, group_size=32, mode="xyz" |
| | ) |
| |
|
| | rhs_indices = mx.array(0) |
| | with self.assertRaises(ValueError): |
| | mx.gather_qmm( |
| | x, |
| | wq, |
| | scales, |
| | biases, |
| | rhs_indices=rhs_indices, |
| | bits=4, |
| | group_size=32, |
| | mode="xyz", |
| | ) |
| |
|
| | |
| | with self.assertRaises(ValueError): |
| | mx.quantize(mx.zeros((128, 128), mx.int32)) |
| |
|
| | with self.assertRaises(ValueError): |
| | mx.quantize(mx.zeros((128, 128), mx.int32), mode="mxfp4") |
| |
|
| | |
| | with self.assertRaises(ValueError): |
| | mx.dequantize(wq, scales, None, bits=4, group_size=32) |
| |
|
| | with self.assertRaises(ValueError): |
| | mx.quantized_matmul(x, wq, scales, None, bits=4, group_size=32) |
| |
|
| | with self.assertRaises(ValueError): |
| | mx.gather_qmm( |
| | x, wq, scales, None, rhs_indices=rhs_indices, bits=4, group_size=32 |
| | ) |
| |
|
| | |
| | x = mx.zeros(shape=(256,), dtype=mx.int32) |
| | scales = mx.zeros(scales.shape, dtype=mx.int32) |
| | biases = mx.zeros(scales.shape, dtype=mx.int32) |
| | with self.assertRaises(ValueError): |
| | mx.dequantize(wq, scales, biases, bits=4, group_size=32) |
| |
|
| | with self.assertRaises(ValueError): |
| | mx.quantized_matmul(x, wq, scales, biases, bits=4, group_size=32) |
| |
|
| | with self.assertRaises(ValueError): |
| | mx.gather_qmm( |
| | x, wq, scales, biases, rhs_indices=rhs_indices, bits=4, group_size=32 |
| | ) |
| |
|
| | def test_throw(self): |
| | x = mx.random.normal(shape=(10, 512)) |
| | w = mx.random.normal(shape=(32, 512)) |
| | w_q, scales, biases = mx.quantize(w) |
| |
|
| | with self.assertRaises(ValueError): |
| | mx.quantized_matmul(x, w_q.T, scales, biases) |
| | with self.assertRaises(ValueError): |
| | mx.quantized_matmul(x, w_q.T, scales.T, biases) |
| | with self.assertRaises(ValueError): |
| | mx.quantized_matmul(x, w_q, scales, biases, False) |
| | with self.assertRaises(ValueError): |
| | mx.quantized_matmul(x, w_q, scales.T, biases.T) |
| | y = mx.quantized_matmul(x, w_q, scales, biases, True) |
| | mx.eval(y) |
| |
|
| | def test_small_matrix(self): |
| | for w_shape in [(8, 256), (1, 8, 256), (3, 8, 256)]: |
| | with self.subTest(w_shape=w_shape): |
| | w = mx.random.normal(shape=(w_shape)) |
| | w_q, scales, biases = mx.quantize(w) |
| | w_hat = mx.dequantize(w_q, scales, biases) |
| |
|
| | |
| | for shape in [(3, 1, 256), (3, 4, 256)]: |
| | x = mx.random.normal(shape=shape) |
| | y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) |
| | y_hat = x @ mx.swapaxes(w_hat, -1, -2) |
| | self.assertEqual(y_q.shape, y_hat.shape) |
| | self.assertLess((y_q - y_hat).abs().max(), 1e-3) |
| |
|
| | |
| | x = mx.random.normal(shape=(3, 10, 256)) |
| | y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) |
| | y_hat = x @ mx.swapaxes(w_hat, -1, -2) |
| | self.assertEqual(y_q.shape, y_hat.shape) |
| | self.assertLess((y_q - y_hat).abs().max(), 1e-3) |
| |
|
| | |
| | x = mx.random.normal(shape=(3, 1, 8)) |
| | y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) |
| | y_hat = x @ w_hat |
| | self.assertEqual(y_q.shape, y_hat.shape) |
| | self.assertLess((y_q - y_hat).abs().max(), 1e-3) |
| |
|
| | |
| | x = mx.random.normal(shape=(3, 10, 8)) |
| | y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) |
| | y_hat = x @ w_hat |
| | self.assertEqual(y_q.shape, y_hat.shape) |
| | self.assertLess((y_q - y_hat).abs().max(), 1e-3) |
| |
|
| | def test_non_multiples(self): |
| | w = mx.random.normal(shape=(33, 256)) |
| | w_q, scales, biases = mx.quantize(w) |
| | w_hat = mx.dequantize(w_q, scales, biases) |
| |
|
| | |
| | x = mx.random.normal(shape=(1, 256)) |
| | y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) |
| | y_hat = x @ w_hat.T |
| | self.assertLess((y_q - y_hat).abs().max(), 1e-3) |
| |
|
| | |
| | x = mx.random.normal(shape=(10, 256)) |
| | y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) |
| | y_hat = x @ w_hat.T |
| | self.assertEqual(y_q.shape, y_hat.shape) |
| | self.assertLess((y_q - y_hat).abs().max(), 1e-3) |
| |
|
| | |
| | x = mx.random.normal(shape=(1, 33)) |
| | y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) |
| | y_hat = x @ w_hat |
| | self.assertEqual(y_q.shape, y_hat.shape) |
| | self.assertLess((y_q - y_hat).abs().max(), 1e-3) |
| |
|
| | |
| | x = mx.random.normal(shape=(10, 33)) |
| | y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) |
| | y_hat = x @ w_hat |
| | self.assertEqual(y_q.shape, y_hat.shape) |
| | self.assertLess((y_q - y_hat).abs().max(), 1e-3) |
| |
|
| | |
| | w = mx.random.normal(shape=(3, 256)) |
| | w_q, scales, biases = mx.quantize(w) |
| | w_hat = mx.dequantize(w_q, scales, biases) |
| |
|
| | |
| | x = mx.random.normal(shape=(1, 256)) |
| | y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) |
| | y_hat = x @ w_hat.T |
| | self.assertLess((y_q - y_hat).abs().max(), 1e-3) |
| |
|
| | |
| | x = mx.random.normal(shape=(10, 256)) |
| | y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) |
| | y_hat = x @ w_hat.T |
| | self.assertEqual(y_q.shape, y_hat.shape) |
| | self.assertLess((y_q - y_hat).abs().max(), 1e-3) |
| |
|
| | |
| | x = mx.random.normal(shape=(1, 3)) |
| | y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) |
| | y_hat = x @ w_hat |
| | self.assertEqual(y_q.shape, y_hat.shape) |
| | self.assertLess((y_q - y_hat).abs().max(), 1e-3) |
| |
|
| | |
| | x = mx.random.normal(shape=(10, 3)) |
| | y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=False) |
| | y_hat = x @ w_hat |
| | self.assertEqual(y_q.shape, y_hat.shape) |
| | self.assertLess((y_q - y_hat).abs().max(), 1e-3) |
| |
|
| | |
| | w = mx.random.normal(shape=(99, 256)) |
| | w_q, scales, biases = mx.quantize(w) |
| | w_hat = mx.dequantize(w_q, scales, biases) |
| | x = mx.random.normal(shape=(129, 256)) |
| | y_q = mx.quantized_matmul(x, w_q, scales, biases, transpose=True) |
| | y_hat = x @ w_hat.T |
| | self.assertEqual(y_q.shape, y_hat.shape) |
| | self.assertLess((y_q - y_hat).abs().max(), 1e-3) |
| |
|
| | def test_gather_qmm(self): |
| | def quantize(w, transpose=True, group_size=64, bits=4, mode="affine"): |
| | if mode == "affine": |
| | qw, s, b = mx.quantize(w, group_size=group_size, bits=bits, mode=mode) |
| | else: |
| | qw, s = mx.quantize(w, group_size=group_size, bits=bits, mode=mode) |
| | b = None |
| | w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits, mode=mode) |
| | if transpose: |
| | w_hat = w_hat.swapaxes(-1, -2) |
| | return w_hat, qw, s, b |
| |
|
| | def test_shape( |
| | M, |
| | N, |
| | K, |
| | dtype=mx.float32, |
| | batch_A=(), |
| | batch_B=(), |
| | lhs_indices=None, |
| | rhs_indices=None, |
| | transpose=True, |
| | group_size=64, |
| | bits=4, |
| | mode="affine", |
| | ): |
| | with self.subTest( |
| | M=M, |
| | N=N, |
| | K=K, |
| | dtype=dtype, |
| | batch_A=batch_A, |
| | batch_B=batch_B, |
| | lhs_indices=lhs_indices, |
| | rhs_indices=rhs_indices, |
| | transpose=transpose, |
| | group_size=group_size, |
| | bits=bits, |
| | mode=mode, |
| | ): |
| | x = mx.random.normal(shape=batch_A + (M, K)).astype(dtype) |
| | w = mx.random.normal( |
| | shape=batch_B + ((N, K) if transpose else (K, N)) |
| | ).astype(dtype) |
| | w_hat, qw, s, b = quantize(w, transpose, group_size, bits, mode=mode) |
| |
|
| | if lhs_indices is not None: |
| | lhs_indices = mx.array(lhs_indices) |
| | if rhs_indices is not None: |
| | rhs_indices = mx.array(rhs_indices) |
| |
|
| | c1 = mx.gather_mm(x, w_hat, lhs_indices, rhs_indices) |
| | c2 = mx.gather_qmm( |
| | x, |
| | qw, |
| | s, |
| | b, |
| | lhs_indices, |
| | rhs_indices, |
| | transpose=transpose, |
| | group_size=group_size, |
| | bits=bits, |
| | mode=mode, |
| | ) |
| | self.assertTrue(mx.allclose(c1, c2, atol=1e-4)) |
| |
|
| | inputs = ( |
| | { |
| | "batch_A": (1,), |
| | "lhs_indices": (0,), |
| | "batch_B": (3,), |
| | "rhs_indices": (2, 1), |
| | }, |
| | { |
| | "batch_A": (1,), |
| | "lhs_indices": None, |
| | "batch_B": (3,), |
| | "rhs_indices": (2, 1), |
| | }, |
| | { |
| | "batch_A": (2,), |
| | "lhs_indices": None, |
| | "batch_B": (3,), |
| | "rhs_indices": (2, 1), |
| | }, |
| | { |
| | "batch_A": (3,), |
| | "lhs_indices": (0, 2), |
| | "batch_B": (1,), |
| | "rhs_indices": (0,), |
| | }, |
| | { |
| | "batch_A": (5,), |
| | "lhs_indices": (0, 2), |
| | "batch_B": (3,), |
| | "rhs_indices": (2, 1), |
| | }, |
| | { |
| | "batch_A": (4, 2), |
| | "lhs_indices": ( |
| | (7, 6), |
| | (5, 4), |
| | (1, 2), |
| | ), |
| | "batch_B": (4, 1), |
| | "rhs_indices": ((2,), (0,), (1,)), |
| | }, |
| | { |
| | "batch_A": (1,), |
| | "lhs_indices": (0,), |
| | "batch_B": (3,), |
| | "rhs_indices": (2, 1), |
| | "group_size": 32, |
| | "mode": "mxfp4", |
| | }, |
| | ) |
| |
|
| | for kwargs in inputs: |
| | test_shape(1, 32, 128, **kwargs) |
| | test_shape(32, 32, 256, **kwargs) |
| | test_shape(1, 32, 256, **kwargs) |
| | test_shape(32, 256, 32, transpose=False, **kwargs) |
| | test_shape(1, 256, 32, transpose=False, **kwargs) |
| | test_shape(32, 32, 512, **kwargs) |
| | test_shape(1, 32, 512, **kwargs) |
| | test_shape(32, 512, 32, transpose=False, **kwargs) |
| | test_shape(1, 512, 32, transpose=False, **kwargs) |
| |
|
| | def test_gather_matmul_grad(self): |
| | def quantize(w, transpose=True, group_size=64, bits=4): |
| | qw, s, b = mx.quantize(w, group_size=group_size, bits=bits) |
| | w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits) |
| | if transpose: |
| | w_hat = w_hat.swapaxes(-1, -2) |
| | return w_hat, qw, s, b |
| |
|
| | lhs_indices = mx.array([[7, 6], [4, 1], [0, 2]], dtype=mx.uint32) |
| | rhs_indices = mx.array([[2], [0], [1]], dtype=mx.uint32) |
| |
|
| | x = mx.random.normal((4, 2, 32, 256)) |
| | w = mx.random.normal((4, 1, 32, 256)) |
| | w_hat, qw, s, b = quantize(w) |
| |
|
| | def f_ref(x, w, i1, i2): |
| | return mx.gather_mm(x, w, i1, i2).sum() |
| |
|
| | def f_test(x, qw, s, b, i1, i2): |
| | return mx.gather_qmm(x, qw, s, b, i1, i2, transpose=True).sum() |
| |
|
| | r1 = f_ref(x, w_hat, lhs_indices, rhs_indices) |
| | r2 = f_test(x, qw, s, b, lhs_indices, rhs_indices) |
| | self.assertTrue(mx.allclose(r1, r2, atol=1e-4)) |
| |
|
| | g1 = mx.grad(f_ref)(x, w_hat, lhs_indices, rhs_indices) |
| | g2 = mx.grad(f_test)(x, qw, s, b, lhs_indices, rhs_indices) |
| | self.assertTrue(mx.allclose(g1, g2, atol=1e-4)) |
| |
|
| | def test_gather_qmm_sorted(self): |
| | def quantize(w, transpose=True, bits=4, group_size=64, mode="affine"): |
| | if mode == "affine": |
| | qw, s, b = mx.quantize(w, group_size=group_size, bits=bits, mode=mode) |
| | else: |
| | qw, s = mx.quantize(w, group_size=group_size, bits=bits, mode=mode) |
| | b = None |
| |
|
| | w_hat = mx.dequantize(qw, s, b, group_size=group_size, bits=bits, mode=mode) |
| | if transpose: |
| | w_hat = w_hat.swapaxes(-1, -2) |
| | return w_hat, qw, s, b |
| |
|
| | def gather_sort(x, indices): |
| | N, M = indices.shape |
| | indices = indices.flatten() |
| | order = mx.argsort(indices) |
| | inv_order = mx.argsort(order) |
| | return x.flatten(0, -3)[order // M], indices[order], inv_order |
| |
|
| | def scatter_unsort(x, inv_order, shape=None): |
| | x = x[inv_order] |
| | if shape is not None: |
| | x = mx.unflatten(x, 0, shape) |
| | return x |
| |
|
| | parameters = [ |
| | |
| | (32, 512, 512, 4, 2, True, "affine"), |
| | (32, 512, 544, 4, 2, True, "mxfp4"), |
| | (133, 512, 512, 4, 2, True, "affine"), |
| | (133, 512, 555, 4, 2, True, "affine"), |
| | (133, 512, 512, 4, 2, True, "affine"), |
| | (64, 512, 512, 4, 2, False, "affine"), |
| | (64, 512, 544, 4, 2, False, "mxfp4"), |
| | (133, 512, 512, 4, 2, False, "affine"), |
| | (133, 512, 544, 4, 2, False, "affine"), |
| | (133, 512, 555, 4, 2, False, "affine"), |
| | (64, 512, 512, 4, 2, False, "affine"), |
| | ] |
| | for L, K, D, E, I, transpose, mode in parameters: |
| | if mode == "mxfp4": |
| | group_size = 32 |
| | else: |
| | group_size = 64 |
| | K, D = (K, D) if transpose else (D, K) |
| | ishape = (L, I) |
| | xshape = (L, 1, 1, K) |
| | wshape = (E, D, K) if transpose else (E, K, D) |
| |
|
| | indices = (mx.random.uniform(shape=ishape) * E).astype(mx.uint32) |
| | x = mx.random.normal(xshape) / K**0.5 |
| | w = mx.random.normal(wshape) / K**0.5 |
| | w, *wq = quantize(w, group_size=group_size, mode=mode, transpose=transpose) |
| |
|
| | y1 = mx.gather_mm(x, w, rhs_indices=indices) |
| | y2 = mx.gather_qmm( |
| | x, |
| | *wq, |
| | group_size=group_size, |
| | mode=mode, |
| | transpose=transpose, |
| | rhs_indices=indices |
| | ) |
| | xs, idx, inv_order = gather_sort(x, indices) |
| | y3 = mx.gather_mm(xs, w, rhs_indices=idx, sorted_indices=True) |
| |
|
| | y4 = mx.gather_qmm( |
| | xs, |
| | *wq, |
| | group_size=group_size, |
| | mode=mode, |
| | rhs_indices=idx, |
| | transpose=transpose, |
| | sorted_indices=True |
| | ) |
| | y3 = scatter_unsort(y3, inv_order, indices.shape) |
| | y4 = scatter_unsort(y4, inv_order, indices.shape) |
| |
|
| | self.assertTrue(mx.allclose(y1, y2, atol=1e-5)) |
| | self.assertTrue(mx.allclose(y1, y3, atol=1e-5)) |
| | self.assertTrue(mx.allclose(y1, y4, atol=1e-5)) |
| |
|
| | def test_gather_qmm_grad(self): |
| | def gather_qmm_ref(x, w, s, b, lhs, rhs, trans, sort): |
| | if lhs is not None: |
| | x = x[lhs] |
| | if rhs is not None: |
| | w = w[rhs] |
| | s = s[rhs] |
| | b = b[rhs] |
| | return mx.quantized_matmul(x, w, s, b, transpose=trans) |
| |
|
| | def gather_qmm(x, w, s, b, lhs, rhs, trans, sort): |
| | return mx.gather_qmm( |
| | x, |
| | w, |
| | s, |
| | b, |
| | transpose=trans, |
| | lhs_indices=lhs, |
| | rhs_indices=rhs, |
| | sorted_indices=sort, |
| | ) |
| |
|
| | x = mx.random.normal((16, 1, 256)) |
| | w, s, b = mx.quantize(mx.random.normal((4, 256, 256))) |
| | indices = mx.sort(mx.random.randint(0, 4, shape=(16,))) |
| | cotan = mx.random.normal((16, 1, 256)) |
| |
|
| | (o1,), (dx1, ds1, db1) = mx.vjp( |
| | lambda x, s, b: gather_qmm_ref(x, w, s, b, None, indices, True, True), |
| | [x, s, b], |
| | [cotan], |
| | ) |
| | (o2,), (dx2, ds2, db2) = mx.vjp( |
| | lambda x, s, b: gather_qmm(x, w, s, b, None, indices, True, True), |
| | [x, s, b], |
| | [cotan], |
| | ) |
| |
|
| | self.assertTrue(mx.allclose(o1, o2, atol=1e-4)) |
| | self.assertTrue(mx.allclose(dx1, dx2, atol=1e-4)) |
| | self.assertTrue(mx.allclose(ds1, ds2, atol=1e-3)) |
| | self.assertTrue(mx.allclose(db1, db2, atol=1e-3)) |
| |
|
| | def test_vjp_scales_biases(self): |
| | mx.random.seed(0) |
| | x = mx.random.normal(shape=(2, 2, 512)) |
| | w = mx.random.normal(shape=(512, 512)) |
| | wq, s, b = mx.quantize(w, bits=4, group_size=64) |
| |
|
| | def mm(sb, x, wq): |
| | return mx.quantized_matmul(x, wq, *sb, bits=4, group_size=64).sum() |
| |
|
| | params = (s, b) |
| | dparams = mx.grad(mm)((s, b), x, wq) |
| |
|
| | eps = 8e-3 |
| | |
| | indices = [(0, 0), (11, 4), (22, 7)] |
| | for idx in indices: |
| | for p in [0, 1]: |
| | params[p][idx] += eps |
| | out_up = mm(params, x, wq) |
| | params[p][idx] -= 2 * eps |
| | out_down = mm(params, x, wq) |
| | params[p][idx] += eps |
| | num_ds = (out_up - out_down) / (2 * eps) |
| | self.assertAlmostEqual(dparams[p][idx], num_ds, delta=2e-2) |
| |
|
| | def test_mxfp4_vjp_scales_throws(self): |
| | mx.random.seed(0) |
| | x = mx.random.normal(shape=(2, 512)) |
| | w = mx.random.normal(shape=(512, 512)) |
| | wq, s = mx.quantize(w, bits=4, group_size=32, mode="mxfp4") |
| |
|
| | def mm(s, x, wq): |
| | return mx.quantized_matmul( |
| | x, wq, s, bits=4, group_size=32, mode="mxfp4" |
| | ).sum() |
| |
|
| | |
| | with self.assertRaises(ValueError): |
| | ds = mx.grad(mm)(s, x, wq) |
| |
|
| | rhs_indices = mx.array(0) |
| | with self.assertRaises(ValueError): |
| |
|
| | def gmm(s, x, wq): |
| | return mx.gather_qmm( |
| | x, |
| | wq, |
| | s, |
| | rhs_indices=rhs_indices, |
| | bits=4, |
| | group_size=32, |
| | mode="mxfp4", |
| | ).sum() |
| |
|
| | ds = mx.grad(gmm)(s, x, wq) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | mlx_tests.MLXTestRunner() |
| |
|