File size: 1,958 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
// Copyright © 2023 Apple Inc.

#include "doctest/doctest.h"

#include "mlx/mlx.h"

using namespace mlx::core;

TEST_CASE("test type promotion") {
  for (auto t : {bool_, uint32, int32, int64, float32}) {
    auto a = array(0, t);
    CHECK_EQ(result_type({a}), t);

    std::vector<array> arrs = {array(0, t), array(0, t)};
    CHECK_EQ(result_type(arrs), t);
  }

  {
    std::vector<array> arrs = {array(false), array(0, int32)};
    CHECK_EQ(result_type(arrs), int32);
  }

  {
    std::vector<array> arrs = {array(0, int32), array(false), array(0.0f)};
    CHECK_EQ(result_type(arrs), float32);
  }
}

TEST_CASE("test normalize axis") {
  struct TestCase {
    int axis;
    int ndim;
    int expected;
  };

  std::vector<TestCase> testCases = {
      {0, 3, 0}, {1, 3, 1}, {2, 3, 2}, {-1, 3, 2}, {-2, 3, 1}, {-3, 3, 0}};

  for (const auto& tc : testCases) {
    CHECK_EQ(normalize_axis_index(tc.axis, tc.ndim), tc.expected);
  }

  CHECK_THROWS(normalize_axis_index(3, 3));
  CHECK_THROWS(normalize_axis_index(-4, 3));
}

TEST_CASE("test finfo") {
  CHECK_EQ(finfo(float32).dtype, float32);
  CHECK_EQ(finfo(complex64).dtype, float32);
  CHECK_EQ(finfo(float16).dtype, float16);
  CHECK_EQ(finfo(float32).min, std::numeric_limits<float>::lowest());
  CHECK_EQ(finfo(float32).max, std::numeric_limits<float>::max());
  CHECK_EQ(finfo(complex64).min, std::numeric_limits<float>::lowest());
  CHECK_EQ(finfo(complex64).max, std::numeric_limits<float>::max());
  CHECK_EQ(finfo(float16).min, -65504);
  CHECK_EQ(finfo(float16).max, 65504);
}

TEST_CASE("test iinfo") {
  CHECK_EQ(iinfo(int8).dtype, int8);
  CHECK_EQ(iinfo(int64).dtype, int64);
  CHECK_EQ(iinfo(int64).max, std::numeric_limits<int64_t>::max());
  CHECK_EQ(iinfo(uint64).max, std::numeric_limits<uint64_t>::max());
  CHECK_EQ(iinfo(uint64).max, std::numeric_limits<uint64_t>::max());
  CHECK_EQ(iinfo(uint64).min, 0);
  CHECK_EQ(iinfo(int64).min, std::numeric_limits<int64_t>::min());
}