popboat1
initial commit
ab81f90
Raw
History Blame Contribute Delete
2.04 kB
import cupy as cp
from cupyx import scatter_add
def get_im2col_indices(x_shape, field_height, field_width, padding=1, stride=1):
"""
Calculates the indices needed to extract image patches.
"""
N, C, H, W = x_shape
out_height = int((H + 2 * padding - field_height) / stride + 1)
out_width = int((W + 2 * padding - field_width) / stride + 1)
i0 = cp.repeat(cp.arange(field_height), field_width)
i0 = cp.tile(i0, C)
i1 = stride * cp.repeat(cp.arange(out_height), out_width)
j0 = cp.tile(cp.arange(field_width), field_height * C)
j1 = stride * cp.tile(cp.arange(out_width), out_height)
i = i0.reshape(-1, 1) + i1.reshape(1, -1)
j = j0.reshape(-1, 1) + j1.reshape(1, -1)
k = cp.repeat(cp.arange(C), field_height * field_width).reshape(-1, 1)
return k, i, j
def im2col_indices(x, field_height, field_width, padding=1, stride=1):
"""
Transforms the 4D image tensor into a 2D matrix of stretched out patches.
"""
p = padding
x_padded = cp.pad(x, ((0, 0), (0, 0), (p, p), (p, p)), mode='constant')
k, i, j = get_im2col_indices(x.shape, field_height, field_width, padding, stride)
cols = x_padded[:, k, i, j]
C = x.shape[1]
cols = cols.transpose(1, 2, 0).reshape(field_height * field_width * C, -1)
return cols
def col2im_indices(cols, x_shape, field_height, field_width, padding=1, stride=1):
"""Routes the 2D gradient matrix back into a 4D image tensor."""
N, C, H, W = x_shape
H_padded, W_padded = H + 2 * padding, W + 2 * padding
x_padded = cp.zeros((N, C, H_padded, W_padded), dtype=cols.dtype)
k, i, j = get_im2col_indices(x_shape, field_height, field_width, padding, stride)
cols_reshaped = cols.reshape(C * field_height * field_width, -1, N)
cols_reshaped = cols_reshaped.transpose(2, 0, 1)
scatter_add(x_padded, (slice(None), k, i, j), cols_reshaped)
if padding == 0:
return x_padded
return x_padded[:, :, padding:-padding, padding:-padding]