| | |
| |
|
| | import gc |
| | import os |
| | import tempfile |
| | import unittest |
| |
|
| | import mlx.core as mx |
| | import mlx.nn as nn |
| | import mlx_tests |
| |
|
| |
|
| | class TestExportImport(mlx_tests.MLXTestCase): |
| |
|
| | @classmethod |
| | def setUpClass(cls): |
| | cls.test_dir_fid = tempfile.TemporaryDirectory() |
| | cls.test_dir = cls.test_dir_fid.name |
| | if not os.path.isdir(cls.test_dir): |
| | os.mkdir(cls.test_dir) |
| |
|
| | @classmethod |
| | def tearDownClass(cls): |
| | cls.test_dir_fid.cleanup() |
| |
|
| | def test_basic_export_import(self): |
| | path = os.path.join(self.test_dir, "fn.mlxfn") |
| |
|
| | |
| | def fun(): |
| | return mx.zeros((3, 3)) |
| |
|
| | mx.export_function(path, fun) |
| | imported = mx.import_function(path) |
| |
|
| | expected = fun() |
| | (out,) = imported() |
| | self.assertTrue(mx.array_equal(out, expected)) |
| |
|
| | |
| | def fun(x): |
| | return mx.abs(mx.sin(x)) |
| |
|
| | inputs = mx.array([1.0, 2.0, 3.0, 4.0, 5.0]) |
| |
|
| | mx.export_function(path, fun, inputs) |
| | imported = mx.import_function(path) |
| |
|
| | expected = fun(inputs) |
| | (out,) = imported(inputs) |
| | self.assertTrue(mx.allclose(out, expected)) |
| |
|
| | |
| | def fun(x): |
| | x = mx.abs(mx.sin(x)) |
| | return x |
| |
|
| | mx.export_function(path, fun, [inputs]) |
| | imported = mx.import_function(path) |
| |
|
| | expected = fun(inputs) |
| | (out,) = imported([inputs]) |
| | self.assertTrue(mx.allclose(out, expected)) |
| |
|
| | (out,) = imported(inputs) |
| | self.assertTrue(mx.allclose(out, expected)) |
| |
|
| | mx.export_function(path, fun, (inputs,)) |
| | imported = mx.import_function(path) |
| | (out,) = imported((inputs,)) |
| | self.assertTrue(mx.allclose(out, expected)) |
| |
|
| | |
| | def fun(x): |
| | return [mx.abs(mx.sin(x))] |
| |
|
| | mx.export_function(path, fun, inputs) |
| | imported = mx.import_function(path) |
| | (out,) = imported(inputs) |
| | self.assertTrue(mx.allclose(out, expected)) |
| |
|
| | |
| | def fun(x): |
| | return (mx.abs(mx.sin(x)),) |
| |
|
| | mx.export_function(path, fun, inputs) |
| | imported = mx.import_function(path) |
| | (out,) = imported(inputs) |
| | self.assertTrue(mx.allclose(out, expected)) |
| |
|
| | |
| | def fun(x): |
| | return mx.abs(x) |
| |
|
| | with self.assertRaises(ValueError): |
| | mx.export_function(path, fun, "hi") |
| |
|
| | with self.assertRaises(ValueError): |
| | mx.export_function(path, fun, mx.array(1.0), "hi") |
| |
|
| | def fun(x): |
| | return mx.abs(x[0][0]) |
| |
|
| | with self.assertRaises(ValueError): |
| | mx.export_function(path, fun, [[mx.array(1.0)]]) |
| |
|
| | def fun(): |
| | return (mx.zeros((3, 3)), 1) |
| |
|
| | with self.assertRaises(ValueError): |
| | mx.export_function(path, fun) |
| |
|
| | def fun(): |
| | return (mx.zeros((3, 3)), [mx.zeros((3, 3))]) |
| |
|
| | with self.assertRaises(ValueError): |
| | mx.export_function(path, fun) |
| |
|
| | def fun(x, y): |
| | return x + y |
| |
|
| | mx.export_function(path, fun, mx.array(1.0), mx.array(1.0)) |
| | imported = mx.import_function(path) |
| |
|
| | with self.assertRaises(ValueError): |
| | imported(mx.array(1.0), 1.0) |
| |
|
| | with self.assertRaises(ValueError): |
| | imported(mx.array(1.0), mx.array(1.0), mx.array(1.0)) |
| |
|
| | with self.assertRaises(ValueError): |
| | imported(mx.array(1.0), [mx.array(1.0)]) |
| |
|
| | def test_export_random_sample(self): |
| | path = os.path.join(self.test_dir, "fn.mlxfn") |
| |
|
| | mx.random.seed(5) |
| |
|
| | def fun(): |
| | return mx.random.uniform(shape=(3,)) |
| |
|
| | mx.export_function(path, fun) |
| | imported = mx.import_function(path) |
| |
|
| | (out,) = imported() |
| |
|
| | mx.random.seed(5) |
| | expected = fun() |
| |
|
| | self.assertTrue(mx.array_equal(out, expected)) |
| |
|
| | def test_export_with_kwargs(self): |
| | path = os.path.join(self.test_dir, "fn.mlxfn") |
| |
|
| | def fun(x, z=None): |
| | out = x |
| | if z is not None: |
| | out += z |
| | return out |
| |
|
| | x = mx.array([1, 2, 3]) |
| | y = mx.array([1, 1, 0]) |
| | z = mx.array([2, 2, 2]) |
| |
|
| | mx.export_function(path, fun, (x,), {"z": z}) |
| | imported_fun = mx.import_function(path) |
| |
|
| | with self.assertRaises(ValueError): |
| | imported_fun(x, z) |
| |
|
| | with self.assertRaises(ValueError): |
| | imported_fun(x, y=z) |
| |
|
| | with self.assertRaises(ValueError): |
| | imported_fun((x,), {"y": z}) |
| |
|
| | out = imported_fun(x, z=z)[0] |
| | self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5]))) |
| |
|
| | out = imported_fun((x,), {"z": z})[0] |
| | self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5]))) |
| |
|
| | mx.export_function(path, fun, x, z=z) |
| | imported_fun = mx.import_function(path) |
| | out = imported_fun(x, z=z)[0] |
| | self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5]))) |
| |
|
| | out = imported_fun((x,), {"z": z})[0] |
| | self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5]))) |
| |
|
| | |
| | mx.export_function(path, fun, x=x, z=z) |
| | imported_fun = mx.import_function(path) |
| | with self.assertRaises(ValueError): |
| | out = imported_fun(x, z=z)[0] |
| |
|
| | out = imported_fun(x=x, z=z)[0] |
| | self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5]))) |
| |
|
| | out = imported_fun({"x": x, "z": z})[0] |
| | self.assertTrue(mx.array_equal(out, mx.array([3, 4, 5]))) |
| |
|
| | def test_export_variable_inputs(self): |
| | path = os.path.join(self.test_dir, "fn.mlxfn") |
| |
|
| | def fun(x, y, z=None): |
| | out = x + y |
| | if z is not None: |
| | out += z |
| | return out |
| |
|
| | with mx.exporter(path, fun) as exporter: |
| | exporter(mx.array([1, 2, 3]), mx.array([1, 1, 1])) |
| | exporter(mx.array([1, 2, 3]), mx.array([1, 1, 1]), z=mx.array([2])) |
| |
|
| | with self.assertRaises(RuntimeError): |
| | exporter(mx.array([1, 2, 3, 4]), mx.array([1, 1, 1, 1])) |
| |
|
| | imported_fun = mx.import_function(path) |
| | out = imported_fun(mx.array([1, 2, 3]), mx.array([1, 1, 1]))[0] |
| | self.assertTrue(mx.array_equal(out, mx.array([2, 3, 4]))) |
| |
|
| | out = imported_fun(mx.array([1, 2, 3]), mx.array([1, 1, 1]), z=mx.array([2]))[0] |
| | self.assertTrue(mx.array_equal(out, mx.array([4, 5, 6]))) |
| |
|
| | with self.assertRaises(ValueError): |
| | imported_fun(mx.array([1, 2, 3, 4]), mx.array([1, 1, 1, 1])) |
| |
|
| | |
| | constant = mx.zeros((16, 2048)) |
| | mx.eval(constant) |
| |
|
| | def fun(*args): |
| | return constant + sum(args) |
| |
|
| | with mx.exporter(path, fun) as exporter: |
| | for i in range(5): |
| | exporter(*[mx.array(1)] * i) |
| |
|
| | |
| | constants_size = constant.nbytes + 8192 |
| | self.assertTrue(os.path.getsize(path) < constants_size) |
| |
|
| | def test_leaks(self): |
| | path = os.path.join(self.test_dir, "fn.mlxfn") |
| | mx.synchronize() |
| | if mx.metal.is_available(): |
| | mem_pre = mx.get_active_memory() |
| | else: |
| | mem_pre = 0 |
| |
|
| | def outer(): |
| | d = {} |
| |
|
| | def f(x): |
| | return d["x"] |
| |
|
| | d["f"] = mx.exporter(path, f) |
| | d["x"] = mx.array([0] * 1000) |
| |
|
| | for _ in range(5): |
| | outer() |
| | gc.collect() |
| |
|
| | if mx.metal.is_available(): |
| | mem_post = mx.get_active_memory() |
| | else: |
| | mem_post = 0 |
| |
|
| | self.assertEqual(mem_pre, mem_post) |
| |
|
| | def test_export_import_shapeless(self): |
| | path = os.path.join(self.test_dir, "fn.mlxfn") |
| |
|
| | def fun(*args): |
| | return sum(args) |
| |
|
| | with mx.exporter(path, fun, shapeless=True) as exporter: |
| | exporter(mx.array(1)) |
| | exporter(mx.array(1), mx.array(2)) |
| | exporter(mx.array(1), mx.array(2), mx.array(3)) |
| |
|
| | f2 = mx.import_function(path) |
| | self.assertEqual(f2(mx.array(1))[0].item(), 1) |
| | self.assertEqual(f2(mx.array(1), mx.array(1))[0].item(), 2) |
| | self.assertEqual(f2(mx.array(1), mx.array(1), mx.array(1))[0].item(), 3) |
| | with self.assertRaises(ValueError): |
| | f2(mx.array(10), mx.array([5, 10, 20])) |
| |
|
| | def test_export_scatter_gather(self): |
| | path = os.path.join(self.test_dir, "fn.mlxfn") |
| |
|
| | def fun(a, b): |
| | return mx.take_along_axis(a, b, axis=0) |
| |
|
| | x = mx.random.uniform(shape=(4, 4)) |
| | y = mx.array([[0, 1, 2, 3], [1, 2, 0, 3]]) |
| | mx.export_function(path, fun, (x, y)) |
| | imported_fun = mx.import_function(path) |
| | expected = fun(x, y) |
| | out = imported_fun(x, y)[0] |
| | self.assertTrue(mx.array_equal(expected, out)) |
| |
|
| | def fun(a, b, c): |
| | return mx.put_along_axis(a, b, c, axis=0) |
| |
|
| | x = mx.random.uniform(shape=(4, 4)) |
| | y = mx.array([[0, 1, 2, 3], [1, 2, 0, 3]]) |
| | z = mx.random.uniform(shape=(2, 4)) |
| | mx.export_function(path, fun, (x, y, z)) |
| | imported_fun = mx.import_function(path) |
| | expected = fun(x, y, z) |
| | out = imported_fun(x, y, z)[0] |
| | self.assertTrue(mx.array_equal(expected, out)) |
| |
|
| | def test_export_conv(self): |
| | path = os.path.join(self.test_dir, "fn.mlxfn") |
| |
|
| | class Model(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.c1 = nn.Conv2d( |
| | 3, 16, kernel_size=3, stride=1, padding=1, bias=False |
| | ) |
| | self.c2 = nn.Conv2d( |
| | 16, 16, kernel_size=3, stride=2, padding=1, bias=False |
| | ) |
| | self.c3 = nn.Conv2d( |
| | 16, 16, kernel_size=3, stride=1, padding=2, bias=False |
| | ) |
| |
|
| | def __call__(self, x): |
| | return self.c3(self.c2(self.c1(x))) |
| |
|
| | model = Model() |
| | mx.eval(model.parameters()) |
| |
|
| | def forward(x): |
| | return model(x) |
| |
|
| | input_data = mx.random.normal(shape=(4, 32, 32, 3)) |
| | mx.export_function(path, forward, input_data) |
| |
|
| | imported_fn = mx.import_function(path) |
| | out = imported_fn(input_data)[0] |
| | expected = forward(input_data) |
| | self.assertTrue(mx.allclose(expected, out)) |
| |
|
| | def test_export_conv_shapeless(self): |
| | |
| | path = os.path.join(self.test_dir, "conv1d.mlxfn") |
| |
|
| | class M1(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.c = nn.Conv1d(3, 8, kernel_size=3, stride=2, padding=1, bias=False) |
| |
|
| | def __call__(self, x): |
| | return self.c(x) |
| |
|
| | m1 = M1() |
| | mx.eval(m1.parameters()) |
| |
|
| | def f1(x): |
| | return m1(x) |
| |
|
| | x = mx.random.normal(shape=(4, 64, 3)) |
| | mx.export_function(path, f1, x, shapeless=True) |
| | f1_imp = mx.import_function(path) |
| | for shape in [(4, 64, 3), (1, 33, 3), (2, 128, 3)]: |
| | xt = mx.random.normal(shape=shape) |
| | self.assertTrue(mx.allclose(f1_imp(xt)[0], f1(xt))) |
| |
|
| | |
| | path = os.path.join(self.test_dir, "conv2d.mlxfn") |
| |
|
| | class M2(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.c = nn.Conv2d(3, 6, kernel_size=3, stride=2, padding=1, bias=False) |
| |
|
| | def __call__(self, x): |
| | return self.c(x) |
| |
|
| | m2 = M2() |
| | mx.eval(m2.parameters()) |
| |
|
| | def f2(x): |
| | return m2(x) |
| |
|
| | x = mx.random.normal(shape=(2, 32, 32, 3)) |
| | mx.export_function(path, f2, x, shapeless=True) |
| | f2_imp = mx.import_function(path) |
| | for shape in [(2, 32, 32, 3), (1, 31, 31, 3), (4, 64, 48, 3)]: |
| | xt = mx.random.normal(shape=shape) |
| | self.assertTrue(mx.allclose(f2_imp(xt)[0], f2(xt))) |
| |
|
| | |
| | path = os.path.join(self.test_dir, "conv3d.mlxfn") |
| |
|
| | class M3(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.c = nn.Conv3d(2, 4, kernel_size=3, stride=2, padding=1, bias=False) |
| |
|
| | def __call__(self, x): |
| | return self.c(x) |
| |
|
| | m3 = M3() |
| | mx.eval(m3.parameters()) |
| |
|
| | def f3(x): |
| | return m3(x) |
| |
|
| | x = mx.random.normal(shape=(1, 8, 8, 8, 2)) |
| | mx.export_function(path, f3, x, shapeless=True) |
| | f3_imp = mx.import_function(path) |
| | for shape in [(1, 8, 8, 8, 2), (2, 7, 8, 9, 2), (1, 16, 16, 4, 2)]: |
| | xt = mx.random.normal(shape=shape) |
| | self.assertTrue(mx.allclose(f3_imp(xt)[0], f3(xt))) |
| |
|
| | |
| | path = os.path.join(self.test_dir, "conv2d_grouped.mlxfn") |
| |
|
| | class MG(nn.Module): |
| | def __init__(self): |
| | super().__init__() |
| | self.c = nn.Conv2d( |
| | 4, 6, kernel_size=3, stride=2, padding=1, groups=2, bias=False |
| | ) |
| |
|
| | def __call__(self, x): |
| | return self.c(x) |
| |
|
| | mg = MG() |
| | mx.eval(mg.parameters()) |
| |
|
| | def fg(x): |
| | return mg(x) |
| |
|
| | x = mx.random.normal(shape=(2, 32, 32, 4)) |
| | mx.export_function(path, fg, x, shapeless=True) |
| | fg_imp = mx.import_function(path) |
| | for shape in [(2, 32, 32, 4), (1, 32, 32, 4), (3, 15, 20, 4)]: |
| | xt = mx.random.normal(shape=shape) |
| | self.assertTrue(mx.allclose(fg_imp(xt)[0], fg(xt))) |
| |
|
| | def test_export_control_flow(self): |
| |
|
| | def fun(x, y): |
| | if y.shape[0] <= 2: |
| | return x + y |
| | else: |
| | return x + 2 * y |
| |
|
| | for y in (mx.array([1, 2, 3]), mx.array([1, 2])): |
| | for shapeless in (True, False): |
| | with self.subTest(y=y, shapeless=shapeless): |
| | x = mx.array(1) |
| | export_path = os.path.join(self.test_dir, "control_flow.mlxfn") |
| | mx.export_function(export_path, fun, x, y, shapeless=shapeless) |
| |
|
| | imported_fn = mx.import_function(export_path) |
| | self.assertTrue(mx.array_equal(imported_fn(x, y)[0], fun(x, y))) |
| |
|
| | def test_export_quantized_model(self): |
| | for shapeless in (True, False): |
| | with self.subTest(shapeless=shapeless): |
| | model = nn.Sequential( |
| | nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, 1024) |
| | ) |
| | model.eval() |
| | mx.eval(model.parameters()) |
| | input_data = mx.ones(shape=(512, 1024)) |
| | nn.quantize(model) |
| | self.assertTrue(isinstance(model.layers[0], nn.QuantizedLinear)) |
| | self.assertTrue(isinstance(model.layers[2], nn.QuantizedLinear)) |
| | mx.eval(model.parameters()) |
| |
|
| | export_path = os.path.join(self.test_dir, "quantized_linear.mlxfn") |
| | mx.export_function(export_path, model, input_data, shapeless=shapeless) |
| |
|
| | imported_fn = mx.import_function(export_path) |
| | self.assertTrue( |
| | mx.array_equal(imported_fn(input_data)[0], model(input_data)) |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | mlx_tests.MLXTestRunner() |
| |
|