File size: 1,479 Bytes
d1d4335 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
#pragma once
// Set of global constants that could be shareable between CPU and Metal code
#ifdef __METAL__
#define C10_METAL_CONSTEXPR constant constexpr
#else
#define C10_METAL_CONSTEXPR constexpr
#endif
#if !defined(__METAL__) || __METAL_VERSION__ >= 310
#define C10_METAL_ALL_TYPES_FUNCTOR(_) \
_(Byte, 0) \
_(Char, 1) \
_(Short, 2) \
_(Int, 3) \
_(Long, 4) \
_(Half, 5) \
_(Float, 6) \
_(ComplexHalf, 8) \
_(ComplexFloat, 9) \
_(Bool, 11) \
_(BFloat16, 15)
#else
#define C10_METAL_ALL_TYPES_FUNCTOR(_) \
_(Byte, 0) \
_(Char, 1) \
_(Short, 2) \
_(Int, 3) \
_(Long, 4) \
_(Half, 5) \
_(Float, 6) \
_(ComplexHalf, 8) \
_(ComplexFloat, 9) \
_(Bool, 11)
#endif
namespace c10 {
namespace metal {
C10_METAL_CONSTEXPR unsigned max_ndim = 16;
enum class ScalarType {
#define _DEFINE_ENUM_VAL_(_v, _n) _v = _n,
C10_METAL_ALL_TYPES_FUNCTOR(_DEFINE_ENUM_VAL_)
#undef _DEFINE_ENUM_VAL_
};
} // namespace metal
} // namespace c10
|