| import torch |
| import re |
| import cupy |
|
|
| from modules.cupy_module.cupy_utils import cupy_launch |
|
|
| |
|
|
| kernel_Softsplat_updateOutput = ''' |
| extern "C" __global__ void kernel_Softsplat_updateOutput( |
| const int n, |
| const float* input, |
| const float* flow, |
| float* output |
| ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { |
| const int intN = ( intIndex / SIZE_3(output) / SIZE_2(output) / SIZE_1(output) ) % SIZE_0(output); |
| const int intC = ( intIndex / SIZE_3(output) / SIZE_2(output) ) % SIZE_1(output); |
| const int intY = ( intIndex / SIZE_3(output) ) % SIZE_2(output); |
| const int intX = ( intIndex ) % SIZE_3(output); |
| |
| float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); |
| float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); |
| |
| int intNorthwestX = (int) (floor(fltOutputX)); |
| int intNorthwestY = (int) (floor(fltOutputY)); |
| int intNortheastX = intNorthwestX + 1; |
| int intNortheastY = intNorthwestY; |
| int intSouthwestX = intNorthwestX; |
| int intSouthwestY = intNorthwestY + 1; |
| int intSoutheastX = intNorthwestX + 1; |
| int intSoutheastY = intNorthwestY + 1; |
| |
| float fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (intSoutheastY) - fltOutputY); |
| float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY); |
| float fltSouthwest = ((float) (intNortheastX) - fltOutputX) * (fltOutputY - (float) (intNortheastY)); |
| float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); |
| |
| if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(output)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(output))) { |
| atomicAdd(&output[OFFSET_4(output, intN, intC, intNorthwestY, intNorthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltNorthwest); |
| } |
| |
| if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(output)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(output))) { |
| atomicAdd(&output[OFFSET_4(output, intN, intC, intNortheastY, intNortheastX)], VALUE_4(input, intN, intC, intY, intX) * fltNortheast); |
| } |
| |
| if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(output)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(output))) { |
| atomicAdd(&output[OFFSET_4(output, intN, intC, intSouthwestY, intSouthwestX)], VALUE_4(input, intN, intC, intY, intX) * fltSouthwest); |
| } |
| |
| if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(output)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(output))) { |
| atomicAdd(&output[OFFSET_4(output, intN, intC, intSoutheastY, intSoutheastX)], VALUE_4(input, intN, intC, intY, intX) * fltSoutheast); |
| } |
| } } |
| ''' |
|
|
| kernel_Softsplat_updateGradInput = ''' |
| extern "C" __global__ void kernel_Softsplat_updateGradInput( |
| const int n, |
| const float* input, |
| const float* flow, |
| const float* gradOutput, |
| float* gradInput, |
| float* gradFlow |
| ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { |
| const int intN = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) / SIZE_1(gradInput) ) % SIZE_0(gradInput); |
| const int intC = ( intIndex / SIZE_3(gradInput) / SIZE_2(gradInput) ) % SIZE_1(gradInput); |
| const int intY = ( intIndex / SIZE_3(gradInput) ) % SIZE_2(gradInput); |
| const int intX = ( intIndex ) % SIZE_3(gradInput); |
| |
| float fltGradInput = 0.0; |
| |
| float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); |
| float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); |
| |
| int intNorthwestX = (int) (floor(fltOutputX)); |
| int intNorthwestY = (int) (floor(fltOutputY)); |
| int intNortheastX = intNorthwestX + 1; |
| int intNortheastY = intNorthwestY; |
| int intSouthwestX = intNorthwestX; |
| int intSouthwestY = intNorthwestY + 1; |
| int intSoutheastX = intNorthwestX + 1; |
| int intSoutheastY = intNorthwestY + 1; |
| |
| float fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (intSoutheastY) - fltOutputY); |
| float fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (intSouthwestY) - fltOutputY); |
| float fltSouthwest = ((float) (intNortheastX) - fltOutputX) * (fltOutputY - (float) (intNortheastY)); |
| float fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * (fltOutputY - (float) (intNorthwestY)); |
| |
| if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { |
| fltGradInput += VALUE_4(gradOutput, intN, intC, intNorthwestY, intNorthwestX) * fltNorthwest; |
| } |
| |
| if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { |
| fltGradInput += VALUE_4(gradOutput, intN, intC, intNortheastY, intNortheastX) * fltNortheast; |
| } |
| |
| if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { |
| fltGradInput += VALUE_4(gradOutput, intN, intC, intSouthwestY, intSouthwestX) * fltSouthwest; |
| } |
| |
| if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { |
| fltGradInput += VALUE_4(gradOutput, intN, intC, intSoutheastY, intSoutheastX) * fltSoutheast; |
| } |
| |
| gradInput[intIndex] = fltGradInput; |
| } } |
| ''' |
|
|
| kernel_Softsplat_updateGradFlow = ''' |
| extern "C" __global__ void kernel_Softsplat_updateGradFlow( |
| const int n, |
| const float* input, |
| const float* flow, |
| const float* gradOutput, |
| float* gradInput, |
| float* gradFlow |
| ) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) { |
| float fltGradFlow = 0.0; |
| |
| const int intN = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) / SIZE_1(gradFlow) ) % SIZE_0(gradFlow); |
| const int intC = ( intIndex / SIZE_3(gradFlow) / SIZE_2(gradFlow) ) % SIZE_1(gradFlow); |
| const int intY = ( intIndex / SIZE_3(gradFlow) ) % SIZE_2(gradFlow); |
| const int intX = ( intIndex ) % SIZE_3(gradFlow); |
| |
| float fltOutputX = (float) (intX) + VALUE_4(flow, intN, 0, intY, intX); |
| float fltOutputY = (float) (intY) + VALUE_4(flow, intN, 1, intY, intX); |
| |
| int intNorthwestX = (int) (floor(fltOutputX)); |
| int intNorthwestY = (int) (floor(fltOutputY)); |
| int intNortheastX = intNorthwestX + 1; |
| int intNortheastY = intNorthwestY; |
| int intSouthwestX = intNorthwestX; |
| int intSouthwestY = intNorthwestY + 1; |
| int intSoutheastX = intNorthwestX + 1; |
| int intSoutheastY = intNorthwestY + 1; |
| |
| float fltNorthwest = 0.0; |
| float fltNortheast = 0.0; |
| float fltSouthwest = 0.0; |
| float fltSoutheast = 0.0; |
| |
| if (intC == 0) { |
| fltNorthwest = ((float) (-1.0)) * ((float) (intSoutheastY) - fltOutputY); |
| fltNortheast = ((float) (+1.0)) * ((float) (intSouthwestY) - fltOutputY); |
| fltSouthwest = ((float) (-1.0)) * (fltOutputY - (float) (intNortheastY)); |
| fltSoutheast = ((float) (+1.0)) * (fltOutputY - (float) (intNorthwestY)); |
| |
| } else if (intC == 1) { |
| fltNorthwest = ((float) (intSoutheastX) - fltOutputX) * ((float) (-1.0)); |
| fltNortheast = (fltOutputX - (float) (intSouthwestX)) * ((float) (-1.0)); |
| fltSouthwest = ((float) (intNortheastX) - fltOutputX) * ((float) (+1.0)); |
| fltSoutheast = (fltOutputX - (float) (intNorthwestX)) * ((float) (+1.0)); |
| |
| } |
| |
| for (int intChannel = 0; intChannel < SIZE_1(gradOutput); intChannel += 1) { |
| float fltInput = VALUE_4(input, intN, intChannel, intY, intX); |
| |
| if ((intNorthwestX >= 0) & (intNorthwestX < SIZE_3(gradOutput)) & (intNorthwestY >= 0) & (intNorthwestY < SIZE_2(gradOutput))) { |
| fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNorthwestY, intNorthwestX) * fltNorthwest; |
| } |
| |
| if ((intNortheastX >= 0) & (intNortheastX < SIZE_3(gradOutput)) & (intNortheastY >= 0) & (intNortheastY < SIZE_2(gradOutput))) { |
| fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intNortheastY, intNortheastX) * fltNortheast; |
| } |
| |
| if ((intSouthwestX >= 0) & (intSouthwestX < SIZE_3(gradOutput)) & (intSouthwestY >= 0) & (intSouthwestY < SIZE_2(gradOutput))) { |
| fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSouthwestY, intSouthwestX) * fltSouthwest; |
| } |
| |
| if ((intSoutheastX >= 0) & (intSoutheastX < SIZE_3(gradOutput)) & (intSoutheastY >= 0) & (intSoutheastY < SIZE_2(gradOutput))) { |
| fltGradFlow += fltInput * VALUE_4(gradOutput, intN, intChannel, intSoutheastY, intSoutheastX) * fltSoutheast; |
| } |
| } |
| |
| gradFlow[intIndex] = fltGradFlow; |
| } } |
| ''' |
|
|
| def cupy_kernel(strFunction, objVariables): |
| strKernel = globals()[strFunction] |
|
|
| while True: |
| objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel) |
|
|
| if objMatch is None: |
| break |
| |
|
|
| intArg = int(objMatch.group(2)) |
|
|
| strTensor = objMatch.group(4) |
| intSizes = objVariables[strTensor].size() |
|
|
| strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg])) |
| |
|
|
| while True: |
| objMatch = re.search('(OFFSET_)([0-4])(\()([^\)]+)(\))', strKernel) |
|
|
| if objMatch is None: |
| break |
| |
|
|
| intArgs = int(objMatch.group(2)) |
| strArgs = objMatch.group(4).split(',') |
|
|
| strTensor = strArgs[0] |
| intStrides = objVariables[strTensor].stride() |
| strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] |
|
|
| strKernel = strKernel.replace(objMatch.group(0), '(' + str.join('+', strIndex) + ')') |
| |
|
|
| while True: |
| objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel) |
|
|
| if objMatch is None: |
| break |
| |
|
|
| intArgs = int(objMatch.group(2)) |
| strArgs = objMatch.group(4).split(',') |
|
|
| strTensor = strArgs[0] |
| intStrides = objVariables[strTensor].stride() |
| strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ] |
|
|
| strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']') |
| |
|
|
| return strKernel |
| |
|
|
| class _FunctionSoftsplat(torch.autograd.Function): |
| @staticmethod |
| def forward(self, input, flow): |
| intSamples = input.shape[0] |
| intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] |
| intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] |
|
|
| assert(intFlowDepth == 2) |
| assert(intInputHeight == intFlowHeight) |
| assert(intInputWidth == intFlowWidth) |
|
|
| input = input.contiguous(); assert(input.is_cuda == True) |
| flow = flow.contiguous(); assert(flow.is_cuda == True) |
|
|
| output = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) |
|
|
| if input.is_cuda == True: |
| n = output.nelement() |
| cupy_launch('kernel_Softsplat_updateOutput', cupy_kernel('kernel_Softsplat_updateOutput', { |
| 'input': input, |
| 'flow': flow, |
| 'output': output |
| }))( |
| grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), |
| block=tuple([ 512, 1, 1 ]), |
| args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), output.data_ptr() ] |
| ) |
|
|
| elif input.is_cuda == False: |
| raise NotImplementedError() |
|
|
| |
|
|
| self.save_for_backward(input, flow) |
|
|
| return output |
| |
|
|
| @staticmethod |
| def backward(self, gradOutput): |
| input, flow = self.saved_tensors |
|
|
| intSamples = input.shape[0] |
| intInputDepth, intInputHeight, intInputWidth = input.shape[1], input.shape[2], input.shape[3] |
| intFlowDepth, intFlowHeight, intFlowWidth = flow.shape[1], flow.shape[2], flow.shape[3] |
|
|
| assert(intFlowDepth == 2) |
| assert(intInputHeight == intFlowHeight) |
| assert(intInputWidth == intFlowWidth) |
|
|
| gradOutput = gradOutput.contiguous(); assert(gradOutput.is_cuda == True) |
|
|
| gradInput = input.new_zeros([ intSamples, intInputDepth, intInputHeight, intInputWidth ]) if self.needs_input_grad[0] == True else None |
| gradFlow = input.new_zeros([ intSamples, intFlowDepth, intFlowHeight, intFlowWidth ]) if self.needs_input_grad[1] == True else None |
|
|
| if input.is_cuda == True: |
| if gradInput is not None: |
| n = gradInput.nelement() |
| cupy_launch('kernel_Softsplat_updateGradInput', cupy_kernel('kernel_Softsplat_updateGradInput', { |
| 'input': input, |
| 'flow': flow, |
| 'gradOutput': gradOutput, |
| 'gradInput': gradInput, |
| 'gradFlow': gradFlow |
| }))( |
| grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), |
| block=tuple([ 512, 1, 1 ]), |
| args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), gradInput.data_ptr(), None ] |
| ) |
| |
|
|
| if gradFlow is not None: |
| n = gradFlow.nelement() |
| cupy_launch('kernel_Softsplat_updateGradFlow', cupy_kernel('kernel_Softsplat_updateGradFlow', { |
| 'input': input, |
| 'flow': flow, |
| 'gradOutput': gradOutput, |
| 'gradInput': gradInput, |
| 'gradFlow': gradFlow |
| }))( |
| grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]), |
| block=tuple([ 512, 1, 1 ]), |
| args=[ cupy.int32(n), input.data_ptr(), flow.data_ptr(), gradOutput.data_ptr(), None, gradFlow.data_ptr() ] |
| ) |
| |
|
|
| elif input.is_cuda == False: |
| raise NotImplementedError() |
|
|
| |
|
|
| return gradInput, gradFlow |
| |
| |
|
|
| def FunctionSoftsplat(tenInput, tenFlow, tenMetric, strType): |
| assert(tenMetric is None or tenMetric.shape[1] == 1) |
| assert(strType in ['summation', 'average', 'linear', 'softmax']) |
|
|
| if strType == 'average': |
| tenInput = torch.cat([ tenInput, tenInput.new_ones(tenInput.shape[0], 1, tenInput.shape[2], tenInput.shape[3]) ], 1) |
|
|
| elif strType == 'linear': |
| tenInput = torch.cat([ tenInput * tenMetric, tenMetric ], 1) |
|
|
| elif strType == 'softmax': |
| tenInput = torch.cat([ tenInput * tenMetric.exp(), tenMetric.exp() ], 1) |
|
|
| |
|
|
| tenOutput = _FunctionSoftsplat.apply(tenInput, tenFlow) |
|
|
| if strType != 'summation': |
| tenNormalize = tenOutput[:, -1:, :, :] |
|
|
| tenNormalize[tenNormalize == 0.0] = 1.0 |
|
|
| tenOutput = tenOutput[:, :-1, :, :] / tenNormalize |
| |
|
|
| return tenOutput |
| |
|
|
| class ModuleSoftsplat(torch.nn.Module): |
| def __init__(self, strType): |
| super().__init__() |
|
|
| self.strType = strType |
| |
|
|
| def forward(self, tenInput, tenFlow, tenMetric): |
| return FunctionSoftsplat(tenInput, tenFlow, tenMetric, self.strType) |
| |
| |