|
|
|
|
|
|
|
|
#include <nanobind/stl/vector.h> |
|
|
#include <cstring> |
|
|
#include <fstream> |
|
|
#include <stdexcept> |
|
|
#include <string_view> |
|
|
#include <unordered_map> |
|
|
#include <vector> |
|
|
|
|
|
#include "mlx/io/load.h" |
|
|
#include "mlx/ops.h" |
|
|
#include "mlx/utils.h" |
|
|
#include "python/src/load.h" |
|
|
#include "python/src/small_vector.h" |
|
|
#include "python/src/utils.h" |
|
|
|
|
|
namespace mx = mlx::core; |
|
|
namespace nb = nanobind; |
|
|
using namespace nb::literals; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
bool is_str_or_path(nb::object obj) { |
|
|
if (nb::isinstance<nb::str>(obj)) { |
|
|
return true; |
|
|
} |
|
|
nb::object path_type = nb::module_::import_("pathlib").attr("Path"); |
|
|
return nb::isinstance(obj, path_type); |
|
|
} |
|
|
|
|
|
bool is_istream_object(const nb::object& file) { |
|
|
return nb::hasattr(file, "readinto") && nb::hasattr(file, "seek") && |
|
|
nb::hasattr(file, "tell") && nb::hasattr(file, "closed"); |
|
|
} |
|
|
|
|
|
bool is_ostream_object(const nb::object& file) { |
|
|
return nb::hasattr(file, "write") && nb::hasattr(file, "seek") && |
|
|
nb::hasattr(file, "tell") && nb::hasattr(file, "closed"); |
|
|
} |
|
|
|
|
|
bool is_zip_file(const nb::module_& zipfile, const nb::object& file) { |
|
|
if (is_istream_object(file)) { |
|
|
auto st_pos = file.attr("tell")(); |
|
|
bool r = nb::cast<bool>(zipfile.attr("is_zipfile")(file)); |
|
|
file.attr("seek")(st_pos, 0); |
|
|
return r; |
|
|
} |
|
|
return nb::cast<bool>(zipfile.attr("is_zipfile")(file)); |
|
|
} |
|
|
|
|
|
class ZipFileWrapper { |
|
|
public: |
|
|
ZipFileWrapper( |
|
|
const nb::module_& zipfile, |
|
|
const nb::object& file, |
|
|
char mode = 'r', |
|
|
int compression = 0) |
|
|
: zipfile_module_(zipfile), |
|
|
zipfile_object_(zipfile.attr("ZipFile")( |
|
|
file, |
|
|
"mode"_a = mode, |
|
|
"compression"_a = compression, |
|
|
"allowZip64"_a = true)), |
|
|
files_list_(zipfile_object_.attr("namelist")()), |
|
|
open_func_(zipfile_object_.attr("open")), |
|
|
read_func_(zipfile_object_.attr("read")), |
|
|
close_func_(zipfile_object_.attr("close")) {} |
|
|
|
|
|
std::vector<std::string> namelist() const { |
|
|
return nb::cast<std::vector<std::string>>(files_list_); |
|
|
} |
|
|
|
|
|
nb::object open(const std::string& key, char mode = 'r') { |
|
|
|
|
|
|
|
|
if (mode == 'w') { |
|
|
return open_func_(key, "mode"_a = mode, "force_zip64"_a = true); |
|
|
} |
|
|
return open_func_(key, "mode"_a = mode); |
|
|
} |
|
|
|
|
|
private: |
|
|
nb::module_ zipfile_module_; |
|
|
nb::object zipfile_object_; |
|
|
nb::list files_list_; |
|
|
nb::object open_func_; |
|
|
nb::object read_func_; |
|
|
nb::object close_func_; |
|
|
}; |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PyFileReader : public mx::io::Reader { |
|
|
public: |
|
|
PyFileReader(nb::object file) |
|
|
: pyistream_(file), |
|
|
readinto_func_(file.attr("readinto")), |
|
|
seek_func_(file.attr("seek")), |
|
|
tell_func_(file.attr("tell")) {} |
|
|
|
|
|
~PyFileReader() { |
|
|
nb::gil_scoped_acquire gil; |
|
|
|
|
|
pyistream_.release().dec_ref(); |
|
|
readinto_func_.release().dec_ref(); |
|
|
seek_func_.release().dec_ref(); |
|
|
tell_func_.release().dec_ref(); |
|
|
} |
|
|
|
|
|
bool is_open() const override { |
|
|
bool out; |
|
|
{ |
|
|
nb::gil_scoped_acquire gil; |
|
|
out = !nb::cast<bool>(pyistream_.attr("closed")); |
|
|
} |
|
|
return out; |
|
|
} |
|
|
|
|
|
bool good() const override { |
|
|
bool out; |
|
|
{ |
|
|
nb::gil_scoped_acquire gil; |
|
|
out = !pyistream_.is_none(); |
|
|
} |
|
|
return out; |
|
|
} |
|
|
|
|
|
size_t tell() override { |
|
|
size_t out; |
|
|
{ |
|
|
nb::gil_scoped_acquire gil; |
|
|
out = nb::cast<size_t>(tell_func_()); |
|
|
} |
|
|
return out; |
|
|
} |
|
|
|
|
|
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) |
|
|
override { |
|
|
nb::gil_scoped_acquire gil; |
|
|
seek_func_(off, (int)way); |
|
|
} |
|
|
|
|
|
void read(char* data, size_t n) override { |
|
|
nb::gil_scoped_acquire gil; |
|
|
_read(data, n); |
|
|
} |
|
|
|
|
|
void read(char* data, size_t n, size_t offset) override { |
|
|
nb::gil_scoped_acquire gil; |
|
|
seek_func_(offset, (int)std::ios_base::beg); |
|
|
_read(data, n); |
|
|
} |
|
|
|
|
|
std::string label() const override { |
|
|
return "python file object"; |
|
|
} |
|
|
|
|
|
private: |
|
|
void _read(char* data, size_t n) { |
|
|
auto memview = PyMemoryView_FromMemory(data, n, PyBUF_WRITE); |
|
|
nb::object bytes_read = readinto_func_(nb::handle(memview)); |
|
|
|
|
|
if (bytes_read.is_none() || nb::cast<size_t>(bytes_read) < n) { |
|
|
throw std::runtime_error("[load] Failed to read from python stream"); |
|
|
} |
|
|
} |
|
|
|
|
|
nb::object pyistream_; |
|
|
nb::object readinto_func_; |
|
|
nb::object seek_func_; |
|
|
nb::object tell_func_; |
|
|
}; |
|
|
|
|
|
std::pair< |
|
|
std::unordered_map<std::string, mx::array>, |
|
|
std::unordered_map<std::string, std::string>> |
|
|
mlx_load_safetensor_helper(nb::object file, mx::StreamOrDevice s) { |
|
|
if (is_str_or_path(file)) { |
|
|
auto file_str = nb::cast<std::string>(nb::str(file)); |
|
|
return mx::load_safetensors(file_str, s); |
|
|
} else if (is_istream_object(file)) { |
|
|
|
|
|
auto res = mx::load_safetensors(std::make_shared<PyFileReader>(file), s); |
|
|
{ |
|
|
nb::gil_scoped_release gil; |
|
|
for (auto& [key, arr] : std::get<0>(res)) { |
|
|
arr.eval(); |
|
|
} |
|
|
} |
|
|
return res; |
|
|
} |
|
|
|
|
|
throw std::invalid_argument( |
|
|
"[load_safetensors] Input must be a file-like object, or string"); |
|
|
} |
|
|
|
|
|
mx::GGUFLoad mlx_load_gguf_helper(nb::object file, mx::StreamOrDevice s) { |
|
|
if (is_str_or_path(file)) { |
|
|
auto file_str = nb::cast<std::string>(nb::str(file)); |
|
|
return mx::load_gguf(file_str, s); |
|
|
} |
|
|
|
|
|
throw std::invalid_argument("[load_gguf] Input must be a string"); |
|
|
} |
|
|
|
|
|
std::unordered_map<std::string, mx::array> mlx_load_npz_helper( |
|
|
nb::object file, |
|
|
mx::StreamOrDevice s) { |
|
|
bool own_file = is_str_or_path(file); |
|
|
|
|
|
nb::module_ zipfile = nb::module_::import_("zipfile"); |
|
|
if (!is_zip_file(zipfile, file)) { |
|
|
throw std::invalid_argument( |
|
|
"[load_npz] Input must be a zip file or a file-like object that can be " |
|
|
"opened with zipfile.ZipFile"); |
|
|
} |
|
|
|
|
|
std::unordered_map<std::string, mx::array> array_dict; |
|
|
|
|
|
|
|
|
ZipFileWrapper zipfile_object(zipfile, file); |
|
|
for (const std::string& st : zipfile_object.namelist()) { |
|
|
|
|
|
nb::object sub_file = zipfile_object.open(st); |
|
|
|
|
|
|
|
|
auto arr = mx::load(std::make_shared<PyFileReader>(sub_file), s); |
|
|
|
|
|
|
|
|
auto key = st; |
|
|
if (st.length() > 4 && st.substr(st.length() - 4, 4) == ".npy") |
|
|
key = st.substr(0, st.length() - 4); |
|
|
|
|
|
|
|
|
array_dict.insert({key, arr}); |
|
|
} |
|
|
|
|
|
|
|
|
if (!own_file) { |
|
|
nb::gil_scoped_release gil; |
|
|
for (auto& [key, arr] : array_dict) { |
|
|
arr.eval(); |
|
|
} |
|
|
} |
|
|
|
|
|
return array_dict; |
|
|
} |
|
|
|
|
|
mx::array mlx_load_npy_helper(nb::object file, mx::StreamOrDevice s) { |
|
|
if (is_str_or_path(file)) { |
|
|
auto file_str = nb::cast<std::string>(nb::str(file)); |
|
|
return mx::load(file_str, s); |
|
|
} else if (is_istream_object(file)) { |
|
|
|
|
|
auto arr = mx::load(std::make_shared<PyFileReader>(file), s); |
|
|
{ |
|
|
nb::gil_scoped_release gil; |
|
|
arr.eval(); |
|
|
} |
|
|
return arr; |
|
|
} |
|
|
throw std::invalid_argument( |
|
|
"[load_npy] Input must be a file-like object, or string"); |
|
|
} |
|
|
|
|
|
LoadOutputTypes mlx_load_helper( |
|
|
nb::object file, |
|
|
std::optional<std::string> format, |
|
|
bool return_metadata, |
|
|
mx::StreamOrDevice s) { |
|
|
if (!format.has_value()) { |
|
|
std::string fname; |
|
|
if (is_str_or_path(file)) { |
|
|
fname = nb::cast<std::string>(nb::str(file)); |
|
|
} else if (is_istream_object(file)) { |
|
|
fname = nb::cast<std::string>(file.attr("name")); |
|
|
} else { |
|
|
throw std::invalid_argument( |
|
|
"[load] Input must be a file-like object opened in binary mode, or string"); |
|
|
} |
|
|
size_t ext = fname.find_last_of('.'); |
|
|
if (ext == std::string::npos) { |
|
|
throw std::invalid_argument( |
|
|
"[load] Could not infer file format from extension"); |
|
|
} |
|
|
format.emplace(fname.substr(ext + 1)); |
|
|
} |
|
|
|
|
|
if (return_metadata && (format.value() == "npy" || format.value() == "npz")) { |
|
|
throw std::invalid_argument( |
|
|
"[load] metadata not supported for format " + format.value()); |
|
|
} |
|
|
if (format.value() == "safetensors") { |
|
|
auto [dict, metadata] = mlx_load_safetensor_helper(file, s); |
|
|
if (return_metadata) { |
|
|
return std::make_pair(dict, metadata); |
|
|
} |
|
|
return dict; |
|
|
} else if (format.value() == "npz") { |
|
|
return mlx_load_npz_helper(file, s); |
|
|
} else if (format.value() == "npy") { |
|
|
return mlx_load_npy_helper(file, s); |
|
|
} else if (format.value() == "gguf") { |
|
|
auto [weights, metadata] = mlx_load_gguf_helper(file, s); |
|
|
if (return_metadata) { |
|
|
return std::make_pair(weights, metadata); |
|
|
} else { |
|
|
return weights; |
|
|
} |
|
|
} else { |
|
|
throw std::invalid_argument("[load] Unknown file format " + format.value()); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PyFileWriter : public mx::io::Writer { |
|
|
public: |
|
|
PyFileWriter(nb::object file) |
|
|
: pyostream_(file), |
|
|
write_func_(file.attr("write")), |
|
|
seek_func_(file.attr("seek")), |
|
|
tell_func_(file.attr("tell")) {} |
|
|
|
|
|
~PyFileWriter() { |
|
|
nb::gil_scoped_acquire gil; |
|
|
|
|
|
pyostream_.release().dec_ref(); |
|
|
write_func_.release().dec_ref(); |
|
|
seek_func_.release().dec_ref(); |
|
|
tell_func_.release().dec_ref(); |
|
|
} |
|
|
|
|
|
bool is_open() const override { |
|
|
bool out; |
|
|
{ |
|
|
nb::gil_scoped_acquire gil; |
|
|
out = !nb::cast<bool>(pyostream_.attr("closed")); |
|
|
} |
|
|
return out; |
|
|
} |
|
|
|
|
|
bool good() const override { |
|
|
bool out; |
|
|
{ |
|
|
nb::gil_scoped_acquire gil; |
|
|
out = !pyostream_.is_none(); |
|
|
} |
|
|
return out; |
|
|
} |
|
|
|
|
|
size_t tell() override { |
|
|
size_t out; |
|
|
{ |
|
|
nb::gil_scoped_acquire gil; |
|
|
out = nb::cast<size_t>(tell_func_()); |
|
|
} |
|
|
return out; |
|
|
} |
|
|
|
|
|
void seek(int64_t off, std::ios_base::seekdir way = std::ios_base::beg) |
|
|
override { |
|
|
nb::gil_scoped_acquire gil; |
|
|
seek_func_(off, (int)way); |
|
|
} |
|
|
|
|
|
void write(const char* data, size_t n) override { |
|
|
nb::gil_scoped_acquire gil; |
|
|
|
|
|
auto memview = |
|
|
PyMemoryView_FromMemory(const_cast<char*>(data), n, PyBUF_READ); |
|
|
nb::object bytes_written = write_func_(nb::handle(memview)); |
|
|
|
|
|
if (bytes_written.is_none() || nb::cast<size_t>(bytes_written) < n) { |
|
|
throw std::runtime_error("[load] Failed to write to python stream"); |
|
|
} |
|
|
} |
|
|
|
|
|
std::string label() const override { |
|
|
return "python file object"; |
|
|
} |
|
|
|
|
|
private: |
|
|
nb::object pyostream_; |
|
|
nb::object write_func_; |
|
|
nb::object seek_func_; |
|
|
nb::object tell_func_; |
|
|
}; |
|
|
|
|
|
void mlx_save_helper(nb::object file, mx::array a) { |
|
|
if (is_str_or_path(file)) { |
|
|
auto file_str = nb::cast<std::string>(nb::str(file)); |
|
|
mx::save(file_str, a); |
|
|
return; |
|
|
} else if (is_ostream_object(file)) { |
|
|
auto writer = std::make_shared<PyFileWriter>(file); |
|
|
{ |
|
|
nb::gil_scoped_release gil; |
|
|
mx::save(writer, a); |
|
|
} |
|
|
|
|
|
return; |
|
|
} |
|
|
|
|
|
throw std::invalid_argument( |
|
|
"[save] Input must be a file-like object, or string"); |
|
|
} |
|
|
|
|
|
void mlx_savez_helper( |
|
|
nb::object file_, |
|
|
nb::args args, |
|
|
const nb::kwargs& kwargs, |
|
|
bool compressed) { |
|
|
|
|
|
nb::object file = file_; |
|
|
|
|
|
if (is_str_or_path(file)) { |
|
|
std::string fname = nb::cast<std::string>(nb::str(file_)); |
|
|
|
|
|
|
|
|
if (fname.length() < 4 || fname.substr(fname.length() - 4, 4) != ".npz") |
|
|
fname += ".npz"; |
|
|
|
|
|
file = nb::cast(fname); |
|
|
} |
|
|
|
|
|
|
|
|
auto arrays_dict = |
|
|
nb::cast<std::unordered_map<std::string, mx::array>>(kwargs); |
|
|
auto arrays_list = nb::cast<std::vector<mx::array>>(args); |
|
|
|
|
|
for (int i = 0; i < arrays_list.size(); i++) { |
|
|
std::string arr_name = "arr_" + std::to_string(i); |
|
|
|
|
|
if (arrays_dict.count(arr_name) > 0) { |
|
|
throw std::invalid_argument( |
|
|
"[savez] Cannot use un-named variables and keyword " + arr_name); |
|
|
} |
|
|
|
|
|
arrays_dict.insert({arr_name, arrays_list[i]}); |
|
|
} |
|
|
|
|
|
|
|
|
nb::module_ zipfile = nb::module_::import_("zipfile"); |
|
|
int compression = nb::cast<int>( |
|
|
compressed ? zipfile.attr("ZIP_DEFLATED") : zipfile.attr("ZIP_STORED")); |
|
|
char mode = 'w'; |
|
|
ZipFileWrapper zipfile_object(zipfile, file, mode, compression); |
|
|
|
|
|
|
|
|
for (auto [k, a] : arrays_dict) { |
|
|
std::string fname = k + ".npy"; |
|
|
auto py_ostream = zipfile_object.open(fname, 'w'); |
|
|
auto writer = std::make_shared<PyFileWriter>(py_ostream); |
|
|
{ |
|
|
nb::gil_scoped_release nogil; |
|
|
mx::save(writer, a); |
|
|
} |
|
|
} |
|
|
|
|
|
return; |
|
|
} |
|
|
|
|
|
void mlx_save_safetensor_helper( |
|
|
nb::object file, |
|
|
nb::dict d, |
|
|
std::optional<nb::dict> m) { |
|
|
std::unordered_map<std::string, std::string> metadata_map; |
|
|
if (m) { |
|
|
try { |
|
|
metadata_map = |
|
|
nb::cast<std::unordered_map<std::string, std::string>>(m.value()); |
|
|
} catch (const nb::cast_error& e) { |
|
|
throw std::invalid_argument( |
|
|
"[save_safetensors] Metadata must be a dictionary with string keys and values"); |
|
|
} |
|
|
} else { |
|
|
metadata_map = std::unordered_map<std::string, std::string>(); |
|
|
} |
|
|
auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(d); |
|
|
if (is_str_or_path(file)) { |
|
|
{ |
|
|
auto file_str = nb::cast<std::string>(nb::str(file)); |
|
|
nb::gil_scoped_release nogil; |
|
|
mx::save_safetensors(file_str, arrays_map, metadata_map); |
|
|
} |
|
|
} else if (is_ostream_object(file)) { |
|
|
auto writer = std::make_shared<PyFileWriter>(file); |
|
|
{ |
|
|
nb::gil_scoped_release nogil; |
|
|
mx::save_safetensors(writer, arrays_map, metadata_map); |
|
|
} |
|
|
} else { |
|
|
throw std::invalid_argument( |
|
|
"[save_safetensors] Input must be a file-like object, or string"); |
|
|
} |
|
|
} |
|
|
|
|
|
void mlx_save_gguf_helper( |
|
|
nb::object file, |
|
|
nb::dict a, |
|
|
std::optional<nb::dict> m) { |
|
|
auto arrays_map = nb::cast<std::unordered_map<std::string, mx::array>>(a); |
|
|
if (is_str_or_path(file)) { |
|
|
if (m) { |
|
|
auto metadata_map = |
|
|
nb::cast<std::unordered_map<std::string, mx::GGUFMetaData>>( |
|
|
m.value()); |
|
|
{ |
|
|
auto file_str = nb::cast<std::string>(nb::str(file)); |
|
|
nb::gil_scoped_release nogil; |
|
|
mx::save_gguf(file_str, arrays_map, metadata_map); |
|
|
} |
|
|
} else { |
|
|
{ |
|
|
auto file_str = nb::cast<std::string>(nb::str(file)); |
|
|
nb::gil_scoped_release nogil; |
|
|
mx::save_gguf(file_str, arrays_map); |
|
|
} |
|
|
} |
|
|
} else { |
|
|
throw std::invalid_argument("[save_gguf] Input must be a string"); |
|
|
} |
|
|
} |
|
|
|