#pragma once #include "builtin.h" namespace wp { #if FP_CHECK #define FP_ASSERT_FWD(value) \ print(value); \ printf(")\n"); \ assert(0); \ #define FP_ASSERT_ADJ(value, adj_value) \ print(value); \ printf(", "); \ print(adj_value); \ printf(")\n"); \ assert(0); \ #define FP_VERIFY_FWD(value) \ if (!isfinite(value)) { \ printf("%s:%d - %s(addr", __FILE__, __LINE__, __FUNCTION__); \ FP_ASSERT_FWD(value) \ } \ #define FP_VERIFY_FWD_1(value) \ if (!isfinite(value)) { \ printf("%s:%d - %s(arr, %d) ", __FILE__, __LINE__, __FUNCTION__, i); \ FP_ASSERT_FWD(value) \ } \ #define FP_VERIFY_FWD_2(value) \ if (!isfinite(value)) { \ printf("%s:%d - %s(arr, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j); \ FP_ASSERT_FWD(value) \ } \ #define FP_VERIFY_FWD_3(value) \ if (!isfinite(value)) { \ printf("%s:%d - %s(arr, %d, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j, k); \ FP_ASSERT_FWD(value) \ } \ #define FP_VERIFY_FWD_4(value) \ if (!isfinite(value)) { \ printf("%s:%d - %s(arr, %d, %d, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j, k, l); \ FP_ASSERT_FWD(value) \ } \ #define FP_VERIFY_ADJ(value, adj_value) \ if (!isfinite(value) || !isfinite(adj_value)) \ { \ printf("%s:%d - %s(addr", __FILE__, __LINE__, __FUNCTION__); \ FP_ASSERT_ADJ(value, adj_value); \ } \ #define FP_VERIFY_ADJ_1(value, adj_value) \ if (!isfinite(value) || !isfinite(adj_value)) \ { \ printf("%s:%d - %s(arr, %d) ", __FILE__, __LINE__, __FUNCTION__, i); \ FP_ASSERT_ADJ(value, adj_value); \ } \ #define FP_VERIFY_ADJ_2(value, adj_value) \ if (!isfinite(value) || !isfinite(adj_value)) \ { \ printf("%s:%d - %s(arr, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j); \ FP_ASSERT_ADJ(value, adj_value); \ } \ #define FP_VERIFY_ADJ_3(value, adj_value) \ if (!isfinite(value) || !isfinite(adj_value)) \ { \ printf("%s:%d - %s(arr, %d, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j, k); \ FP_ASSERT_ADJ(value, adj_value); \ } \ #define FP_VERIFY_ADJ_4(value, adj_value) \ if (!isfinite(value) || !isfinite(adj_value)) \ { \ printf("%s:%d - %s(arr, %d, %d, %d, %d) ", __FILE__, __LINE__, __FUNCTION__, i, j, k, l); \ FP_ASSERT_ADJ(value, adj_value); \ } \ #else #define FP_VERIFY_FWD(value) {} #define FP_VERIFY_FWD_1(value) {} #define FP_VERIFY_FWD_2(value) {} #define FP_VERIFY_FWD_3(value) {} #define FP_VERIFY_FWD_4(value) {} #define FP_VERIFY_ADJ(value, adj_value) {} #define FP_VERIFY_ADJ_1(value, adj_value) {} #define FP_VERIFY_ADJ_2(value, adj_value) {} #define FP_VERIFY_ADJ_3(value, adj_value) {} #define FP_VERIFY_ADJ_4(value, adj_value) {} #endif // WP_FP_CHECK const int ARRAY_MAX_DIMS = 4; // must match constant in types.py // must match constants in types.py const int ARRAY_TYPE_REGULAR = 0; const int ARRAY_TYPE_INDEXED = 1; const int ARRAY_TYPE_FABRIC = 2; const int ARRAY_TYPE_FABRIC_INDEXED = 3; struct shape_t { int dims[ARRAY_MAX_DIMS]; CUDA_CALLABLE inline shape_t() : dims() {} CUDA_CALLABLE inline int operator[](int i) const { assert(i < ARRAY_MAX_DIMS); return dims[i]; } CUDA_CALLABLE inline int& operator[](int i) { assert(i < ARRAY_MAX_DIMS); return dims[i]; } }; CUDA_CALLABLE inline int extract(const shape_t& s, int i) { return s.dims[i]; } CUDA_CALLABLE inline void adj_extract(const shape_t& s, int i, const shape_t& adj_s, int adj_i, int adj_ret) {} inline CUDA_CALLABLE void print(shape_t s) { // todo: only print valid dims, currently shape has a fixed size // but we don't know how many dims are valid (e.g.: 1d, 2d, etc) // should probably store ndim with shape printf("(%d, %d, %d, %d)\n", s.dims[0], s.dims[1], s.dims[2], s.dims[3]); } inline CUDA_CALLABLE void adj_print(shape_t s, shape_t& shape_t) {} template struct array_t { CUDA_CALLABLE inline array_t() {} CUDA_CALLABLE inline array_t(int) {} // for backward a = 0 initialization syntax CUDA_CALLABLE array_t(T* data, int size, T* grad=nullptr) : data(data), grad(grad) { // constructor for 1d array shape.dims[0] = size; shape.dims[1] = 0; shape.dims[2] = 0; shape.dims[3] = 0; ndim = 1; strides[0] = sizeof(T); strides[1] = 0; strides[2] = 0; strides[3] = 0; } CUDA_CALLABLE array_t(T* data, int dim0, int dim1, T* grad=nullptr) : data(data), grad(grad) { // constructor for 2d array shape.dims[0] = dim0; shape.dims[1] = dim1; shape.dims[2] = 0; shape.dims[3] = 0; ndim = 2; strides[0] = dim1 * sizeof(T); strides[1] = sizeof(T); strides[2] = 0; strides[3] = 0; } CUDA_CALLABLE array_t(T* data, int dim0, int dim1, int dim2, T* grad=nullptr) : data(data), grad(grad) { // constructor for 3d array shape.dims[0] = dim0; shape.dims[1] = dim1; shape.dims[2] = dim2; shape.dims[3] = 0; ndim = 3; strides[0] = dim1 * dim2 * sizeof(T); strides[1] = dim2 * sizeof(T); strides[2] = sizeof(T); strides[3] = 0; } CUDA_CALLABLE array_t(T* data, int dim0, int dim1, int dim2, int dim3, T* grad=nullptr) : data(data), grad(grad) { // constructor for 4d array shape.dims[0] = dim0; shape.dims[1] = dim1; shape.dims[2] = dim2; shape.dims[3] = dim3; ndim = 4; strides[0] = dim1 * dim2 * dim3 * sizeof(T); strides[1] = dim2 * dim3 * sizeof(T); strides[2] = dim3 * sizeof(T); strides[3] = sizeof(T); } CUDA_CALLABLE inline bool empty() const { return !data; } T* data{nullptr}; T* grad{nullptr}; shape_t shape; int strides[ARRAY_MAX_DIMS]; int ndim; CUDA_CALLABLE inline operator T*() const { return data; } }; // TODO: // - templated index type? // - templated dimensionality? (also for array_t to save space when passing arrays to kernels) template struct indexedarray_t { CUDA_CALLABLE inline indexedarray_t() {} CUDA_CALLABLE inline indexedarray_t(int) {} // for backward a = 0 initialization syntax CUDA_CALLABLE inline bool empty() const { return !arr.data; } array_t arr; int* indices[ARRAY_MAX_DIMS]; // index array per dimension (can be NULL) shape_t shape; // element count per dimension (num. indices if indexed, array dim if not) }; // return stride (in bytes) of the given index template CUDA_CALLABLE inline size_t stride(const array_t& a, int dim) { return size_t(a.strides[dim]); } template CUDA_CALLABLE inline T* data_at_byte_offset(const array_t& a, size_t byte_offset) { return reinterpret_cast(reinterpret_cast(a.data) + byte_offset); } template CUDA_CALLABLE inline T* grad_at_byte_offset(const array_t& a, size_t byte_offset) { return reinterpret_cast(reinterpret_cast(a.grad) + byte_offset); } template CUDA_CALLABLE inline size_t byte_offset(const array_t& arr, int i) { assert(i >= 0 && i < arr.shape[0]); return i*stride(arr, 0); } template CUDA_CALLABLE inline size_t byte_offset(const array_t& arr, int i, int j) { assert(i >= 0 && i < arr.shape[0]); assert(j >= 0 && j < arr.shape[1]); return i*stride(arr, 0) + j*stride(arr, 1); } template CUDA_CALLABLE inline size_t byte_offset(const array_t& arr, int i, int j, int k) { assert(i >= 0 && i < arr.shape[0]); assert(j >= 0 && j < arr.shape[1]); assert(k >= 0 && k < arr.shape[2]); return i*stride(arr, 0) + j*stride(arr, 1) + k*stride(arr, 2); } template CUDA_CALLABLE inline size_t byte_offset(const array_t& arr, int i, int j, int k, int l) { assert(i >= 0 && i < arr.shape[0]); assert(j >= 0 && j < arr.shape[1]); assert(k >= 0 && k < arr.shape[2]); assert(l >= 0 && l < arr.shape[3]); return i*stride(arr, 0) + j*stride(arr, 1) + k*stride(arr, 2) + l*stride(arr, 3); } template CUDA_CALLABLE inline T& index(const array_t& arr, int i) { assert(arr.ndim == 1); T& result = *data_at_byte_offset(arr, byte_offset(arr, i)); FP_VERIFY_FWD_1(result) return result; } template CUDA_CALLABLE inline T& index(const array_t& arr, int i, int j) { assert(arr.ndim == 2); T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j)); FP_VERIFY_FWD_2(result) return result; } template CUDA_CALLABLE inline T& index(const array_t& arr, int i, int j, int k) { assert(arr.ndim == 3); T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j, k)); FP_VERIFY_FWD_3(result) return result; } template CUDA_CALLABLE inline T& index(const array_t& arr, int i, int j, int k, int l) { assert(arr.ndim == 4); T& result = *data_at_byte_offset(arr, byte_offset(arr, i, j, k, l)); FP_VERIFY_FWD_4(result) return result; } template CUDA_CALLABLE inline T& index_grad(const array_t& arr, int i) { T& result = *grad_at_byte_offset(arr, byte_offset(arr, i)); FP_VERIFY_FWD_1(result) return result; } template CUDA_CALLABLE inline T& index_grad(const array_t& arr, int i, int j) { T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j)); FP_VERIFY_FWD_2(result) return result; } template CUDA_CALLABLE inline T& index_grad(const array_t& arr, int i, int j, int k) { T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j, k)); FP_VERIFY_FWD_3(result) return result; } template CUDA_CALLABLE inline T& index_grad(const array_t& arr, int i, int j, int k, int l) { T& result = *grad_at_byte_offset(arr, byte_offset(arr, i, j, k, l)); FP_VERIFY_FWD_4(result) return result; } template CUDA_CALLABLE inline T& index(const indexedarray_t& iarr, int i) { assert(iarr.arr.ndim == 1); assert(i >= 0 && i < iarr.shape[0]); if (iarr.indices[0]) { i = iarr.indices[0][i]; assert(i >= 0 && i < iarr.arr.shape[0]); } T& result = *data_at_byte_offset(iarr.arr, byte_offset(iarr.arr, i)); FP_VERIFY_FWD_1(result) return result; } template CUDA_CALLABLE inline T& index(const indexedarray_t& iarr, int i, int j) { assert(iarr.arr.ndim == 2); assert(i >= 0 && i < iarr.shape[0]); assert(j >= 0 && j < iarr.shape[1]); if (iarr.indices[0]) { i = iarr.indices[0][i]; assert(i >= 0 && i < iarr.arr.shape[0]); } if (iarr.indices[1]) { j = iarr.indices[1][j]; assert(j >= 0 && j < iarr.arr.shape[1]); } T& result = *data_at_byte_offset(iarr.arr, byte_offset(iarr.arr, i, j)); FP_VERIFY_FWD_1(result) return result; } template CUDA_CALLABLE inline T& index(const indexedarray_t& iarr, int i, int j, int k) { assert(iarr.arr.ndim == 3); assert(i >= 0 && i < iarr.shape[0]); assert(j >= 0 && j < iarr.shape[1]); assert(k >= 0 && k < iarr.shape[2]); if (iarr.indices[0]) { i = iarr.indices[0][i]; assert(i >= 0 && i < iarr.arr.shape[0]); } if (iarr.indices[1]) { j = iarr.indices[1][j]; assert(j >= 0 && j < iarr.arr.shape[1]); } if (iarr.indices[2]) { k = iarr.indices[2][k]; assert(k >= 0 && k < iarr.arr.shape[2]); } T& result = *data_at_byte_offset(iarr.arr, byte_offset(iarr.arr, i, j, k)); FP_VERIFY_FWD_1(result) return result; } template CUDA_CALLABLE inline T& index(const indexedarray_t& iarr, int i, int j, int k, int l) { assert(iarr.arr.ndim == 4); assert(i >= 0 && i < iarr.shape[0]); assert(j >= 0 && j < iarr.shape[1]); assert(k >= 0 && k < iarr.shape[2]); assert(l >= 0 && l < iarr.shape[3]); if (iarr.indices[0]) { i = iarr.indices[0][i]; assert(i >= 0 && i < iarr.arr.shape[0]); } if (iarr.indices[1]) { j = iarr.indices[1][j]; assert(j >= 0 && j < iarr.arr.shape[1]); } if (iarr.indices[2]) { k = iarr.indices[2][k]; assert(k >= 0 && k < iarr.arr.shape[2]); } if (iarr.indices[3]) { l = iarr.indices[3][l]; assert(l >= 0 && l < iarr.arr.shape[3]); } T& result = *data_at_byte_offset(iarr.arr, byte_offset(iarr.arr, i, j, k, l)); FP_VERIFY_FWD_1(result) return result; } template CUDA_CALLABLE inline array_t view(array_t& src, int i) { assert(src.ndim > 1); assert(i >= 0 && i < src.shape[0]); array_t a; a.data = data_at_byte_offset(src, byte_offset(src, i)); a.shape[0] = src.shape[1]; a.shape[1] = src.shape[2]; a.shape[2] = src.shape[3]; a.strides[0] = src.strides[1]; a.strides[1] = src.strides[2]; a.strides[2] = src.strides[3]; a.ndim = src.ndim-1; return a; } template CUDA_CALLABLE inline array_t view(array_t& src, int i, int j) { assert(src.ndim > 2); assert(i >= 0 && i < src.shape[0]); assert(j >= 0 && j < src.shape[1]); array_t a; a.data = data_at_byte_offset(src, byte_offset(src, i, j)); a.shape[0] = src.shape[2]; a.shape[1] = src.shape[3]; a.strides[0] = src.strides[2]; a.strides[1] = src.strides[3]; a.ndim = src.ndim-2; return a; } template CUDA_CALLABLE inline array_t view(array_t& src, int i, int j, int k) { assert(src.ndim > 3); assert(i >= 0 && i < src.shape[0]); assert(j >= 0 && j < src.shape[1]); assert(k >= 0 && k < src.shape[2]); array_t a; a.data = data_at_byte_offset(src, byte_offset(src, i, j, k)); a.shape[0] = src.shape[3]; a.strides[0] = src.strides[3]; a.ndim = src.ndim-3; return a; } template CUDA_CALLABLE inline indexedarray_t view(indexedarray_t& src, int i) { assert(src.arr.ndim > 1); if (src.indices[0]) { assert(i >= 0 && i < src.shape[0]); i = src.indices[0][i]; } indexedarray_t a; a.arr = view(src.arr, i); a.indices[0] = src.indices[1]; a.indices[1] = src.indices[2]; a.indices[2] = src.indices[3]; a.shape[0] = src.shape[1]; a.shape[1] = src.shape[2]; a.shape[2] = src.shape[3]; return a; } template CUDA_CALLABLE inline indexedarray_t view(indexedarray_t& src, int i, int j) { assert(src.arr.ndim > 2); if (src.indices[0]) { assert(i >= 0 && i < src.shape[0]); i = src.indices[0][i]; } if (src.indices[1]) { assert(j >= 0 && j < src.shape[1]); j = src.indices[1][j]; } indexedarray_t a; a.arr = view(src.arr, i, j); a.indices[0] = src.indices[2]; a.indices[1] = src.indices[3]; a.shape[0] = src.shape[2]; a.shape[1] = src.shape[3]; return a; } template CUDA_CALLABLE inline indexedarray_t view(indexedarray_t& src, int i, int j, int k) { assert(src.arr.ndim > 3); if (src.indices[0]) { assert(i >= 0 && i < src.shape[0]); i = src.indices[0][i]; } if (src.indices[1]) { assert(j >= 0 && j < src.shape[1]); j = src.indices[1][j]; } if (src.indices[2]) { assert(k >= 0 && k < src.shape[2]); k = src.indices[2][k]; } indexedarray_t a; a.arr = view(src.arr, i, j, k); a.indices[0] = src.indices[3]; a.shape[0] = src.shape[3]; return a; } template class A1, template class A2, template class A3, typename T> inline CUDA_CALLABLE void adj_view(A1& src, int i, A2& adj_src, int adj_i, A3 adj_ret) {} template class A1, template class A2, template class A3, typename T> inline CUDA_CALLABLE void adj_view(A1& src, int i, int j, A2& adj_src, int adj_i, int adj_j, A3 adj_ret) {} template class A1, template class A2, template class A3, typename T> inline CUDA_CALLABLE void adj_view(A1& src, int i, int j, int k, A2& adj_src, int adj_i, int adj_j, int adj_k, A3 adj_ret) {} // TODO: lower_bound() for indexed arrays? template CUDA_CALLABLE inline int lower_bound(const array_t& arr, int arr_begin, int arr_end, T value) { assert(arr.ndim == 1); int lower = arr_begin; int upper = arr_end - 1; while(lower < upper) { int mid = lower + (upper - lower) / 2; if (arr[mid] < value) { lower = mid + 1; } else { upper = mid; } } return lower; } template CUDA_CALLABLE inline int lower_bound(const array_t& arr, T value) { return lower_bound(arr, 0, arr.shape[0], value); } template inline CUDA_CALLABLE void adj_lower_bound(const array_t& arr, T value, array_t adj_arr, T adj_value, int adj_ret) {} template inline CUDA_CALLABLE void adj_lower_bound(const array_t& arr, int arr_begin, int arr_end, T value, array_t adj_arr, int adj_arr_begin, int adj_arr_end, T adj_value, int adj_ret) {} template class A, typename T> inline CUDA_CALLABLE T atomic_add(const A& buf, int i, T value) { return atomic_add(&index(buf, i), value); } template class A, typename T> inline CUDA_CALLABLE T atomic_add(const A& buf, int i, int j, T value) { return atomic_add(&index(buf, i, j), value); } template class A, typename T> inline CUDA_CALLABLE T atomic_add(const A& buf, int i, int j, int k, T value) { return atomic_add(&index(buf, i, j, k), value); } template class A, typename T> inline CUDA_CALLABLE T atomic_add(const A& buf, int i, int j, int k, int l, T value) { return atomic_add(&index(buf, i, j, k, l), value); } template class A, typename T> inline CUDA_CALLABLE T atomic_sub(const A& buf, int i, T value) { return atomic_add(&index(buf, i), -value); } template class A, typename T> inline CUDA_CALLABLE T atomic_sub(const A& buf, int i, int j, T value) { return atomic_add(&index(buf, i, j), -value); } template class A, typename T> inline CUDA_CALLABLE T atomic_sub(const A& buf, int i, int j, int k, T value) { return atomic_add(&index(buf, i, j, k), -value); } template class A, typename T> inline CUDA_CALLABLE T atomic_sub(const A& buf, int i, int j, int k, int l, T value) { return atomic_add(&index(buf, i, j, k, l), -value); } template class A, typename T> inline CUDA_CALLABLE T atomic_min(const A& buf, int i, T value) { return atomic_min(&index(buf, i), value); } template class A, typename T> inline CUDA_CALLABLE T atomic_min(const A& buf, int i, int j, T value) { return atomic_min(&index(buf, i, j), value); } template class A, typename T> inline CUDA_CALLABLE T atomic_min(const A& buf, int i, int j, int k, T value) { return atomic_min(&index(buf, i, j, k), value); } template class A, typename T> inline CUDA_CALLABLE T atomic_min(const A& buf, int i, int j, int k, int l, T value) { return atomic_min(&index(buf, i, j, k, l), value); } template class A, typename T> inline CUDA_CALLABLE T atomic_max(const A& buf, int i, T value) { return atomic_max(&index(buf, i), value); } template class A, typename T> inline CUDA_CALLABLE T atomic_max(const A& buf, int i, int j, T value) { return atomic_max(&index(buf, i, j), value); } template class A, typename T> inline CUDA_CALLABLE T atomic_max(const A& buf, int i, int j, int k, T value) { return atomic_max(&index(buf, i, j, k), value); } template class A, typename T> inline CUDA_CALLABLE T atomic_max(const A& buf, int i, int j, int k, int l, T value) { return atomic_max(&index(buf, i, j, k, l), value); } template class A, typename T> inline CUDA_CALLABLE T* address(const A& buf, int i) { return &index(buf, i); } template class A, typename T> inline CUDA_CALLABLE T* address(const A& buf, int i, int j) { return &index(buf, i, j); } template class A, typename T> inline CUDA_CALLABLE T* address(const A& buf, int i, int j, int k) { return &index(buf, i, j, k); } template class A, typename T> inline CUDA_CALLABLE T* address(const A& buf, int i, int j, int k, int l) { return &index(buf, i, j, k, l); } template class A, typename T> inline CUDA_CALLABLE void array_store(const A& buf, int i, T value) { FP_VERIFY_FWD_1(value) index(buf, i) = value; } template class A, typename T> inline CUDA_CALLABLE void array_store(const A& buf, int i, int j, T value) { FP_VERIFY_FWD_2(value) index(buf, i, j) = value; } template class A, typename T> inline CUDA_CALLABLE void array_store(const A& buf, int i, int j, int k, T value) { FP_VERIFY_FWD_3(value) index(buf, i, j, k) = value; } template class A, typename T> inline CUDA_CALLABLE void array_store(const A& buf, int i, int j, int k, int l, T value) { FP_VERIFY_FWD_4(value) index(buf, i, j, k, l) = value; } template inline CUDA_CALLABLE void store(T* address, T value) { FP_VERIFY_FWD(value) *address = value; } template inline CUDA_CALLABLE T load(T* address) { T value = *address; FP_VERIFY_FWD(value) return value; } // select operator to check for array being null template CUDA_CALLABLE inline T2 select(const array_t& arr, const T2& a, const T2& b) { return arr.data?b:a; } template CUDA_CALLABLE inline void adj_select(const array_t& arr, const T2& a, const T2& b, const array_t& adj_cond, T2& adj_a, T2& adj_b, const T2& adj_ret) { if (arr.data) adj_b += adj_ret; else adj_a += adj_ret; } // stub for the case where we have an nested array inside a struct and // atomic add the whole struct onto an array (e.g.: during backwards pass) template CUDA_CALLABLE inline void atomic_add(array_t*, array_t) {} // for float and vector types this is just an alias for an atomic add template CUDA_CALLABLE inline void adj_atomic_add(T* buf, T value) { atomic_add(buf, value); } // for integral types we do not accumulate gradients CUDA_CALLABLE inline void adj_atomic_add(int8* buf, int8 value) { } CUDA_CALLABLE inline void adj_atomic_add(uint8* buf, uint8 value) { } CUDA_CALLABLE inline void adj_atomic_add(int16* buf, int16 value) { } CUDA_CALLABLE inline void adj_atomic_add(uint16* buf, uint16 value) { } CUDA_CALLABLE inline void adj_atomic_add(int32* buf, int32 value) { } CUDA_CALLABLE inline void adj_atomic_add(uint32* buf, uint32 value) { } CUDA_CALLABLE inline void adj_atomic_add(int64* buf, int64 value) { } CUDA_CALLABLE inline void adj_atomic_add(uint64* buf, uint64 value) { } CUDA_CALLABLE inline void adj_atomic_add(bool* buf, bool value) { } // only generate gradients for T types template inline CUDA_CALLABLE void adj_address(const array_t& buf, int i, const array_t& adj_buf, int& adj_i, const T& adj_output) { if (buf.grad) adj_atomic_add(&index_grad(buf, i), adj_output); } template inline CUDA_CALLABLE void adj_address(const array_t& buf, int i, int j, const array_t& adj_buf, int& adj_i, int& adj_j, const T& adj_output) { if (buf.grad) adj_atomic_add(&index_grad(buf, i, j), adj_output); } template inline CUDA_CALLABLE void adj_address(const array_t& buf, int i, int j, int k, const array_t& adj_buf, int& adj_i, int& adj_j, int& adj_k, const T& adj_output) { if (buf.grad) adj_atomic_add(&index_grad(buf, i, j, k), adj_output); } template inline CUDA_CALLABLE void adj_address(const array_t& buf, int i, int j, int k, int l, const array_t& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, const T& adj_output) { if (buf.grad) adj_atomic_add(&index_grad(buf, i, j, k, l), adj_output); } template inline CUDA_CALLABLE void adj_array_store(const array_t& buf, int i, T value, const array_t& adj_buf, int& adj_i, T& adj_value) { if (buf.grad) adj_value += index_grad(buf, i); FP_VERIFY_ADJ_1(value, adj_value) } template inline CUDA_CALLABLE void adj_array_store(const array_t& buf, int i, int j, T value, const array_t& adj_buf, int& adj_i, int& adj_j, T& adj_value) { if (buf.grad) adj_value += index_grad(buf, i, j); FP_VERIFY_ADJ_2(value, adj_value) } template inline CUDA_CALLABLE void adj_array_store(const array_t& buf, int i, int j, int k, T value, const array_t& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value) { if (buf.grad) adj_value += index_grad(buf, i, j, k); FP_VERIFY_ADJ_3(value, adj_value) } template inline CUDA_CALLABLE void adj_array_store(const array_t& buf, int i, int j, int k, int l, T value, const array_t& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value) { if (buf.grad) adj_value += index_grad(buf, i, j, k, l); FP_VERIFY_ADJ_4(value, adj_value) } template inline CUDA_CALLABLE void adj_store(const T* address, T value, const T& adj_address, T& adj_value) { // nop; generic store() operations are not differentiable, only array_store() is FP_VERIFY_ADJ(value, adj_value) } template inline CUDA_CALLABLE void adj_load(const T* address, const T& adj_address, T& adj_value) { // nop; generic load() operations are not differentiable } template inline CUDA_CALLABLE void adj_atomic_add(const array_t& buf, int i, T value, const array_t& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) { if (buf.grad) adj_value += index_grad(buf, i); FP_VERIFY_ADJ_1(value, adj_value) } template inline CUDA_CALLABLE void adj_atomic_add(const array_t& buf, int i, int j, T value, const array_t& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) { if (buf.grad) adj_value += index_grad(buf, i, j); FP_VERIFY_ADJ_2(value, adj_value) } template inline CUDA_CALLABLE void adj_atomic_add(const array_t& buf, int i, int j, int k, T value, const array_t& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) { if (buf.grad) adj_value += index_grad(buf, i, j, k); FP_VERIFY_ADJ_3(value, adj_value) } template inline CUDA_CALLABLE void adj_atomic_add(const array_t& buf, int i, int j, int k, int l, T value, const array_t& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) { if (buf.grad) adj_value += index_grad(buf, i, j, k, l); FP_VERIFY_ADJ_4(value, adj_value) } template inline CUDA_CALLABLE void adj_atomic_sub(const array_t& buf, int i, T value, const array_t& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) { if (buf.grad) adj_value -= index_grad(buf, i); FP_VERIFY_ADJ_1(value, adj_value) } template inline CUDA_CALLABLE void adj_atomic_sub(const array_t& buf, int i, int j, T value, const array_t& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) { if (buf.grad) adj_value -= index_grad(buf, i, j); FP_VERIFY_ADJ_2(value, adj_value) } template inline CUDA_CALLABLE void adj_atomic_sub(const array_t& buf, int i, int j, int k, T value, const array_t& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) { if (buf.grad) adj_value -= index_grad(buf, i, j, k); FP_VERIFY_ADJ_3(value, adj_value) } template inline CUDA_CALLABLE void adj_atomic_sub(const array_t& buf, int i, int j, int k, int l, T value, const array_t& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) { if (buf.grad) adj_value -= index_grad(buf, i, j, k, l); FP_VERIFY_ADJ_4(value, adj_value) } // generic array types that do not support gradient computation (indexedarray, etc.) template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_address(const A1& buf, int i, const A2& adj_buf, int& adj_i, const T& adj_output) {} template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_address(const A1& buf, int i, int j, const A2& adj_buf, int& adj_i, int& adj_j, const T& adj_output) {} template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_address(const A1& buf, int i, int j, int k, const A2& adj_buf, int& adj_i, int& adj_j, int& adj_k, const T& adj_output) {} template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_address(const A1& buf, int i, int j, int k, int l, const A2& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, const T& adj_output) {} template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_array_store(const A1& buf, int i, T value, const A2& adj_buf, int& adj_i, T& adj_value) {} template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_array_store(const A1& buf, int i, int j, T value, const A2& adj_buf, int& adj_i, int& adj_j, T& adj_value) {} template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_array_store(const A1& buf, int i, int j, int k, T value, const A2& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value) {} template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_array_store(const A1& buf, int i, int j, int k, int l, T value, const A2& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value) {} template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_atomic_add(const A1& buf, int i, T value, const A2& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {} template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_atomic_add(const A1& buf, int i, int j, T value, const A2& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {} template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_atomic_add(const A1& buf, int i, int j, int k, T value, const A2& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {} template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_atomic_add(const A1& buf, int i, int j, int k, int l, T value, const A2& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {} template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_atomic_sub(const A1& buf, int i, T value, const A2& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) {} template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_atomic_sub(const A1& buf, int i, int j, T value, const A2& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) {} template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_atomic_sub(const A1& buf, int i, int j, int k, T value, const A2& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) {} template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_atomic_sub(const A1& buf, int i, int j, int k, int l, T value, const A2& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) {} // generic handler for scalar values template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_atomic_min(const A1& buf, int i, T value, const A2& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) { if (buf.grad) adj_atomic_minmax(&index(buf, i), &index_grad(buf, i), value, adj_value); FP_VERIFY_ADJ_1(value, adj_value) } template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_atomic_min(const A1& buf, int i, int j, T value, const A2& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) { if (buf.grad) adj_atomic_minmax(&index(buf, i, j), &index_grad(buf, i, j), value, adj_value); FP_VERIFY_ADJ_2(value, adj_value) } template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_atomic_min(const A1& buf, int i, int j, int k, T value, const A2& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) { if (buf.grad) adj_atomic_minmax(&index(buf, i, j, k), &index_grad(buf, i, j, k), value, adj_value); FP_VERIFY_ADJ_3(value, adj_value) } template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_atomic_min(const A1& buf, int i, int j, int k, int l, T value, const A2& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) { if (buf.grad) adj_atomic_minmax(&index(buf, i, j, k, l), &index_grad(buf, i, j, k, l), value, adj_value); FP_VERIFY_ADJ_4(value, adj_value) } template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_atomic_max(const A1& buf, int i, T value, const A2& adj_buf, int& adj_i, T& adj_value, const T& adj_ret) { if (buf.grad) adj_atomic_minmax(&index(buf, i), &index_grad(buf, i), value, adj_value); FP_VERIFY_ADJ_1(value, adj_value) } template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_atomic_max(const A1& buf, int i, int j, T value, const A2& adj_buf, int& adj_i, int& adj_j, T& adj_value, const T& adj_ret) { if (buf.grad) adj_atomic_minmax(&index(buf, i, j), &index_grad(buf, i, j), value, adj_value); FP_VERIFY_ADJ_2(value, adj_value) } template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_atomic_max(const A1& buf, int i, int j, int k, T value, const A2& adj_buf, int& adj_i, int& adj_j, int& adj_k, T& adj_value, const T& adj_ret) { if (buf.grad) adj_atomic_minmax(&index(buf, i, j, k), &index_grad(buf, i, j, k), value, adj_value); FP_VERIFY_ADJ_3(value, adj_value) } template class A1, template class A2, typename T> inline CUDA_CALLABLE void adj_atomic_max(const A1& buf, int i, int j, int k, int l, T value, const A2& adj_buf, int& adj_i, int& adj_j, int& adj_k, int& adj_l, T& adj_value, const T& adj_ret) { if (buf.grad) adj_atomic_minmax(&index(buf, i, j, k, l), &index_grad(buf, i, j, k, l), value, adj_value); FP_VERIFY_ADJ_4(value, adj_value) } } // namespace wp #include "fabric.h"