|
|
|
|
|
|
|
|
#include <dlfcn.h> |
|
|
#include <iostream> |
|
|
#include <sstream> |
|
|
|
|
|
#include "mlx/backend/common/utils.h" |
|
|
#include "mlx/backend/cpu/encoder.h" |
|
|
#include "mlx/utils.h" |
|
|
|
|
|
#include "axpby/axpby.h" |
|
|
|
|
|
#ifdef _METAL_ |
|
|
#include "mlx/backend/metal/device.h" |
|
|
#include "mlx/backend/metal/utils.h" |
|
|
#endif |
|
|
|
|
|
namespace my_ext { |
|
|
|
|
|
|
|
|
|
|
|
std::string current_binary_dir() { |
|
|
static std::string binary_dir = []() { |
|
|
Dl_info info; |
|
|
if (!dladdr(reinterpret_cast<void*>(¤t_binary_dir), &info)) { |
|
|
throw std::runtime_error("Unable to get current binary dir."); |
|
|
} |
|
|
return std::filesystem::path(info.dli_fname).parent_path().string(); |
|
|
}(); |
|
|
return binary_dir; |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mx::array axpby( |
|
|
const mx::array& x, |
|
|
const mx::array& y, |
|
|
const float alpha, |
|
|
const float beta, |
|
|
mx::StreamOrDevice s |
|
|
) { |
|
|
|
|
|
auto promoted_dtype = promote_types(x.dtype(), y.dtype()); |
|
|
|
|
|
|
|
|
auto out_dtype = mx::issubdtype(promoted_dtype, mx::float32) |
|
|
? promoted_dtype |
|
|
: promote_types(promoted_dtype, mx::float32); |
|
|
|
|
|
|
|
|
auto x_casted = mx::astype(x, out_dtype, s); |
|
|
auto y_casted = mx::astype(y, out_dtype, s); |
|
|
|
|
|
|
|
|
auto broadcasted_inputs = broadcast_arrays({x_casted, y_casted}, s); |
|
|
auto out_shape = broadcasted_inputs[0].shape(); |
|
|
|
|
|
|
|
|
|
|
|
return mx::array( |
|
|
out_shape, |
|
|
out_dtype, |
|
|
|
|
|
std::make_shared<Axpby>(to_stream(s), alpha, beta), |
|
|
broadcasted_inputs); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
template <typename T> |
|
|
void axpby_impl( |
|
|
const mx::array& x, |
|
|
const mx::array& y, |
|
|
mx::array& out, |
|
|
float alpha_, |
|
|
float beta_, |
|
|
mx::Stream stream) { |
|
|
out.set_data(mx::allocator::malloc(out.nbytes())); |
|
|
|
|
|
|
|
|
auto& encoder = mx::cpu::get_command_encoder(stream); |
|
|
encoder.set_input_array(x); |
|
|
encoder.set_input_array(y); |
|
|
encoder.set_output_array(out); |
|
|
|
|
|
|
|
|
encoder.dispatch([x_ptr = x.data<T>(), |
|
|
y_ptr = y.data<T>(), |
|
|
out_ptr = out.data<T>(), |
|
|
size = out.size(), |
|
|
shape = out.shape(), |
|
|
x_strides = x.strides(), |
|
|
y_strides = y.strides(), |
|
|
alpha_, |
|
|
beta_]() { |
|
|
|
|
|
T alpha = static_cast<T>(alpha_); |
|
|
T beta = static_cast<T>(beta_); |
|
|
|
|
|
|
|
|
for (size_t out_idx = 0; out_idx < size; out_idx++) { |
|
|
|
|
|
auto x_offset = mx::elem_to_loc(out_idx, shape, x_strides); |
|
|
auto y_offset = mx::elem_to_loc(out_idx, shape, y_strides); |
|
|
|
|
|
|
|
|
|
|
|
out_ptr[out_idx] = alpha * x_ptr[x_offset] + beta * y_ptr[y_offset]; |
|
|
} |
|
|
}); |
|
|
} |
|
|
|
|
|
void Axpby::eval_cpu( |
|
|
const std::vector<mx::array>& inputs, |
|
|
std::vector<mx::array>& outputs) { |
|
|
auto& x = inputs[0]; |
|
|
auto& y = inputs[1]; |
|
|
auto& out = outputs[0]; |
|
|
|
|
|
|
|
|
if (out.dtype() == mx::float32) { |
|
|
return axpby_impl<float>(x, y, out, alpha_, beta_, stream()); |
|
|
} else if (out.dtype() == mx::float16) { |
|
|
return axpby_impl<mx::float16_t>(x, y, out, alpha_, beta_, stream()); |
|
|
} else if (out.dtype() == mx::bfloat16) { |
|
|
return axpby_impl<mx::bfloat16_t>(x, y, out, alpha_, beta_, stream()); |
|
|
} else if (out.dtype() == mx::complex64) { |
|
|
return axpby_impl<mx::complex64_t>(x, y, out, alpha_, beta_, stream()); |
|
|
} else { |
|
|
throw std::runtime_error( |
|
|
"Axpby is only supported for floating point types."); |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#ifdef _METAL_ |
|
|
|
|
|
|
|
|
void Axpby::eval_gpu( |
|
|
const std::vector<mx::array>& inputs, |
|
|
std::vector<mx::array>& outputs) { |
|
|
|
|
|
auto& x = inputs[0]; |
|
|
auto& y = inputs[1]; |
|
|
auto& out = outputs[0]; |
|
|
|
|
|
|
|
|
|
|
|
auto& s = stream(); |
|
|
|
|
|
auto& d = mx::metal::device(s.device); |
|
|
|
|
|
|
|
|
bool contiguous_kernel = |
|
|
(x.flags().row_contiguous && y.flags().row_contiguous) || |
|
|
(x.flags().col_contiguous && y.flags().col_contiguous); |
|
|
|
|
|
|
|
|
if (contiguous_kernel) { |
|
|
out.set_data( |
|
|
mx::allocator::malloc(x.data_size() * out.itemsize()), |
|
|
x.data_size(), |
|
|
x.strides(), |
|
|
x.flags()); |
|
|
} else { |
|
|
out.set_data(mx::allocator::malloc(out.nbytes())); |
|
|
} |
|
|
|
|
|
|
|
|
std::string kname = "axpby_"; |
|
|
kname += (contiguous_kernel ? "contiguous_" : "general_"); |
|
|
kname += type_to_name(out); |
|
|
|
|
|
|
|
|
auto lib = d.get_library("mlx_ext", current_binary_dir()); |
|
|
|
|
|
|
|
|
auto kernel = d.get_kernel(kname, lib); |
|
|
|
|
|
|
|
|
auto& compute_encoder = d.get_command_encoder(s.index); |
|
|
compute_encoder.set_compute_pipeline_state(kernel); |
|
|
|
|
|
|
|
|
|
|
|
int ndim = out.ndim(); |
|
|
size_t nelem = out.size(); |
|
|
|
|
|
|
|
|
compute_encoder.set_input_array(x, 0); |
|
|
compute_encoder.set_input_array(y, 1); |
|
|
|
|
|
|
|
|
compute_encoder.set_output_array(out, 2); |
|
|
|
|
|
|
|
|
compute_encoder.set_bytes(alpha_, 3); |
|
|
compute_encoder.set_bytes(beta_, 4); |
|
|
|
|
|
|
|
|
if (!contiguous_kernel) { |
|
|
compute_encoder.set_vector_bytes(x.shape(), 5); |
|
|
compute_encoder.set_vector_bytes(x.strides(), 6); |
|
|
compute_encoder.set_vector_bytes(y.strides(), 7); |
|
|
compute_encoder.set_bytes(ndim, 8); |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
size_t tgp_size = std::min(nelem, kernel->maxTotalThreadsPerThreadgroup()); |
|
|
|
|
|
|
|
|
MTL::Size group_dims = MTL::Size(tgp_size, 1, 1); |
|
|
|
|
|
|
|
|
MTL::Size grid_dims = MTL::Size(nelem, 1, 1); |
|
|
|
|
|
|
|
|
|
|
|
compute_encoder.dispatch_threads(grid_dims, group_dims); |
|
|
} |
|
|
|
|
|
#else |
|
|
|
|
|
|
|
|
void Axpby::eval_gpu( |
|
|
const std::vector<mx::array>& inputs, |
|
|
std::vector<mx::array>& out) { |
|
|
throw std::runtime_error("Axpby has no GPU implementation."); |
|
|
} |
|
|
|
|
|
#endif |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
std::vector<mx::array> Axpby::jvp( |
|
|
const std::vector<mx::array>& primals, |
|
|
const std::vector<mx::array>& tangents, |
|
|
const std::vector<int>& argnums) { |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if (argnums.size() > 1) { |
|
|
auto scale = argnums[0] == 0 ? alpha_ : beta_; |
|
|
auto scale_arr = mx::array(scale, tangents[0].dtype()); |
|
|
return {mx::multiply(scale_arr, tangents[0], stream())}; |
|
|
} |
|
|
|
|
|
|
|
|
else { |
|
|
return {axpby(tangents[0], tangents[1], alpha_, beta_, stream())}; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
std::vector<mx::array> Axpby::vjp( |
|
|
const std::vector<mx::array>& primals, |
|
|
const std::vector<mx::array>& cotangents, |
|
|
const std::vector<int>& argnums, |
|
|
const std::vector<mx::array>&) { |
|
|
|
|
|
std::vector<mx::array> vjps; |
|
|
for (auto arg : argnums) { |
|
|
auto scale = arg == 0 ? alpha_ : beta_; |
|
|
auto scale_arr = mx::array(scale, cotangents[0].dtype()); |
|
|
vjps.push_back(mx::multiply(scale_arr, cotangents[0], stream())); |
|
|
} |
|
|
return vjps; |
|
|
} |
|
|
|
|
|
|
|
|
std::pair<std::vector<mx::array>, std::vector<int>> Axpby::vmap( |
|
|
const std::vector<mx::array>& inputs, |
|
|
const std::vector<int>& axes) { |
|
|
throw std::runtime_error("Axpby has no vmap implementation."); |
|
|
} |
|
|
|
|
|
|
|
|
bool Axpby::is_equivalent(const Primitive& other) const { |
|
|
const Axpby& r_other = static_cast<const Axpby&>(other); |
|
|
return alpha_ == r_other.alpha_ && beta_ == r_other.beta_; |
|
|
} |
|
|
|
|
|
} |
|
|
|