|
|
|
|
|
|
|
|
import unittest |
|
|
|
|
|
import mlx.core as mx |
|
|
import mlx_tests |
|
|
|
|
|
|
|
|
class TestMemory(mlx_tests.MLXTestCase): |
|
|
def test_memory_info(self): |
|
|
old_limit = mx.set_cache_limit(0) |
|
|
|
|
|
a = mx.zeros((4096,)) |
|
|
mx.eval(a) |
|
|
del a |
|
|
self.assertEqual(mx.get_cache_memory(), 0) |
|
|
self.assertEqual(mx.set_cache_limit(old_limit), 0) |
|
|
self.assertEqual(mx.set_cache_limit(old_limit), old_limit) |
|
|
|
|
|
old_limit = mx.set_memory_limit(10) |
|
|
self.assertTrue(mx.set_memory_limit(old_limit), 10) |
|
|
self.assertTrue(mx.set_memory_limit(old_limit), old_limit) |
|
|
|
|
|
|
|
|
a = mx.zeros((4096,)) |
|
|
mx.eval(a) |
|
|
mx.synchronize() |
|
|
active_mem = mx.get_active_memory() |
|
|
self.assertTrue(active_mem >= 4096 * 4) |
|
|
|
|
|
b = mx.zeros((4096,)) |
|
|
mx.eval(b) |
|
|
del b |
|
|
mx.synchronize() |
|
|
|
|
|
new_active_mem = mx.get_active_memory() |
|
|
self.assertEqual(new_active_mem, active_mem) |
|
|
peak_mem = mx.get_peak_memory() |
|
|
self.assertTrue(peak_mem >= 4096 * 8) |
|
|
|
|
|
if mx.metal.is_available(): |
|
|
cache_mem = mx.get_cache_memory() |
|
|
self.assertTrue(cache_mem >= 4096 * 4) |
|
|
|
|
|
mx.clear_cache() |
|
|
self.assertEqual(mx.get_cache_memory(), 0) |
|
|
|
|
|
mx.reset_peak_memory() |
|
|
self.assertEqual(mx.get_peak_memory(), 0) |
|
|
|
|
|
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available") |
|
|
def test_wired_memory(self): |
|
|
old_limit = mx.set_wired_limit(1000) |
|
|
old_limit = mx.set_wired_limit(0) |
|
|
self.assertEqual(old_limit, 1000) |
|
|
|
|
|
max_size = mx.metal.device_info()["max_recommended_working_set_size"] |
|
|
with self.assertRaises(ValueError): |
|
|
mx.set_wired_limit(max_size + 10) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
mlx_tests.MLXTestRunner() |
|
|
|