FloorPlanTransformation / util /lua /nn /CrossMergeTable.lua
rawanessam's picture
Upload 79 files
e9fe176 verified
--[[
Input: A length-2 table of length-K tables
Input[1]: K tables
Input[2]: K tables
Output: A length-K table of length-2 tables
Output[i]: {Input[1][i], input[2][i]}
--]]
local CrossMergeTable, parent = torch.class('nn.CrossMergeTable', 'nn.Module')
function CrossMergeTable:__init()
parent.__init(self)
self.gradInput = {}
end
function CrossMergeTable:updateOutput(input)
assert(#input == 2)
assert(#input[1] == #input[2])
self.output = {}
for i = 1, #input[1] do
self.output[i] = {input[1][i], input[2][i]}
end
return self.output
end
function CrossMergeTable:updateGradInput(input, gradOutput)
for i = 1, 2 do
if not self.gradInput[i] then
self.gradInput[i] = {}
end
end
for i = 1, #gradOutput do
self.gradInput[1][i] = gradOutput[i][1]
self.gradInput[2][i] = gradOutput[i][2]
end
return self.gradInput
end
function CrossMergeTable:__tostring__()
return string.format('%s()', torch.type(self))
end