// Copyright © 2023 Apple Inc. #pragma once #include using namespace metal; #if __METAL_VERSION__ >= 310 typedef bfloat bfloat16_t; inline uint16_t bfloat16_to_uint16(const bfloat16_t x) { return as_type(x); } inline bfloat16_t uint16_to_bfloat16(const uint16_t x) { return as_type(x); } #else // bfloat not available before Metal 3.1; use a stub so the file parses // but only half/float kernels will be instantiated. typedef half bfloat16_t; inline uint16_t bfloat16_to_uint16(const bfloat16_t x) { return as_type(x); } inline bfloat16_t uint16_to_bfloat16(const uint16_t x) { return as_type(x); } #endif