|
|
|
|
|
|
|
|
#include <climits> |
|
|
|
|
|
#include "doctest/doctest.h" |
|
|
|
|
|
#include "mlx/mlx.h" |
|
|
|
|
|
using namespace mlx::core; |
|
|
|
|
|
TEST_CASE("test array basics") { |
|
|
|
|
|
array x(1.0); |
|
|
CHECK_EQ(x.size(), 1); |
|
|
CHECK_EQ(x.ndim(), 0); |
|
|
CHECK_EQ(x.shape(), Shape{}); |
|
|
CHECK_THROWS_AS(x.shape(0), std::out_of_range); |
|
|
CHECK_THROWS_AS(x.shape(-1), std::out_of_range); |
|
|
CHECK_EQ(x.strides(), Strides{}); |
|
|
CHECK_EQ(x.itemsize(), sizeof(float)); |
|
|
CHECK_EQ(x.nbytes(), sizeof(float)); |
|
|
CHECK_EQ(x.dtype(), float32); |
|
|
CHECK_EQ(x.item<float>(), 1.0); |
|
|
|
|
|
|
|
|
x = array(1, float32); |
|
|
CHECK_EQ(x.dtype(), float32); |
|
|
CHECK_EQ(x.item<float>(), 1.0); |
|
|
|
|
|
|
|
|
x = array(1, bool_); |
|
|
CHECK_EQ(x.dtype(), bool_); |
|
|
CHECK_EQ(x.itemsize(), sizeof(bool)); |
|
|
CHECK_EQ(x.nbytes(), sizeof(bool)); |
|
|
CHECK_EQ(x.item<bool>(), true); |
|
|
|
|
|
|
|
|
x = array({1.0}); |
|
|
CHECK_EQ(x.dtype(), float32); |
|
|
CHECK_EQ(x.size(), 1); |
|
|
CHECK_EQ(x.ndim(), 1); |
|
|
CHECK_EQ(x.shape(), Shape{1}); |
|
|
CHECK_EQ(x.shape(0), 1); |
|
|
CHECK_EQ(x.shape(-1), 1); |
|
|
CHECK_THROWS_AS(x.shape(1), std::out_of_range); |
|
|
CHECK_THROWS_AS(x.shape(-2), std::out_of_range); |
|
|
CHECK_EQ(x.strides(), Strides{1}); |
|
|
CHECK_EQ(x.item<float>(), 1.0); |
|
|
|
|
|
|
|
|
x = array({}); |
|
|
CHECK_EQ(x.size(), 0); |
|
|
CHECK_EQ(x.dtype(), float32); |
|
|
CHECK_EQ(x.itemsize(), sizeof(float)); |
|
|
CHECK_EQ(x.nbytes(), 0); |
|
|
CHECK_THROWS_AS(x.item<float>(), std::invalid_argument); |
|
|
|
|
|
x = array({1.0, 1.0}); |
|
|
CHECK_EQ(x.size(), 2); |
|
|
CHECK_EQ(x.shape(), Shape{2}); |
|
|
CHECK_EQ(x.itemsize(), sizeof(float)); |
|
|
CHECK_EQ(x.nbytes(), x.itemsize() * x.size()); |
|
|
|
|
|
|
|
|
CHECK_THROWS_AS(x.item<float>(), std::invalid_argument); |
|
|
|
|
|
x = array({1.0, 1.0, 1.0}, {1, 3}); |
|
|
CHECK_EQ(x.size(), 3); |
|
|
CHECK_EQ(x.shape(), Shape{1, 3}); |
|
|
CHECK_EQ(x.strides(), Strides{3, 1}); |
|
|
|
|
|
|
|
|
CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {4}), std::invalid_argument); |
|
|
CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {1, 4}), std::invalid_argument); |
|
|
CHECK_THROWS_AS(array({1.0, 1.0, 1.0}, {1, 2}), std::invalid_argument); |
|
|
|
|
|
|
|
|
x = array(1.0); |
|
|
auto y = x; |
|
|
CHECK_EQ(y.id(), x.id()); |
|
|
array z(2.0); |
|
|
CHECK_NE(z.id(), x.id()); |
|
|
z = x; |
|
|
CHECK_EQ(z.id(), x.id()); |
|
|
|
|
|
|
|
|
float data[] = {0.0, 1.0, 2.0, 3.0}; |
|
|
x = array(data, {4}); |
|
|
CHECK_EQ(x.dtype(), float32); |
|
|
CHECK(array_equal(x, array({0.0, 1.0, 2.0, 3.0})).item<bool>()); |
|
|
|
|
|
|
|
|
{ |
|
|
std::vector<int> data = {0, 1, 2, 3}; |
|
|
x = array(data.begin(), {4}); |
|
|
CHECK_EQ(x.dtype(), int32); |
|
|
CHECK(array_equal(x, array({0, 1, 2, 3})).item<bool>()); |
|
|
} |
|
|
|
|
|
{ |
|
|
std::vector<bool> data = {false, true, false, true}; |
|
|
x = array(data.begin(), {4}); |
|
|
CHECK_EQ(x.dtype(), bool_); |
|
|
CHECK(array_equal(x, array({false, true, false, true})).item<bool>()); |
|
|
} |
|
|
} |
|
|
|
|
|
TEST_CASE("test array types") { |
|
|
#define basic_dtype_test(T, mlx_type) \ |
|
|
T val = 42; \ |
|
|
array x(val); \ |
|
|
CHECK_EQ(x.dtype(), mlx_type); \ |
|
|
CHECK_EQ(x.item<T>(), val); \ |
|
|
x = array({val, val}); \ |
|
|
CHECK_EQ(x.dtype(), mlx_type); |
|
|
|
|
|
|
|
|
{ |
|
|
array x(true); |
|
|
CHECK_EQ(x.dtype(), bool_); |
|
|
CHECK_EQ(x.item<bool>(), true); |
|
|
|
|
|
x = array({true, false}); |
|
|
CHECK_EQ(x.dtype(), bool_); |
|
|
|
|
|
x = array({true, false}, float32); |
|
|
CHECK_EQ(x.dtype(), float32); |
|
|
CHECK(array_equal(x, array({1.0f, 0.0f})).item<bool>()); |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
basic_dtype_test(uint8_t, uint8); |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
basic_dtype_test(uint16_t, uint16); |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
basic_dtype_test(uint32_t, uint32); |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
basic_dtype_test(uint64_t, uint64); |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
basic_dtype_test(int8_t, int8); |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
basic_dtype_test(int16_t, int16); |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
basic_dtype_test(int32_t, int32); |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
basic_dtype_test(int64_t, int64); |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
basic_dtype_test(float16_t, float16); |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
basic_dtype_test(float, float32); |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
basic_dtype_test(bfloat16_t, bfloat16); |
|
|
} |
|
|
|
|
|
#undef basic_dtype_test |
|
|
|
|
|
|
|
|
{ |
|
|
uint32_t val = UINT_MAX; |
|
|
array x(val); |
|
|
CHECK_EQ(x.dtype(), uint32); |
|
|
CHECK_EQ(x.item<uint32_t>(), val); |
|
|
|
|
|
x = array({1u, 2u}); |
|
|
CHECK_EQ(x.dtype(), uint32); |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
array x(-1); |
|
|
CHECK_EQ(x.dtype(), int32); |
|
|
CHECK_EQ(x.item<int>(), -1); |
|
|
|
|
|
x = array({-1, 2}); |
|
|
CHECK_EQ(x.dtype(), int32); |
|
|
|
|
|
std::vector<int> data{0, 1, 2}; |
|
|
x = array(data.data(), {static_cast<int>(data.size())}, bool_); |
|
|
CHECK_EQ(x.dtype(), bool_); |
|
|
CHECK(array_equal(x, array({false, true, true})).item<bool>()); |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
int64_t val = static_cast<int64_t>(INT_MIN) - 1; |
|
|
array x(val); |
|
|
CHECK_EQ(x.dtype(), int64); |
|
|
CHECK_EQ(x.item<int64_t>(), val); |
|
|
|
|
|
x = array({val, val}); |
|
|
CHECK_EQ(x.dtype(), int64); |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
array x(3.14f); |
|
|
CHECK_EQ(x.dtype(), float32); |
|
|
CHECK_EQ(x.item<float>(), 3.14f); |
|
|
|
|
|
x = array(1.25); |
|
|
CHECK_EQ(x.dtype(), float32); |
|
|
CHECK_EQ(x.item<float>(), 1.25f); |
|
|
|
|
|
x = array({1.0f, 2.0f}); |
|
|
CHECK_EQ(x.dtype(), float32); |
|
|
|
|
|
x = array({1.0, 2.0}); |
|
|
CHECK_EQ(x.dtype(), float32); |
|
|
|
|
|
std::vector<double> data{1.0, 2.0, 4.0}; |
|
|
x = array(data.data(), {static_cast<int>(data.size())}); |
|
|
CHECK_EQ(x.dtype(), float32); |
|
|
CHECK(array_equal(x, array({1.0f, 2.0f, 4.0f})).item<bool>()); |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
CHECK_EQ(sizeof(complex64_t), sizeof(std::complex<float>)); |
|
|
|
|
|
complex64_t v = {1.0f, 1.0f}; |
|
|
array x(v); |
|
|
CHECK_EQ(x.dtype(), complex64); |
|
|
CHECK_EQ(x.item<complex64_t>(), v); |
|
|
|
|
|
array y(std::complex<float>{1.0f, 1.0f}); |
|
|
CHECK_EQ(x.dtype(), complex64); |
|
|
CHECK_EQ(x.item<complex64_t>(), v); |
|
|
} |
|
|
} |
|
|
|
|
|
TEST_CASE("test array metadata") { |
|
|
array x(1.0f); |
|
|
CHECK_EQ(x.data_size(), 1); |
|
|
CHECK_EQ(x.flags().contiguous, true); |
|
|
CHECK_EQ(x.flags().row_contiguous, true); |
|
|
CHECK_EQ(x.flags().col_contiguous, true); |
|
|
|
|
|
x = array({1.0f}, {1, 1, 1}); |
|
|
CHECK_EQ(x.data_size(), 1); |
|
|
CHECK_EQ(x.flags().contiguous, true); |
|
|
CHECK_EQ(x.flags().row_contiguous, true); |
|
|
CHECK_EQ(x.flags().col_contiguous, true); |
|
|
|
|
|
x = array({1.0f, 1.0f}, {1, 2}); |
|
|
CHECK_EQ(x.data_size(), 2); |
|
|
CHECK_EQ(x.flags().contiguous, true); |
|
|
CHECK_EQ(x.flags().row_contiguous, true); |
|
|
CHECK_EQ(x.flags().col_contiguous, true); |
|
|
|
|
|
x = zeros({1, 1, 4}); |
|
|
eval(x); |
|
|
CHECK_EQ(x.data_size(), 4); |
|
|
CHECK_EQ(x.flags().contiguous, true); |
|
|
CHECK_EQ(x.flags().row_contiguous, true); |
|
|
CHECK_EQ(x.flags().col_contiguous, true); |
|
|
|
|
|
x = zeros({2, 4}); |
|
|
eval(x); |
|
|
CHECK_EQ(x.data_size(), 8); |
|
|
CHECK_EQ(x.flags().contiguous, true); |
|
|
CHECK_EQ(x.flags().row_contiguous, true); |
|
|
CHECK_EQ(x.flags().col_contiguous, false); |
|
|
|
|
|
x = array(1.0f); |
|
|
auto y = broadcast_to(x, {1, 1, 1}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 1); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
y = broadcast_to(x, {2, 8, 10}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 1); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, false); |
|
|
CHECK_EQ(y.flags().col_contiguous, false); |
|
|
|
|
|
y = broadcast_to(x, {1, 0}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 0); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
y = broadcast_to(zeros({4, 2, 1}), {4, 2, 0}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 0); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
x = array(1.0f); |
|
|
y = transpose(x); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 1); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
x = ones({1, 1, 1}); |
|
|
y = transpose(x); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 1); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
x = ones({1, 1, 1}); |
|
|
y = transpose(x, {0, 1, 2}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 1); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
x = ones({1, 1, 1}); |
|
|
y = transpose(x, {1, 2, 0}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 1); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
x = ones({4, 1}); |
|
|
y = transpose(x); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 4); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
x = ones({2, 3, 4}); |
|
|
y = transpose(x); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 24); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, false); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
y = transpose(x, {0, 2, 1}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 24); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, false); |
|
|
CHECK_EQ(y.flags().col_contiguous, false); |
|
|
|
|
|
y = transpose(transpose(x, {0, 2, 1}), {0, 2, 1}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 24); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, false); |
|
|
|
|
|
x = array(1.0f); |
|
|
y = reshape(x, {1, 1, 1}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 1); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
x = ones({2, 4}); |
|
|
y = reshape(x, {8}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 8); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
y = reshape(x, {8, 1, 1}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 8); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
y = reshape(x, {1, 8, 1}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 8); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
x = ones({12}); |
|
|
y = reshape(x, {2, 3, 2}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 12); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, false); |
|
|
|
|
|
x = array(1.0f); |
|
|
y = slice(x, {}, {}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 1); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
x = array({1.0f}); |
|
|
y = slice(x, {-10}, {10}, {10}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 1); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
x = array({1.0f, 2.0f, 3.0f}, {1, 3}); |
|
|
y = slice(x, {0, 0}, {1, 3}, {1, 1}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 3); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
x = array({1.0f, 2.0f, 3.0f}, {1, 3}); |
|
|
y = slice(x, {0, 0}, {1, 3}, {1, 1}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 3); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
x = array({1.0f, 2.0f, 3.0f}, {1, 3}); |
|
|
y = slice(x, {0, 0}, {0, 3}, {1, 1}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 0); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
x = array({1.0f, 2.0f, 3.0f}, {1, 3}); |
|
|
y = slice(x, {0, 0}, {1, 2}, {1, 1}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 2); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
x = array({1.0f, 2.0f, 3.0f}, {1, 3}); |
|
|
y = slice(x, {0, 0}, {1, 2}, {2, 3}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.shape(), Shape{1, 1}); |
|
|
CHECK_EQ(y.data_size(), 1); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
x = array({0.0f, 1.0f, 2.0f, 3.0f}, {1, 4}); |
|
|
y = slice(x, {0, 0}, {1, 4}, {1, 2}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.shape(), Shape{1, 2}); |
|
|
CHECK_EQ(y.flags().contiguous, false); |
|
|
CHECK_EQ(y.flags().row_contiguous, false); |
|
|
CHECK_EQ(y.flags().col_contiguous, false); |
|
|
|
|
|
x = broadcast_to(array(1.0f), {4, 10}); |
|
|
y = slice(x, {0, 0}, {4, 10}, {2, 2}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.shape(), Shape{2, 5}); |
|
|
CHECK_EQ(y.data_size(), 1); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, false); |
|
|
CHECK_EQ(y.flags().col_contiguous, false); |
|
|
|
|
|
x = broadcast_to(array({1.0f, 2.0f}), {4, 2}); |
|
|
y = slice(x, {0, 0}, {1, 2}, {1, 1}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 2); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
y = slice(x, {1, 0}, {2, 2}, {1, 1}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 2); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
x = array({0.0f, 1.0f, 2.0f, 3.0f}, {2, 2}); |
|
|
y = slice(x, {0, 0}, {2, 2}, {1, 1}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 4); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, false); |
|
|
|
|
|
y = slice(transpose(x), {0, 0}, {2, 2}, {1, 1}); |
|
|
eval(y); |
|
|
CHECK_EQ(y.data_size(), 4); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, false); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
|
|
|
x = ones({2, 4}); |
|
|
auto out = split(x, 2); |
|
|
eval(out); |
|
|
for (auto y : out) { |
|
|
CHECK_EQ(y.data_size(), 4); |
|
|
CHECK_EQ(y.flags().contiguous, true); |
|
|
CHECK_EQ(y.flags().row_contiguous, true); |
|
|
CHECK_EQ(y.flags().col_contiguous, true); |
|
|
} |
|
|
out = split(x, 4, 1); |
|
|
eval(out); |
|
|
for (auto y : out) { |
|
|
CHECK_EQ(y.flags().contiguous, false); |
|
|
CHECK_EQ(y.flags().row_contiguous, false); |
|
|
CHECK_EQ(y.flags().col_contiguous, false); |
|
|
} |
|
|
} |
|
|
|
|
|
TEST_CASE("test array iteration") { |
|
|
|
|
|
auto arr = array(1); |
|
|
CHECK_THROWS(arr.begin()); |
|
|
|
|
|
|
|
|
CHECK(std::is_const_v<decltype(*arr.begin())>); |
|
|
|
|
|
arr = array({1, 2, 3, 4, 5}); |
|
|
int i = 0; |
|
|
for (auto a : arr) { |
|
|
i++; |
|
|
CHECK_EQ(a.item<int>(), i); |
|
|
} |
|
|
CHECK_EQ(i, 5); |
|
|
|
|
|
arr = array({1, 2, 3, 4}, {2, 2}); |
|
|
CHECK(array_equal(*arr.begin(), array({1, 2})).item<bool>()); |
|
|
CHECK(array_equal(*(arr.begin() + 1), array({3, 4})).item<bool>()); |
|
|
CHECK_EQ(arr.begin() + 2, arr.end()); |
|
|
} |
|
|
|
|
|
TEST_CASE("test array shared buffer") { |
|
|
Shape shape = {2, 2}; |
|
|
auto n_elem = shape[0] * shape[1]; |
|
|
|
|
|
allocator::Buffer buf_b = allocator::malloc(n_elem * sizeof(float)); |
|
|
void* buf_b_ptr = buf_b.raw_ptr(); |
|
|
float* float_buf_b = (float*)buf_b_ptr; |
|
|
|
|
|
for (int i = 0; i < n_elem; i++) { |
|
|
float_buf_b[i] = 2.; |
|
|
} |
|
|
|
|
|
CHECK_EQ(float_buf_b[0], ((float*)buf_b_ptr)[0]); |
|
|
|
|
|
auto deleter = [float_buf_b](allocator::Buffer buf) { |
|
|
CHECK_EQ(float_buf_b, (float*)buf.raw_ptr()); |
|
|
CHECK_EQ(float_buf_b[0], ((float*)buf.raw_ptr())[0]); |
|
|
allocator::free(buf); |
|
|
}; |
|
|
|
|
|
array a = ones(shape, float32); |
|
|
array b = array(buf_b, shape, float32, deleter); |
|
|
|
|
|
eval(a + b); |
|
|
} |
|
|
|
|
|
TEST_CASE("test make empty array") { |
|
|
auto a = array({}); |
|
|
CHECK_EQ(a.size(), 0); |
|
|
CHECK_EQ(a.dtype(), float32); |
|
|
|
|
|
a = array({}, int32); |
|
|
CHECK_EQ(a.size(), 0); |
|
|
CHECK_EQ(a.dtype(), int32); |
|
|
|
|
|
a = array({}, float32); |
|
|
CHECK_EQ(a.size(), 0); |
|
|
CHECK_EQ(a.dtype(), float32); |
|
|
|
|
|
a = array({}, bool_); |
|
|
CHECK_EQ(a.size(), 0); |
|
|
CHECK_EQ(a.dtype(), bool_); |
|
|
} |
|
|
|