|
|
|
|
|
|
|
|
#include "doctest/doctest.h" |
|
|
|
|
|
#include "mlx/mlx.h" |
|
|
|
|
|
using namespace mlx::core; |
|
|
|
|
|
TEST_CASE("test eval") { |
|
|
{ |
|
|
array x(1.0); |
|
|
array y(1); |
|
|
array z(true); |
|
|
eval({x, y, z}); |
|
|
CHECK_EQ(x.item<float>(), 1.0); |
|
|
} |
|
|
|
|
|
{ |
|
|
array x(1.0); |
|
|
array y = ones({2, 2}); |
|
|
array z(true); |
|
|
eval({x, y, z}); |
|
|
CHECK(array_equal(y, array({1.0, 1.0, 1.0, 1.0}, {2, 2})).item<bool>()); |
|
|
} |
|
|
} |
|
|
|
|
|
TEST_CASE("test eval multiple") { |
|
|
auto x = ones({10, 10}); |
|
|
auto y = ones({10, 10}); |
|
|
eval({x, y}); |
|
|
CHECK(array_equal(x, y).item<bool>()); |
|
|
|
|
|
auto a = x + y; |
|
|
auto b = x - y; |
|
|
eval({a, b}); |
|
|
CHECK(array_equal(a, full({10, 10}, 2.0f)).item<bool>()); |
|
|
CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>()); |
|
|
|
|
|
x = ones({10, 10}); |
|
|
y = ones({10, 10}); |
|
|
eval(x, y); |
|
|
CHECK(array_equal(x, y).item<bool>()); |
|
|
|
|
|
a = x + y; |
|
|
b = x - y; |
|
|
eval(a, b); |
|
|
CHECK(array_equal(a, full({10, 10}, 2.0f)).item<bool>()); |
|
|
CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>()); |
|
|
} |
|
|
|
|
|
TEST_CASE("test eval with tracer when not tracing") { |
|
|
|
|
|
|
|
|
auto x = array(1); |
|
|
x.set_tracer(true); |
|
|
CHECK(!x.is_tracer()); |
|
|
eval(x); |
|
|
CHECK(!x.has_primitive()); |
|
|
CHECK(x.is_available()); |
|
|
|
|
|
x = ones({2, 3}); |
|
|
x.set_tracer(true); |
|
|
eval(x); |
|
|
CHECK(!x.has_primitive()); |
|
|
CHECK(x.is_available()); |
|
|
} |
|
|
|
|
|
TEST_CASE("test eval graph retention when not tracing") { |
|
|
|
|
|
|
|
|
auto x = array(1); |
|
|
x.set_tracer(true); |
|
|
auto y = array(2); |
|
|
auto z = x + y; |
|
|
eval(z); |
|
|
CHECK(!z.has_primitive()); |
|
|
CHECK(z.is_available()); |
|
|
CHECK_EQ(z.item<int>(), 3); |
|
|
|
|
|
z.set_tracer(false); |
|
|
CHECK_EQ(z.item<int>(), 3); |
|
|
CHECK(!z.has_primitive()); |
|
|
CHECK(z.is_available()); |
|
|
|
|
|
z = x + y; |
|
|
auto a = z + x; |
|
|
auto b = a + y; |
|
|
eval(b); |
|
|
CHECK(!z.has_primitive()); |
|
|
CHECK(z.is_available()); |
|
|
CHECK(!a.has_primitive()); |
|
|
CHECK(a.is_available()); |
|
|
} |
|
|
|