File size: 1,031 Bytes
e9fe176 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
--[[
Input: A length-2 table of length-K tables
Input[1]: K tables
Input[2]: K tables
Output: A length-K table of length-2 tables
Output[i]: {Input[1][i], input[2][i]}
--]]
local CrossMergeTable, parent = torch.class('nn.CrossMergeTable', 'nn.Module')
function CrossMergeTable:__init()
parent.__init(self)
self.gradInput = {}
end
function CrossMergeTable:updateOutput(input)
assert(#input == 2)
assert(#input[1] == #input[2])
self.output = {}
for i = 1, #input[1] do
self.output[i] = {input[1][i], input[2][i]}
end
return self.output
end
function CrossMergeTable:updateGradInput(input, gradOutput)
for i = 1, 2 do
if not self.gradInput[i] then
self.gradInput[i] = {}
end
end
for i = 1, #gradOutput do
self.gradInput[1][i] = gradOutput[i][1]
self.gradInput[2][i] = gradOutput[i][2]
end
return self.gradInput
end
function CrossMergeTable:__tostring__()
return string.format('%s()', torch.type(self))
end
|