File size: 1,866 Bytes
712dbf0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
# Copyright © 2023-2024 Apple Inc.
import unittest
import mlx.core as mx
import mlx_tests
class TestMemory(mlx_tests.MLXTestCase):
def test_memory_info(self):
old_limit = mx.set_cache_limit(0)
a = mx.zeros((4096,))
mx.eval(a)
del a
self.assertEqual(mx.get_cache_memory(), 0)
self.assertEqual(mx.set_cache_limit(old_limit), 0)
self.assertEqual(mx.set_cache_limit(old_limit), old_limit)
old_limit = mx.set_memory_limit(10)
self.assertTrue(mx.set_memory_limit(old_limit), 10)
self.assertTrue(mx.set_memory_limit(old_limit), old_limit)
# Query active and peak memory
a = mx.zeros((4096,))
mx.eval(a)
mx.synchronize()
active_mem = mx.get_active_memory()
self.assertTrue(active_mem >= 4096 * 4)
b = mx.zeros((4096,))
mx.eval(b)
del b
mx.synchronize()
new_active_mem = mx.get_active_memory()
self.assertEqual(new_active_mem, active_mem)
peak_mem = mx.get_peak_memory()
self.assertTrue(peak_mem >= 4096 * 8)
if mx.metal.is_available():
cache_mem = mx.get_cache_memory()
self.assertTrue(cache_mem >= 4096 * 4)
mx.clear_cache()
self.assertEqual(mx.get_cache_memory(), 0)
mx.reset_peak_memory()
self.assertEqual(mx.get_peak_memory(), 0)
@unittest.skipIf(not mx.metal.is_available(), "Metal is not available")
def test_wired_memory(self):
old_limit = mx.set_wired_limit(1000)
old_limit = mx.set_wired_limit(0)
self.assertEqual(old_limit, 1000)
max_size = mx.metal.device_info()["max_recommended_working_set_size"]
with self.assertRaises(ValueError):
mx.set_wired_limit(max_size + 10)
if __name__ == "__main__":
mlx_tests.MLXTestRunner()
|