File size: 1,520 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
// Copyright © 2023-2024 Apple Inc.

#include "doctest/doctest.h"

#include "mlx/mlx.h"

using namespace mlx::core;

TEST_CASE("test simple custom vjp") {
  auto one = array(1.0);
  auto x = array(2.0);
  auto y = array(3.0);

  auto fn = [](const std::vector<array>& inputs) {
    return std::vector<array>{inputs[0] * inputs[1], inputs[0] + inputs[1]};
  };
  auto transformed_fn = custom_vjp(
      fn,
      [&](const std::vector<array>&,
          const std::vector<array>&,
          const std::vector<array>&) { return std::vector<array>{one, one}; });

  auto [z, g] = vjp(fn, {x, y}, {one, one});
  CHECK_EQ(z[0].item<float>(), 6.0f);
  CHECK_EQ(z[1].item<float>(), 5.0f);
  CHECK_EQ(g[0].item<float>(), 4.0f);
  CHECK_EQ(g[1].item<float>(), 3.0f);

  std::tie(z, g) = vjp(transformed_fn, {x, y}, {one, one});
  CHECK_EQ(z[0].item<float>(), 6.0f);
  CHECK_EQ(z[1].item<float>(), 5.0f);
  CHECK_EQ(g[0].item<float>(), 1.0f);
  CHECK_EQ(g[1].item<float>(), 1.0f);
}

TEST_CASE("test checkpointing") {
  auto one = array(1.0);
  auto x = array(2.0);
  auto y = array(3.0);

  int cnt = 0;
  auto fn = [&cnt](const std::vector<array>& inputs) {
    cnt++;
    auto x = inputs[0] * inputs[1];
    auto y = inputs[0] + inputs[1];
    return std::vector<array>{square(x + y)};
  };
  auto checkpointed_fn = checkpoint(fn);

  auto [z, g] = vjp(checkpointed_fn, {x, y}, {one});
  CHECK_EQ(z[0].item<float>(), 121.0f);
  CHECK_EQ(g[0].item<float>(), 88.0f);
  CHECK_EQ(g[1].item<float>(), 66.0f);
  CHECK_EQ(cnt, 2);
}