medmekk's picture
Upload folder using huggingface_hub
20347e1 verified
// Copyright © 2023-2024 Apple Inc.
#pragma once
#include <metal_math>
#include "bf16.h"
#include "defines.h"
typedef half float16_t;
// Work per thread values for different types. The values here are expected to
// match get_work_per_thread in mlx/backend/metal/utils.h
template <typename U>
struct WorkPerThread {
static_assert(sizeof(U) <= 8, "Type too large");
static constexpr int constant n = 8 / sizeof(U);
};
///////////////////////////////////////////////////////////////////////////////
// Type limits utils
///////////////////////////////////////////////////////////////////////////////
template <typename U>
struct Limits {
static const constant U max = metal::numeric_limits<U>::max();
static const constant U min = metal::numeric_limits<U>::min();
static const constant U finite_max = metal::numeric_limits<U>::max();
static const constant U finite_min = metal::numeric_limits<U>::min();
};
#define instantiate_default_limit(type) \
template <> \
struct Limits<type> { \
static constexpr constant type max = metal::numeric_limits<type>::max(); \
static constexpr constant type min = metal::numeric_limits<type>::min(); \
static constexpr constant type finite_max = \
metal::numeric_limits<type>::max(); \
static constexpr constant type finite_min = \
metal::numeric_limits<type>::min(); \
};
instantiate_default_limit(uint8_t);
instantiate_default_limit(uint16_t);
instantiate_default_limit(uint32_t);
instantiate_default_limit(uint64_t);
instantiate_default_limit(int8_t);
instantiate_default_limit(int16_t);
instantiate_default_limit(int32_t);
instantiate_default_limit(int64_t);
#define instantiate_float_limit(type) \
template <> \
struct Limits<type> { \
static constexpr constant type max = \
metal::numeric_limits<type>::infinity(); \
static constexpr constant type min = \
-metal::numeric_limits<type>::infinity(); \
static constexpr constant type finite_max = \
metal::numeric_limits<type>::max(); \
static constexpr constant type finite_min = \
-metal::numeric_limits<type>::max(); \
};
instantiate_float_limit(half);
instantiate_float_limit(float);
instantiate_float_limit(bfloat16_t);
template <>
struct Limits<bool> {
static constexpr constant bool max = true;
static constexpr constant bool min = false;
};
// complex64_t specialization removed - not needed for BnB kernels
///////////////////////////////////////////////////////////////////////////////
// Indexing utils
///////////////////////////////////////////////////////////////////////////////
#define MLX_MTL_PRAGMA_UNROLL _Pragma("clang loop unroll(full)")
///////////////////////////////////////////////////////////////////////////////
// Single Array with generic dims
template <typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc(
IdxT elem,
constant const int* shape,
constant const int64_t* strides,
int ndim) {
IdxT loc = 0;
for (int i = ndim - 1; i >= 0 && elem > 0; --i) {
loc += (elem % shape[i]) * IdxT(strides[i]);
elem /= shape[i];
}
return loc;
}
// Non templated version to handle arbitrary dims
template <typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc(
uint3 elem,
constant const int* shape,
constant const int64_t* strides,
int ndim) {
IdxT loc =
elem.x * IdxT(strides[ndim - 1]) + elem.y * IdxT(strides[ndim - 2]);
for (int d = ndim - 3; d >= 0; --d) {
loc += (elem.z % shape[d]) * IdxT(strides[d]);
elem.z /= shape[d];
}
return loc;
}
///////////////////////////////////////////////////////////////////////////////
// Single Array with fixed N dims
template <typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc_1(uint elem, constant const int64_t& stride) {
return elem * IdxT(stride);
}
template <typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc_2(uint2 elem, constant const int64_t strides[2]) {
return elem.x * IdxT(strides[1]) + elem.y * IdxT(strides[0]);
}
template <typename IdxT = int64_t>
METAL_FUNC IdxT elem_to_loc_3(uint3 elem, constant const int64_t strides[3]) {
return elem.x * IdxT(strides[2]) + elem.y * IdxT(strides[1]) +
elem.z * IdxT(strides[0]);
}
///////////////////////////////////////////////////////////////////////////////
// Multiple Arrays with generic dims
template <typename IdxT = int64_t>
METAL_FUNC vec<IdxT, 2> elem_to_loc_2_nd(
uint3 elem,
constant const int* shape,
constant const int64_t* a_strides,
constant const int64_t* b_strides,
int ndim) {
vec<IdxT, 2> loc = {
IdxT(
elem.x * IdxT(a_strides[ndim - 1]) +
IdxT(elem.y) * IdxT(a_strides[ndim - 2])),
IdxT(
elem.x * IdxT(b_strides[ndim - 1]) +
elem.y * IdxT(b_strides[ndim - 2]))};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * IdxT(a_strides[d]);
loc.y += l * IdxT(b_strides[d]);
elem.z /= shape[d];
}
return loc;
}
template <typename IdxT = int64_t>
METAL_FUNC vec<IdxT, 3> elem_to_loc_3_nd(
uint3 elem,
constant const int* shape,
constant const int64_t* a_strides,
constant const int64_t* b_strides,
constant const int64_t* c_strides,
int ndim) {
vec<IdxT, 3> loc = {
IdxT(elem.x * IdxT(a_strides[ndim - 1])) +
IdxT(elem.y * IdxT(a_strides[ndim - 2])),
IdxT(elem.x * IdxT(b_strides[ndim - 1])) +
IdxT(elem.y * IdxT(b_strides[ndim - 2])),
IdxT(elem.x * IdxT(c_strides[ndim - 1])) +
IdxT(elem.y * IdxT(c_strides[ndim - 2]))};
for (int d = ndim - 3; d >= 0; --d) {
uint l = elem.z % shape[d];
loc.x += l * IdxT(a_strides[d]);
loc.y += l * IdxT(b_strides[d]);
loc.z += l * IdxT(c_strides[d]);
elem.z /= shape[d];
}
return loc;
}
///////////////////////////////////////////////////////////////////////////////
// Elem to loc in a loop utils
///////////////////////////////////////////////////////////////////////////////
template <int DIM, typename OffsetT = size_t, bool General = true>
struct LoopedElemToLoc {
int dim;
LoopedElemToLoc<DIM - 1, OffsetT, General> inner_looper;
OffsetT offset{0};
int index{0};
LoopedElemToLoc(int dim) : dim(dim), inner_looper(dim - 1) {}
void next(const constant int* shape, const constant int64_t* strides) {
if (dim == 0) {
return;
}
index++;
offset += OffsetT(strides[dim - 1]);
if (index >= shape[dim - 1]) {
index = 0;
inner_looper.next(shape, strides);
offset = inner_looper.offset;
}
}
void next(int n, const constant int* shape, const constant int64_t* strides) {
if (dim == 0) {
return;
}
index += n;
offset += n * OffsetT(strides[dim - 1]);
if (index >= shape[dim - 1]) {
int extra = index - shape[dim - 1];
if (extra >= shape[dim - 1]) {
inner_looper.next(1 + extra / shape[dim - 1], shape, strides);
extra = extra % shape[dim - 1];
} else {
inner_looper.next(shape, strides);
}
index = 0;
offset = inner_looper.offset;
if (extra > 0) {
next(extra, shape, strides);
}
}
}
OffsetT location() {
return offset;
}
};
template <typename OffsetT>
struct LoopedElemToLoc<1, OffsetT, true> {
int dim;
OffsetT offset{0};
uint index{0};
LoopedElemToLoc(int dim) : dim(dim) {}
void next(const constant int* shape, const constant int64_t* strides) {
index++;
if (dim > 1) {
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
} else {
offset += OffsetT(strides[0]);
}
}
void next(int n, const constant int* shape, const constant int64_t* strides) {
index += n;
if (dim > 1) {
offset = elem_to_loc<OffsetT>(index, shape, strides, dim);
} else {
offset = index * OffsetT(strides[0]);
}
}
OffsetT location() {
return offset;
}
};
template <typename OffsetT>
struct LoopedElemToLoc<1, OffsetT, false> {
OffsetT offset{0};
LoopedElemToLoc(int) {}
void next(const constant int*, const constant int64_t* strides) {
offset += OffsetT(strides[0]);
}
void next(int n, const constant int*, const constant int64_t* strides) {
offset += n * OffsetT(strides[0]);
}
OffsetT location() {
return offset;
}
};
///////////////////////////////////////////////////////////////////////////////
// Calculation utils
///////////////////////////////////////////////////////////////////////////////
/** Compute ceil((float)N/(float)M) */
template <typename T, typename U>
inline T ceildiv(T N, U M) {
return (N + M - 1) / M;
}
// https://docs.oracle.com/cd/E19957-01/806-3568/ncg_goldberg.html#1202
inline float log1p(float x) {
float xp1 = 1.0f + x;
if (xp1 == Limits<float>::max) {
return Limits<float>::max;
}
if (xp1 == 1.0f) {
return x;
}
return x * (metal::log(xp1) / (xp1 - 1.0f));
}
inline bfloat16_t log1p(bfloat16_t x) {
float xp1 = 1.0f + static_cast<float>(x);
if (xp1 == Limits<float>::max) {
return Limits<bfloat16_t>::max;
}
if (xp1 == 1.0f) {
return x;
}
return bfloat16_t(x * (metal::log(xp1) / (xp1 - 1.0f)));
}
///////////////////////////////////////////////////////////////////////////////
// SIMD shuffle ops
///////////////////////////////////////////////////////////////////////////////
inline uint64_t simd_shuffle_down(uint64_t data, uint16_t delta) {
return as_type<uint64_t>(
metal::simd_shuffle_down(as_type<uint2>(data), delta));
}
inline int64_t simd_shuffle_down(int64_t data, uint16_t delta) {
return as_type<int64_t>(
metal::simd_shuffle_down(as_type<uint2>(data), delta));
}
inline bool simd_shuffle_down(bool data, uint16_t delta) {
return simd_shuffle_down(static_cast<uint32_t>(data), delta);
}
inline uint64_t simd_shuffle_up(uint64_t data, uint16_t delta) {
return as_type<uint64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));
}
inline int64_t simd_shuffle_up(int64_t data, uint16_t delta) {
return as_type<int64_t>(metal::simd_shuffle_up(as_type<uint2>(data), delta));
}
inline bool simd_shuffle_up(bool data, uint16_t delta) {
return simd_shuffle_up(static_cast<uint32_t>(data), delta);
}
inline uint64_t
simd_shuffle_and_fill_up(uint64_t data, uint64_t filling, uint16_t delta) {
return as_type<uint64_t>(metal::simd_shuffle_and_fill_up(
as_type<uint2>(data), as_type<uint2>(filling), delta));
}
inline int64_t
simd_shuffle_and_fill_up(int64_t data, int64_t filling, uint16_t delta) {
return as_type<int64_t>(metal::simd_shuffle_and_fill_up(
as_type<uint2>(data), as_type<uint2>(filling), delta));
}
inline bool simd_shuffle_and_fill_up(bool data, bool filling, uint16_t delta) {
return simd_shuffle_and_fill_up(
static_cast<uint32_t>(data), static_cast<uint32_t>(filling), delta);
}
inline uint64_t simd_shuffle(uint64_t data, uint16_t lane) {
return as_type<uint64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));
}
inline int64_t simd_shuffle(int64_t data, uint16_t lane) {
return as_type<int64_t>(metal::simd_shuffle(as_type<uint2>(data), lane));
}
inline bool simd_shuffle(bool data, uint16_t lane) {
return simd_shuffle(static_cast<uint32_t>(data), lane);
}
// std::conditional is not included with Metal
template <bool condition, typename T, typename U>
struct ConditionalType {
using type = U;
};
template <typename T, typename U>
struct ConditionalType<true, T, U> {
using type = T;
};