FloorPlanTransformation / util /lua /nn /CrossConvolveParallel.lua
rawanessam's picture
Upload 79 files
e9fe176 verified
local CrossConvolveParallel, parent = torch.class('nn.CrossConvolveParallel', 'nn.Module')
function CrossConvolveParallel:__init(nInputPlane, nOutputPlane, kW, kH)
parent.__init(self)
kH = kH or kW
-- get args
self.nInputPlane = nInputPlane or 1
self.nOutputPlane = nOutputPlane or 1
self.kW = kW
self.kH = kH
-- padding values
self.padH = math.floor(kH / 2)
self.padW = math.floor(kW / 2)
self.convolver = {}
self.gradInput = {}
end
function CrossConvolveParallel:updateOutput(input)
assert(input[2]:size(5) == self.kH)
assert(input[2]:size(4) == self.kW)
-- Create convolver
if #self.convolver < input[1]:size(1) then
for i=1,input[1]:size(1) do
local convolver = nn.Sequential()
convolver:add(nn.SpatialZeroPadding(self.padW, self.padW, self.padH, self.padH))
convolver:add(nn.SpatialConvolution(self.nInputPlane, self.nOutputPlane, self.kW, self.kH))
-- set bias
convolver.modules[2].bias:zero()
if self:type() == 'torch.CudaTensor' then
convolver:cuda()
end
-- Add to list
table.insert(self.convolver, convolver)
end
end
-- compute output
local nSample = input[2]:size(1)
for i=1,nSample do
self.convolver[i].modules[2].weight = input[2][i]
local tmp = self.convolver[i]:updateOutput(input[1][i])
if (self.output == nil) or (self.output:size():size()==0) or (self.output:size(1)~=nSample) then
local h = tmp:size(2)
local w = tmp:size(3)
if self:type() == 'torch.CudaTensor' then
self.output = torch.CudaTensor(nSample,self.nOutputPlane,h,w)
else
self.output = torch.DoubleTensor(nSample,self.nOutputPlane,h,w)
end
end
self.output[i] = tmp
end
-- done
return self.output
end
function CrossConvolveParallel:updateGradInput(input, gradOutput)
-- resize grad
for i = 1, #input do
if self.gradInput[i] == nil then
self.gradInput[i] = input[i].new()
end
self.gradInput[i]:resizeAs(input[i]):zero()
end
-- backprop
for i = 1, input[1]:size(1) do
self.convolver[i]:zeroGradParameters()
self.gradInput[1][i]:add(self.convolver[i]:updateGradInput(input[1][i], gradOutput[i]))
self.convolver[i]:accGradParameters(input[1][i], gradOutput[i])
self.gradInput[2][i]:add(self.convolver[i].modules[2].gradWeight)
end
-- done
return self.gradInput
end
function CrossConvolveParallel:clearState()
self.convolver:clearState()
return parent.clearState(self)
end
function CrossConvolveParallel:__tostring__()
return torch.type(self) .. string.format(' (%dx%dx%dx%d)',
self.nInputPlane, self.nOutputPlane, self.kW, self.kH)
end