|
|
|
|
|
|
|
|
#include <filesystem> |
|
|
#include <stdexcept> |
|
|
#include <vector> |
|
|
|
|
|
#include "doctest/doctest.h" |
|
|
|
|
|
#include "mlx/mlx.h" |
|
|
|
|
|
using namespace mlx::core; |
|
|
|
|
|
std::string get_temp_file(const std::string& name) { |
|
|
return std::filesystem::temp_directory_path().append(name).string(); |
|
|
} |
|
|
|
|
|
TEST_CASE("test save_safetensors") { |
|
|
std::string file_path = get_temp_file("test_arr.safetensors"); |
|
|
auto map = std::unordered_map<std::string, array>(); |
|
|
map.insert({"test", array({1.0, 2.0, 3.0, 4.0})}); |
|
|
map.insert({"test2", ones({2, 2})}); |
|
|
auto _metadata = std::unordered_map<std::string, std::string>(); |
|
|
_metadata.insert({"test", "test"}); |
|
|
_metadata.insert({"test2", "test2"}); |
|
|
save_safetensors(file_path, map, _metadata); |
|
|
auto [dict, metadata] = load_safetensors(file_path); |
|
|
|
|
|
CHECK_EQ(metadata, _metadata); |
|
|
|
|
|
CHECK_EQ(dict.size(), 2); |
|
|
CHECK_EQ(dict.count("test"), 1); |
|
|
CHECK_EQ(dict.count("test2"), 1); |
|
|
array test = dict.at("test"); |
|
|
CHECK_EQ(test.dtype(), float32); |
|
|
CHECK_EQ(test.shape(), Shape{4}); |
|
|
CHECK(array_equal(test, array({1.0, 2.0, 3.0, 4.0})).item<bool>()); |
|
|
array test2 = dict.at("test2"); |
|
|
CHECK_EQ(test2.dtype(), float32); |
|
|
CHECK_EQ(test2.shape(), Shape{2, 2}); |
|
|
CHECK(array_equal(test2, ones({2, 2})).item<bool>()); |
|
|
} |
|
|
|
|
|
TEST_CASE("test gguf") { |
|
|
std::string file_path = get_temp_file("test_arr.gguf"); |
|
|
using dict = std::unordered_map<std::string, array>; |
|
|
dict original_weights = { |
|
|
{"test", array({1.0f, 2.0f, 3.0f, 4.0f})}, |
|
|
{"test2", reshape(arange(6), {3, 2})}}; |
|
|
|
|
|
{ |
|
|
|
|
|
save_gguf(file_path, original_weights); |
|
|
auto [loaded_weights, loaded_metadata] = load_gguf(file_path); |
|
|
CHECK_EQ(loaded_metadata.size(), 0); |
|
|
CHECK_EQ(loaded_weights.size(), 2); |
|
|
CHECK_EQ(loaded_weights.count("test"), 1); |
|
|
CHECK_EQ(loaded_weights.count("test2"), 1); |
|
|
for (auto [k, v] : loaded_weights) { |
|
|
CHECK(array_equal(v, original_weights.at(k)).item<bool>()); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
std::unordered_map<std::string, GGUFMetaData> original_metadata; |
|
|
original_metadata.insert({"test_str", "my string"}); |
|
|
|
|
|
save_gguf(file_path, original_weights, original_metadata); |
|
|
auto [loaded_weights, loaded_metadata] = load_gguf(file_path); |
|
|
CHECK_EQ(loaded_metadata.size(), 1); |
|
|
CHECK_EQ(loaded_metadata.count("test_str"), 1); |
|
|
CHECK_EQ(std::get<std::string>(loaded_metadata.at("test_str")), "my string"); |
|
|
|
|
|
CHECK_EQ(loaded_weights.size(), 2); |
|
|
CHECK_EQ(loaded_weights.count("test"), 1); |
|
|
CHECK_EQ(loaded_weights.count("test2"), 1); |
|
|
for (auto [k, v] : loaded_weights) { |
|
|
CHECK(array_equal(v, original_weights.at(k)).item<bool>()); |
|
|
} |
|
|
|
|
|
std::vector<Dtype> unsupported_types = { |
|
|
bool_, uint8, uint32, uint64, int64, bfloat16, complex64}; |
|
|
for (auto t : unsupported_types) { |
|
|
dict to_save = {{"test", astype(arange(5), t)}}; |
|
|
CHECK_THROWS(save_gguf(file_path, to_save, original_metadata)); |
|
|
} |
|
|
|
|
|
std::vector<Dtype> supported_types = {int8, int32, float16, float32}; |
|
|
for (auto t : supported_types) { |
|
|
auto arr = astype(arange(5), t); |
|
|
dict to_save = {{"test", arr}}; |
|
|
save_gguf(file_path, to_save, original_metadata); |
|
|
const auto& [loaded_weights, loaded_metadata] = load_gguf(file_path); |
|
|
CHECK(array_equal(loaded_weights.at("test"), arr).item<bool>()); |
|
|
} |
|
|
} |
|
|
|
|
|
TEST_CASE("test gguf metadata") { |
|
|
std::string file_path = get_temp_file("test_arr.gguf"); |
|
|
using dict = std::unordered_map<std::string, array>; |
|
|
dict original_weights = { |
|
|
{"test", array({1.0f, 2.0f, 3.0f, 4.0f})}, |
|
|
{"test2", reshape(arange(6), {3, 2})}}; |
|
|
|
|
|
|
|
|
{ |
|
|
std::unordered_map<std::string, GGUFMetaData> original_metadata; |
|
|
original_metadata.insert({"test_arr", array(1.0)}); |
|
|
save_gguf(file_path, original_weights, original_metadata); |
|
|
|
|
|
auto [loaded_weights, loaded_metadata] = load_gguf(file_path); |
|
|
CHECK_EQ(loaded_metadata.size(), 1); |
|
|
CHECK_EQ(loaded_metadata.count("test_arr"), 1); |
|
|
|
|
|
auto arr = std::get<array>(loaded_metadata.at("test_arr")); |
|
|
CHECK_EQ(arr.item<float>(), 1.0f); |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
std::unordered_map<std::string, GGUFMetaData> original_metadata; |
|
|
auto arr = array({1.0, 2.0}); |
|
|
original_metadata.insert({"test_arr", arr}); |
|
|
save_gguf(file_path, original_weights, original_metadata); |
|
|
|
|
|
auto [loaded_weights, loaded_metadata] = load_gguf(file_path); |
|
|
CHECK_EQ(loaded_metadata.size(), 1); |
|
|
CHECK_EQ(loaded_metadata.count("test_arr"), 1); |
|
|
|
|
|
auto loaded_arr = std::get<array>(loaded_metadata.at("test_arr")); |
|
|
CHECK(array_equal(arr, loaded_arr).item<bool>()); |
|
|
|
|
|
|
|
|
arr = array({1.0}); |
|
|
original_metadata["test_arr"] = arr; |
|
|
save_gguf(file_path, original_weights, original_metadata); |
|
|
|
|
|
std::tie(loaded_weights, loaded_metadata) = load_gguf(file_path); |
|
|
CHECK_EQ(loaded_metadata.size(), 1); |
|
|
CHECK_EQ(loaded_metadata.count("test_arr"), 1); |
|
|
|
|
|
loaded_arr = std::get<array>(loaded_metadata.at("test_arr")); |
|
|
CHECK(array_equal(arr, loaded_arr).item<bool>()); |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
std::unordered_map<std::string, GGUFMetaData> original_metadata; |
|
|
original_metadata.insert({"test_arr", array({1.0}, {1, 1})}); |
|
|
CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata)); |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
std::unordered_map<std::string, GGUFMetaData> original_metadata; |
|
|
original_metadata.insert({"test_arr", array({})}); |
|
|
CHECK_THROWS(save_gguf(file_path, original_weights, original_metadata)); |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
std::unordered_map<std::string, GGUFMetaData> original_metadata; |
|
|
std::vector<std::string> data = {"data1", "data2", "data1234"}; |
|
|
original_metadata.insert({"meta", data}); |
|
|
save_gguf(file_path, original_weights, original_metadata); |
|
|
|
|
|
auto [loaded_weights, loaded_metadata] = load_gguf(file_path); |
|
|
CHECK_EQ(loaded_metadata.size(), 1); |
|
|
CHECK_EQ(loaded_metadata.count("meta"), 1); |
|
|
auto& strs = std::get<std::vector<std::string>>(loaded_metadata["meta"]); |
|
|
CHECK_EQ(strs.size(), 3); |
|
|
for (int i = 0; i < strs.size(); ++i) { |
|
|
CHECK_EQ(strs[i], data[i]); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
std::unordered_map<std::string, GGUFMetaData> original_metadata; |
|
|
std::vector<std::string> data = {"data1", "data2", "data1234"}; |
|
|
original_metadata.insert({"meta1", data}); |
|
|
original_metadata.insert({"meta2", array(2.5)}); |
|
|
original_metadata.insert({"meta3", array({1, 2, 3})}); |
|
|
original_metadata.insert({"meta4", "last"}); |
|
|
save_gguf(file_path, original_weights, original_metadata); |
|
|
|
|
|
auto [loaded_weights, loaded_metadata] = load_gguf(file_path); |
|
|
CHECK_EQ(loaded_metadata.size(), 4); |
|
|
auto& strs = std::get<std::vector<std::string>>(loaded_metadata["meta1"]); |
|
|
CHECK_EQ(strs.size(), 3); |
|
|
for (int i = 0; i < strs.size(); ++i) { |
|
|
CHECK_EQ(strs[i], data[i]); |
|
|
} |
|
|
auto& arr = std::get<array>(loaded_metadata["meta2"]); |
|
|
CHECK_EQ(arr.item<float>(), 2.5); |
|
|
|
|
|
arr = std::get<array>(loaded_metadata["meta3"]); |
|
|
CHECK(array_equal(arr, array({1, 2, 3})).item<bool>()); |
|
|
|
|
|
auto& str = std::get<std::string>(loaded_metadata["meta4"]); |
|
|
CHECK_EQ(str, "last"); |
|
|
} |
|
|
} |
|
|
|
|
|
TEST_CASE("test single array serialization") { |
|
|
|
|
|
{ |
|
|
auto a = random::uniform(-5.f, 5.f, {2, 5, 12}, float32); |
|
|
|
|
|
std::string file_path = get_temp_file("test_arr.npy"); |
|
|
|
|
|
save(file_path, a); |
|
|
auto b = load(file_path); |
|
|
|
|
|
CHECK_EQ(a.dtype(), b.dtype()); |
|
|
CHECK_EQ(a.shape(), b.shape()); |
|
|
CHECK(array_equal(a, b).item<bool>()); |
|
|
} |
|
|
|
|
|
|
|
|
{ |
|
|
auto a = random::uniform( |
|
|
-5.f, |
|
|
5.f, |
|
|
{ |
|
|
1, |
|
|
}, |
|
|
float32); |
|
|
|
|
|
std::string file_path = get_temp_file("test_arr_0.npy"); |
|
|
|
|
|
save(file_path, a); |
|
|
auto b = load(file_path); |
|
|
|
|
|
CHECK_EQ(a.dtype(), b.dtype()); |
|
|
CHECK_EQ(a.shape(), b.shape()); |
|
|
CHECK(array_equal(a, b).item<bool>()); |
|
|
} |
|
|
|
|
|
{ |
|
|
auto a = random::uniform( |
|
|
-5.f, |
|
|
5.f, |
|
|
{ |
|
|
46, |
|
|
}, |
|
|
float32); |
|
|
|
|
|
std::string file_path = get_temp_file("test_arr_1.npy"); |
|
|
|
|
|
save(file_path, a); |
|
|
auto b = load(file_path); |
|
|
|
|
|
CHECK_EQ(a.dtype(), b.dtype()); |
|
|
CHECK_EQ(a.shape(), b.shape()); |
|
|
CHECK(array_equal(a, b).item<bool>()); |
|
|
} |
|
|
|
|
|
{ |
|
|
auto a = random::uniform(-5.f, 5.f, {5, 2, 1, 3, 4}, float32); |
|
|
|
|
|
std::string file_path = get_temp_file("test_arr_2.npy"); |
|
|
|
|
|
save(file_path, a); |
|
|
auto b = load(file_path); |
|
|
|
|
|
CHECK_EQ(a.dtype(), b.dtype()); |
|
|
CHECK_EQ(a.shape(), b.shape()); |
|
|
CHECK(array_equal(a, b).item<bool>()); |
|
|
} |
|
|
} |
|
|
|