| | |
| |
|
| | #include "doctest/doctest.h" |
| |
|
| | #include "mlx/mlx.h" |
| | #include "mlx/primitives.h" |
| |
|
| | using namespace mlx::core; |
| |
|
| | void test_arg_reduce_small( |
| | Device d, |
| | const array& x, |
| | ArgReduce::ReduceType r, |
| | Shape out_shape, |
| | int axis, |
| | std::vector<int> expected_output) { |
| | auto s = default_stream(d); |
| | auto y = |
| | array(out_shape, uint32, std::make_shared<ArgReduce>(s, r, axis), {x}); |
| | y.eval(); |
| | const uint32_t* ydata = y.data<uint32_t>(); |
| | for (int i = 0; i < y.size(); i++) { |
| | CHECK_EQ(expected_output[i], ydata[i]); |
| | } |
| | } |
| |
|
| | void test_arg_reduce_against_cpu( |
| | const array& x, |
| | ArgReduce::ReduceType r, |
| | Shape out_shape, |
| | int axis) { |
| | auto y1 = array( |
| | out_shape, |
| | uint32, |
| | std::make_shared<ArgReduce>(default_stream(Device::cpu), r, axis), |
| | {x}); |
| | auto y2 = array( |
| | out_shape, |
| | uint32, |
| | std::make_shared<ArgReduce>(default_stream(Device::gpu), r, axis), |
| | {x}); |
| | y1.eval(); |
| | y2.eval(); |
| | CHECK(array_equal(y1, y2).item<bool>()); |
| | } |
| |
|
| | TEST_CASE("test arg reduce small") { |
| | auto x = array( |
| | {0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5, |
| | 0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5}, |
| | {2, 3, 4}); |
| | test_arg_reduce_small( |
| | Device::cpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 1, 3, 0, 1, 3}); |
| | test_arg_reduce_small( |
| | Device::cpu, x, ArgReduce::ArgMin, {2, 4}, 1, {0, 1, 1, 2, 0, 1, 1, 2}); |
| | test_arg_reduce_small( |
| | Device::cpu, |
| | x, |
| | ArgReduce::ArgMin, |
| | {3, 4}, |
| | 0, |
| | {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); |
| | test_arg_reduce_small( |
| | Device::cpu, x, ArgReduce::ArgMax, {2, 3}, 2, {3, 0, 1, 3, 0, 1}); |
| | test_arg_reduce_small( |
| | Device::cpu, x, ArgReduce::ArgMax, {2, 4}, 1, {1, 2, 2, 0, 1, 2, 2, 0}); |
| | test_arg_reduce_small( |
| | Device::cpu, |
| | x, |
| | ArgReduce::ArgMax, |
| | {3, 4}, |
| | 0, |
| | {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); |
| |
|
| | if (!metal::is_available()) { |
| | INFO("Skipping arg reduction gpu tests"); |
| | return; |
| | } |
| |
|
| | test_arg_reduce_small( |
| | Device::gpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 1, 3, 0, 1, 3}); |
| | test_arg_reduce_small( |
| | Device::gpu, x, ArgReduce::ArgMin, {2, 4}, 1, {0, 1, 1, 2, 0, 1, 1, 2}); |
| | test_arg_reduce_small( |
| | Device::gpu, |
| | x, |
| | ArgReduce::ArgMin, |
| | {3, 4}, |
| | 0, |
| | {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); |
| | test_arg_reduce_small( |
| | Device::gpu, x, ArgReduce::ArgMax, {2, 3}, 2, {3, 0, 1, 3, 0, 1}); |
| | test_arg_reduce_small( |
| | Device::gpu, x, ArgReduce::ArgMax, {2, 4}, 1, {1, 2, 2, 0, 1, 2, 2, 0}); |
| | test_arg_reduce_small( |
| | Device::gpu, |
| | x, |
| | ArgReduce::ArgMax, |
| | {3, 4}, |
| | 0, |
| | {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); |
| | } |
| |
|
| | TEST_CASE("test arg reduce against cpu") { |
| | if (!metal::is_available()) { |
| | INFO("Skipping arg reduction gpu tests"); |
| | return; |
| | } |
| |
|
| | auto x = random::uniform(array(0.0), array(1.0), {127, 92, 55}); |
| | x.eval(); |
| | test_arg_reduce_against_cpu(x, ArgReduce::ArgMin, {127, 92}, 2); |
| | test_arg_reduce_against_cpu(x, ArgReduce::ArgMin, {127, 55}, 1); |
| | test_arg_reduce_against_cpu(x, ArgReduce::ArgMin, {92, 55}, 0); |
| | test_arg_reduce_against_cpu(x, ArgReduce::ArgMax, {127, 92}, 2); |
| | test_arg_reduce_against_cpu(x, ArgReduce::ArgMax, {127, 55}, 1); |
| | test_arg_reduce_against_cpu(x, ArgReduce::ArgMax, {92, 55}, 0); |
| |
|
| | auto y = random::uniform(array(0.0), array(1.0), {1234}); |
| | y.eval(); |
| | test_arg_reduce_against_cpu(y, ArgReduce::ArgMin, {}, 0); |
| | test_arg_reduce_against_cpu(y, ArgReduce::ArgMax, {}, 0); |
| | } |
| |
|
| | void test_arg_reduce_small_bool( |
| | Device d, |
| | ArgReduce::ReduceType r, |
| | Shape out_shape, |
| | int axis, |
| | std::vector<int> expected_output) { |
| | auto s = default_stream(d); |
| | auto x = array( |
| | {0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5, |
| | 0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5}, |
| | {2, 3, 4}); |
| | x.eval(); |
| | auto y = |
| | array(out_shape, uint32, std::make_shared<ArgReduce>(s, r, axis), {x}); |
| | y.eval(); |
| | const uint32_t* ydata = y.data<uint32_t>(); |
| | for (int i = 0; i < y.size(); i++) { |
| | CHECK_EQ(expected_output[i], ydata[i]); |
| | } |
| | } |
| |
|
| | TEST_CASE("test arg reduce bool") { |
| | if (!metal::is_available()) { |
| | INFO("Skipping arg reduction gpu tests"); |
| | return; |
| | } |
| | auto x = array( |
| | {false, true, true, false, false, false, false, true, |
| | true, false, true, true, false, true, true, false, |
| | false, false, false, true, true, false, true, true}, |
| | {2, 3, 4}); |
| | x.eval(); |
| | test_arg_reduce_small( |
| | Device::gpu, x, ArgReduce::ArgMin, {2, 3}, 2, {0, 0, 1, 0, 0, 1}); |
| | test_arg_reduce_small( |
| | Device::gpu, x, ArgReduce::ArgMin, {2, 4}, 1, {0, 1, 1, 0, 0, 1, 1, 0}); |
| | test_arg_reduce_small( |
| | Device::gpu, |
| | x, |
| | ArgReduce::ArgMin, |
| | {3, 4}, |
| | 0, |
| | {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); |
| | test_arg_reduce_small( |
| | Device::gpu, x, ArgReduce::ArgMax, {2, 3}, 2, {1, 3, 0, 1, 3, 0}); |
| | test_arg_reduce_small( |
| | Device::gpu, x, ArgReduce::ArgMax, {2, 4}, 1, {2, 0, 0, 1, 2, 0, 0, 1}); |
| | test_arg_reduce_small( |
| | Device::gpu, |
| | x, |
| | ArgReduce::ArgMax, |
| | {3, 4}, |
| | 0, |
| | {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}); |
| | } |
| |
|
| | TEST_CASE("test arg reduce edge cases") { |
| | auto a = argmin(array(1.0)); |
| | CHECK_EQ(a.item<uint32_t>(), 0); |
| | auto b = argmax(array(1.0)); |
| | CHECK_EQ(b.item<uint32_t>(), 0); |
| | CHECK_THROWS(argmin(array({}))); |
| | CHECK_THROWS(argmax(array({}))); |
| | } |
| |
|
| | TEST_CASE("test arg reduce irregular strides") { |
| | auto x = array( |
| | {0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5, |
| | 0, 2, 1, 7, 5, -5, 0, 2, 1, 7, 5, -5}, |
| | {2, 3, 4}); |
| | x = transpose(x, {2, 0, 1}); |
| | x.eval(); |
| | test_arg_reduce_small( |
| | Device::cpu, x, ArgReduce::ArgMin, {4, 2}, 2, {0, 0, 1, 1, 1, 1, 2, 2}); |
| |
|
| | if (!metal::is_available()) { |
| | INFO("Skipping arg reduction gpu tests"); |
| | return; |
| | } |
| | } |
| |
|