import cupy as cp from cupyx import scatter_add def get_im2col_indices(x_shape, field_height, field_width, padding=1, stride=1): """ Calculates the indices needed to extract image patches. """ N, C, H, W = x_shape out_height = int((H + 2 * padding - field_height) / stride + 1) out_width = int((W + 2 * padding - field_width) / stride + 1) i0 = cp.repeat(cp.arange(field_height), field_width) i0 = cp.tile(i0, C) i1 = stride * cp.repeat(cp.arange(out_height), out_width) j0 = cp.tile(cp.arange(field_width), field_height * C) j1 = stride * cp.tile(cp.arange(out_width), out_height) i = i0.reshape(-1, 1) + i1.reshape(1, -1) j = j0.reshape(-1, 1) + j1.reshape(1, -1) k = cp.repeat(cp.arange(C), field_height * field_width).reshape(-1, 1) return k, i, j def im2col_indices(x, field_height, field_width, padding=1, stride=1): """ Transforms the 4D image tensor into a 2D matrix of stretched out patches. """ p = padding x_padded = cp.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant') k, i, j = get_im2col_indices(x.shape, field_height, field_width, padding, stride) cols = x_padded[:, k, i, j] C = x.shape[1] cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1) return cols def col2im_indices(cols, x_shape, field_height, field_width, padding=1, stride=1): """Routes the 2D gradient matrix back into a 4D image tensor.""" N, C, H, W = x_shape H_padded, W_padded = H + 2 * padding, W + 2 * padding x_padded = cp.zeros((N, C, H_padded, W_padded), dtype=cols.dtype) k, i, j = get_im2col_indices(x_shape, field_height, field_width, padding, stride) cols_reshaped = cols.reshape(C * field_height * field_width, -1, N) cols_reshaped = cols_reshaped.transpose(2, 0, 1) scatter_add(x_padded, (slice(None), k, i, j), cols_reshaped) if padding == 0: return x_padded return x_padded[:, :, padding:-padding, padding:-padding]