File size: 892 Bytes
712dbf0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
// Copyright © 2023 Apple Inc.

#include <iostream>

#include "mlx/mlx.h"
#include "time_utils.h"

namespace mx = mlx::core;

void time_value_and_grad() {
  auto x = mx::ones({200, 1000});
  mx::eval(x);
  auto fn = [](mx::array x) {
    for (int i = 0; i < 20; ++i) {
      x = mx::log(mx::exp(x));
    }
    return mx::sum(x);
  };

  auto grad_fn = mx::grad(fn);
  auto independent_value_and_grad = [&]() {
    auto value = fn(x);
    auto dfdx = grad_fn(x);
    return std::vector<mx::array>{value, dfdx};
  };
  TIME(independent_value_and_grad);

  auto value_and_grad_fn = mx::value_and_grad(fn);
  auto combined_value_and_grad = [&]() {
    auto [value, dfdx] = value_and_grad_fn(x);
    return std::vector<mx::array>{value, dfdx};
  };
  TIME(combined_value_and_grad);
}

int main() {
  std::cout << "Benchmarks for " << mx::default_device() << std::endl;
  time_value_and_grad();
}