| | |
| |
|
| | import unittest |
| |
|
| | import mlx.core as mx |
| | import mlx_tests |
| |
|
| |
|
| | class TestMemory(mlx_tests.MLXTestCase): |
| | def test_memory_info(self): |
| | old_limit = mx.set_cache_limit(0) |
| |
|
| | a = mx.zeros((4096,)) |
| | mx.eval(a) |
| | del a |
| | self.assertEqual(mx.get_cache_memory(), 0) |
| | self.assertEqual(mx.set_cache_limit(old_limit), 0) |
| | self.assertEqual(mx.set_cache_limit(old_limit), old_limit) |
| |
|
| | old_limit = mx.set_memory_limit(10) |
| | self.assertTrue(mx.set_memory_limit(old_limit), 10) |
| | self.assertTrue(mx.set_memory_limit(old_limit), old_limit) |
| |
|
| | |
| | a = mx.zeros((4096,)) |
| | mx.eval(a) |
| | mx.synchronize() |
| | active_mem = mx.get_active_memory() |
| | self.assertTrue(active_mem >= 4096 * 4) |
| |
|
| | b = mx.zeros((4096,)) |
| | mx.eval(b) |
| | del b |
| | mx.synchronize() |
| |
|
| | new_active_mem = mx.get_active_memory() |
| | self.assertEqual(new_active_mem, active_mem) |
| | peak_mem = mx.get_peak_memory() |
| | self.assertTrue(peak_mem >= 4096 * 8) |
| |
|
| | if mx.metal.is_available(): |
| | cache_mem = mx.get_cache_memory() |
| | self.assertTrue(cache_mem >= 4096 * 4) |
| |
|
| | mx.clear_cache() |
| | self.assertEqual(mx.get_cache_memory(), 0) |
| |
|
| | mx.reset_peak_memory() |
| | self.assertEqual(mx.get_peak_memory(), 0) |
| |
|
| | @unittest.skipIf(not mx.metal.is_available(), "Metal is not available") |
| | def test_wired_memory(self): |
| | old_limit = mx.set_wired_limit(1000) |
| | old_limit = mx.set_wired_limit(0) |
| | self.assertEqual(old_limit, 1000) |
| |
|
| | max_size = mx.metal.device_info()["max_recommended_working_set_size"] |
| | with self.assertRaises(ValueError): |
| | mx.set_wired_limit(max_size + 10) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | mlx_tests.MLXTestRunner() |
| |
|