File size: 2,069 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
// Copyright © 2023 Apple Inc.

#include "doctest/doctest.h"

#include "mlx/mlx.h"

using namespace mlx::core;

TEST_CASE("test eval") {
  {
    array x(1.0);
    array y(1);
    array z(true);
    eval({x, y, z});
    CHECK_EQ(x.item<float>(), 1.0);
  }

  {
    array x(1.0);
    array y = ones({2, 2});
    array z(true);
    eval({x, y, z});
    CHECK(array_equal(y, array({1.0, 1.0, 1.0, 1.0}, {2, 2})).item<bool>());
  }
}

TEST_CASE("test eval multiple") {
  auto x = ones({10, 10});
  auto y = ones({10, 10});
  eval({x, y});
  CHECK(array_equal(x, y).item<bool>());

  auto a = x + y;
  auto b = x - y;
  eval({a, b});
  CHECK(array_equal(a, full({10, 10}, 2.0f)).item<bool>());
  CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());

  x = ones({10, 10});
  y = ones({10, 10});
  eval(x, y);
  CHECK(array_equal(x, y).item<bool>());

  a = x + y;
  b = x - y;
  eval(a, b);
  CHECK(array_equal(a, full({10, 10}, 2.0f)).item<bool>());
  CHECK(array_equal(b, full({10, 10}, 0.0f)).item<bool>());
}

TEST_CASE("test eval with tracer when not tracing") {
  // Since we are not tracing it doesn't matter that the array flags are
  // tracers they will always be detached.
  auto x = array(1);
  x.set_tracer(true);
  CHECK(!x.is_tracer());
  eval(x);
  CHECK(!x.has_primitive());
  CHECK(x.is_available());

  x = ones({2, 3});
  x.set_tracer(true);
  eval(x);
  CHECK(!x.has_primitive());
  CHECK(x.is_available());
}

TEST_CASE("test eval graph retention when not tracing") {
  // Since we are not tracing it doesn't matter that the array flags are
  // tracers they will always be detached.
  auto x = array(1);
  x.set_tracer(true);
  auto y = array(2);
  auto z = x + y;
  eval(z);
  CHECK(!z.has_primitive());
  CHECK(z.is_available());
  CHECK_EQ(z.item<int>(), 3);

  z.set_tracer(false);
  CHECK_EQ(z.item<int>(), 3);
  CHECK(!z.has_primitive());
  CHECK(z.is_available());

  z = x + y;
  auto a = z + x;
  auto b = a + y;
  eval(b);
  CHECK(!z.has_primitive());
  CHECK(z.is_available());
  CHECK(!a.has_primitive());
  CHECK(a.is_available());
}