Spaces:
Sleeping
Sleeping
| import cupy as cp | |
| from cupyx import scatter_add | |
| def get_im2col_indices(x_shape, field_height, field_width, padding=1, stride=1): | |
| """ | |
| Calculates the indices needed to extract image patches. | |
| """ | |
| N, C, H, W = x_shape | |
| out_height = int((H + 2 * padding - field_height) / stride + 1) | |
| out_width = int((W + 2 * padding - field_width) / stride + 1) | |
| i0 = cp.repeat(cp.arange(field_height), field_width) | |
| i0 = cp.tile(i0, C) | |
| i1 = stride * cp.repeat(cp.arange(out_height), out_width) | |
| j0 = cp.tile(cp.arange(field_width), field_height * C) | |
| j1 = stride * cp.tile(cp.arange(out_width), out_height) | |
| i = i0.reshape(-1, 1) + i1.reshape(1, -1) | |
| j = j0.reshape(-1, 1) + j1.reshape(1, -1) | |
| k = cp.repeat(cp.arange(C), field_height * field_width).reshape(-1, 1) | |
| return k, i, j | |
| def im2col_indices(x, field_height, field_width, padding=1, stride=1): | |
| """ | |
| Transforms the 4D image tensor into a 2D matrix of stretched out patches. | |
| """ | |
| p = padding | |
| x_padded = cp.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant') | |
| k, i, j = get_im2col_indices(x.shape, field_height, field_width, padding, stride) | |
| cols = x_padded[:, k, i, j] | |
| C = x.shape[1] | |
| cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1) | |
| return cols | |
| def col2im_indices(cols, x_shape, field_height, field_width, padding=1, stride=1): | |
| """Routes the 2D gradient matrix back into a 4D image tensor.""" | |
| N, C, H, W = x_shape | |
| H_padded, W_padded = H + 2 * padding, W + 2 * padding | |
| x_padded = cp.zeros((N, C, H_padded, W_padded), dtype=cols.dtype) | |
| k, i, j = get_im2col_indices(x_shape, field_height, field_width, padding, stride) | |
| cols_reshaped = cols.reshape(C * field_height * field_width, -1, N) | |
| cols_reshaped = cols_reshaped.transpose(2, 0, 1) | |
| scatter_add(x_padded, (slice(None), k, i, j), cols_reshaped) | |
| if padding == 0: | |
| return x_padded | |
| return x_padded[:, :, padding:-padding, padding:-padding] |