File size: 4,576 Bytes
712dbf0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
// Copyright © 2024 Apple Inc.
#include <filesystem>
#include <stdexcept>
#include <vector>
#include "doctest/doctest.h"
#include "mlx/export.h"
#include "mlx/mlx.h"
using namespace mlx::core;
namespace {
std::string get_temp_file(const std::string& name) {
return std::filesystem::temp_directory_path().append(name).string();
}
} // namespace
TEST_CASE("test export basic functions") {
std::string file_path = get_temp_file("model.mlxfn");
auto fun = [](std::vector<array> x) -> std::vector<array> {
return {negative(exp(x[0]))};
};
export_function(file_path, fun, {array({1.0, 2.0})});
auto imported_fun = import_function(file_path);
// Check num inputs mismatch throws
CHECK_THROWS_AS(
imported_fun({array({1.0}), array({2.0})}), std::invalid_argument);
// Check shape mismatch throws
CHECK_THROWS_AS(imported_fun({array({1.0})}), std::invalid_argument);
// Check type mismatch throws
CHECK_THROWS_AS(imported_fun({array({1.0}, float16)}), std::invalid_argument);
auto expected = fun({array({1.0, -1.0})});
auto out = imported_fun({array({1.0, -1.0})});
CHECK(allclose(expected[0], out[0]).item<bool>());
}
TEST_CASE("test export function with no inputs") {
auto fun = [](std::vector<array> x) -> std::vector<array> {
return {zeros({2, 2})};
};
std::string file_path = get_temp_file("model.mlxfn");
export_function(file_path, fun, {});
auto imported_fun = import_function(file_path);
auto expected = fun({});
auto out = imported_fun({});
CHECK(allclose(expected[0], out[0]).item<bool>());
}
TEST_CASE("test export multi output primitives") {
std::string file_path = get_temp_file("model.mlxfn");
auto fun = [](std::vector<array> x) -> std::vector<array> {
return {divmod(x[0], x[1])};
};
auto inputs = std::vector<array>{array({5.0, -10.0}), array({3.0, -2.0})};
export_function(file_path, fun, inputs);
auto imported_fun = import_function(file_path);
auto expected = fun(inputs);
auto out = imported_fun(inputs);
CHECK(allclose(expected[0], out[0]).item<bool>());
CHECK(allclose(expected[1], out[1]).item<bool>());
}
TEST_CASE("test export primitives with state") {
std::string file_path = get_temp_file("model.mlxfn");
auto fun = [](std::vector<array> x) -> std::vector<array> {
return {argpartition(x[0], 2, 0)};
};
auto x = array({1, 3, 2, 4, 5, 7, 6, 8}, {4, 2});
export_function(file_path, fun, {x});
auto imported_fun = import_function(file_path);
auto expected = fun({x});
auto out = imported_fun({x});
CHECK(allclose(expected[0], out[0]).item<bool>());
}
TEST_CASE("test export functions with kwargs") {
std::string file_path = get_temp_file("model.mlxfn");
auto fun = [](const Kwargs& kwargs) -> std::vector<array> {
return {kwargs.at("x") + kwargs.at("y")};
};
export_function(file_path, fun, {{"x", array(1)}, {"y", array(2)}});
auto fn = import_function(file_path);
// Must use kwargs
CHECK_THROWS(fn({array(1), array(2)}));
// Wrong number of keys
CHECK_THROWS(fn({{"x", array(1)}, {"y", array(2)}, {"z", array(3)}}));
// Wrong keys
CHECK_THROWS(fn({{"a", array(1)}, {"b", array(2)}}));
// Works
auto out = fn({{"x", array(1)}, {"y", array(2)}})[0];
CHECK_EQ(out.item<int>(), 3);
out = fn({}, {{"x", array(1)}, {"y", array(2)}})[0];
CHECK_EQ(out.item<int>(), 3);
}
TEST_CASE("test export function with variable inputs") {
std::string file_path = get_temp_file("model.mlxfn");
auto fun = [](const std::vector<array>& args) -> std::vector<array> {
auto out = array({1, 1, 1, 1});
for (auto x : args) {
out = out + x;
}
return {out};
};
{
auto fn_exporter = exporter(file_path, fun);
fn_exporter({array(0), array(0)});
fn_exporter({array(0), array(0), array(0)});
}
auto imported_fun = import_function(file_path);
// Call with two inputs
auto out = imported_fun({array(1), array(2)})[0];
CHECK(array_equal(out, array({4, 4, 4, 4})).item<bool>());
// Call with three inputs
out = imported_fun({array(1), array(2), array(3)})[0];
CHECK(array_equal(out, array({7, 7, 7, 7})).item<bool>());
}
TEST_CASE("test export function on different stream") {
std::string file_path = get_temp_file("model.mlxfn");
// Caller is responsible for setting up streams before
// importing functoins
auto fun = [](const std::vector<array>& args) -> std::vector<array> {
return {abs(args[0], Stream(1000, Device::cpu))};
};
export_function(file_path, fun, {array({0, 1, 2})});
CHECK_THROWS(import_function(file_path));
}
|