|
|
local CrossConvolveParallel, parent = torch.class('nn.CrossConvolveParallel', 'nn.Module') |
|
|
|
|
|
function CrossConvolveParallel:__init(nInputPlane, nOutputPlane, kW, kH) |
|
|
parent.__init(self) |
|
|
|
|
|
kH = kH or kW |
|
|
|
|
|
|
|
|
self.nInputPlane = nInputPlane or 1 |
|
|
self.nOutputPlane = nOutputPlane or 1 |
|
|
self.kW = kW |
|
|
self.kH = kH |
|
|
|
|
|
|
|
|
self.padH = math.floor(kH / 2) |
|
|
self.padW = math.floor(kW / 2) |
|
|
|
|
|
self.convolver = {} |
|
|
|
|
|
self.gradInput = {} |
|
|
end |
|
|
|
|
|
function CrossConvolveParallel:updateOutput(input) |
|
|
assert(input[2]:size(5) == self.kH) |
|
|
assert(input[2]:size(4) == self.kW) |
|
|
|
|
|
|
|
|
if #self.convolver < input[1]:size(1) then |
|
|
for i=1,input[1]:size(1) do |
|
|
local convolver = nn.Sequential() |
|
|
convolver:add(nn.SpatialZeroPadding(self.padW, self.padW, self.padH, self.padH)) |
|
|
convolver:add(nn.SpatialConvolution(self.nInputPlane, self.nOutputPlane, self.kW, self.kH)) |
|
|
|
|
|
|
|
|
convolver.modules[2].bias:zero() |
|
|
if self:type() == 'torch.CudaTensor' then |
|
|
convolver:cuda() |
|
|
end |
|
|
|
|
|
|
|
|
table.insert(self.convolver, convolver) |
|
|
end |
|
|
end |
|
|
|
|
|
|
|
|
local nSample = input[2]:size(1) |
|
|
for i=1,nSample do |
|
|
self.convolver[i].modules[2].weight = input[2][i] |
|
|
local tmp = self.convolver[i]:updateOutput(input[1][i]) |
|
|
if (self.output == nil) or (self.output:size():size()==0) or (self.output:size(1)~=nSample) then |
|
|
local h = tmp:size(2) |
|
|
local w = tmp:size(3) |
|
|
if self:type() == 'torch.CudaTensor' then |
|
|
self.output = torch.CudaTensor(nSample,self.nOutputPlane,h,w) |
|
|
else |
|
|
self.output = torch.DoubleTensor(nSample,self.nOutputPlane,h,w) |
|
|
end |
|
|
end |
|
|
self.output[i] = tmp |
|
|
end |
|
|
|
|
|
|
|
|
return self.output |
|
|
end |
|
|
|
|
|
function CrossConvolveParallel:updateGradInput(input, gradOutput) |
|
|
|
|
|
for i = 1, #input do |
|
|
if self.gradInput[i] == nil then |
|
|
self.gradInput[i] = input[i].new() |
|
|
end |
|
|
self.gradInput[i]:resizeAs(input[i]):zero() |
|
|
end |
|
|
|
|
|
|
|
|
for i = 1, input[1]:size(1) do |
|
|
self.convolver[i]:zeroGradParameters() |
|
|
self.gradInput[1][i]:add(self.convolver[i]:updateGradInput(input[1][i], gradOutput[i])) |
|
|
self.convolver[i]:accGradParameters(input[1][i], gradOutput[i]) |
|
|
self.gradInput[2][i]:add(self.convolver[i].modules[2].gradWeight) |
|
|
end |
|
|
|
|
|
|
|
|
return self.gradInput |
|
|
end |
|
|
|
|
|
function CrossConvolveParallel:clearState() |
|
|
self.convolver:clearState() |
|
|
return parent.clearState(self) |
|
|
end |
|
|
|
|
|
function CrossConvolveParallel:__tostring__() |
|
|
|
|
|
return torch.type(self) .. string.format(' (%dx%dx%dx%d)', |
|
|
self.nInputPlane, self.nOutputPlane, self.kW, self.kH) |
|
|
|
|
|
end |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|