File size: 1,958 Bytes
712dbf0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
// Copyright © 2023 Apple Inc.
#include "doctest/doctest.h"
#include "mlx/mlx.h"
using namespace mlx::core;
TEST_CASE("test type promotion") {
for (auto t : {bool_, uint32, int32, int64, float32}) {
auto a = array(0, t);
CHECK_EQ(result_type({a}), t);
std::vector<array> arrs = {array(0, t), array(0, t)};
CHECK_EQ(result_type(arrs), t);
}
{
std::vector<array> arrs = {array(false), array(0, int32)};
CHECK_EQ(result_type(arrs), int32);
}
{
std::vector<array> arrs = {array(0, int32), array(false), array(0.0f)};
CHECK_EQ(result_type(arrs), float32);
}
}
TEST_CASE("test normalize axis") {
struct TestCase {
int axis;
int ndim;
int expected;
};
std::vector<TestCase> testCases = {
{0, 3, 0}, {1, 3, 1}, {2, 3, 2}, {-1, 3, 2}, {-2, 3, 1}, {-3, 3, 0}};
for (const auto& tc : testCases) {
CHECK_EQ(normalize_axis_index(tc.axis, tc.ndim), tc.expected);
}
CHECK_THROWS(normalize_axis_index(3, 3));
CHECK_THROWS(normalize_axis_index(-4, 3));
}
TEST_CASE("test finfo") {
CHECK_EQ(finfo(float32).dtype, float32);
CHECK_EQ(finfo(complex64).dtype, float32);
CHECK_EQ(finfo(float16).dtype, float16);
CHECK_EQ(finfo(float32).min, std::numeric_limits<float>::lowest());
CHECK_EQ(finfo(float32).max, std::numeric_limits<float>::max());
CHECK_EQ(finfo(complex64).min, std::numeric_limits<float>::lowest());
CHECK_EQ(finfo(complex64).max, std::numeric_limits<float>::max());
CHECK_EQ(finfo(float16).min, -65504);
CHECK_EQ(finfo(float16).max, 65504);
}
TEST_CASE("test iinfo") {
CHECK_EQ(iinfo(int8).dtype, int8);
CHECK_EQ(iinfo(int64).dtype, int64);
CHECK_EQ(iinfo(int64).max, std::numeric_limits<int64_t>::max());
CHECK_EQ(iinfo(uint64).max, std::numeric_limits<uint64_t>::max());
CHECK_EQ(iinfo(uint64).max, std::numeric_limits<uint64_t>::max());
CHECK_EQ(iinfo(uint64).min, 0);
CHECK_EQ(iinfo(int64).min, std::numeric_limits<int64_t>::min());
}
|