File size: 2,735 Bytes
e9fe176 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
--[[
Input: A length-2 Table
Input[1]: Responses, batchSize x nChannels x iH x iW
Input[2]: Pairs (minW, minH), batchSize x nFeatures x 2
Output: Cropped responses, batchSize x (nFeatures x nChannels) x oH x oW
--]]
local Crop, parent = torch.class('nn.Crop', 'nn.Module')
function Crop:__init(iW, iH, oW, oH)
parent.__init(self)
self.iW = iW
self.iH = iH
self.oW = oW
self.oH = oH
self.gradInput = {}
--self.vis = false
end
function Crop:getCoordinates()
return self.coords
end
function Crop:updateOutput(input)
assert(#input == 2)
assert(input[1]:size(1) == input[2]:size(1))
assert(input[1]:size(3) == self.iH)
assert(input[1]:size(4) == self.iW)
local batchSize = input[1]:size(1)
local nChannels = input[1]:size(2)
local nFeatures = input[2]:size(2)
--[[if self.vis then
for k = 1, math.min(5, nChannels) do
image.save('/home/jiajunwu/public_html/vis_results_11feat/CAS/crop_input_'..k..'.png', input[1][{1, k, {}, {}}]:double())
end
end
--]]
self.output:resize(batchSize, nChannels * nFeatures, self.oH, self.oW)
for p = 1, batchSize do
for q = 1, nFeatures do
local idxSt = (q - 1) * nChannels + 1
local idxEd = q * nChannels
self.output[{p, {idxSt, idxEd}, {}, {}}]:copy(input[1][{p, {},
{input[2][{p, q, 1}], input[2][{p, q, 1}] + self.oH - 1},
{input[2][{p, q, 2}], input[2][{p, q, 2}] + self.oW - 1}}])
end
end
--[[
if self.vis then
for r = 1, nFeatures do
for k = 1, math.min(5, nChannels) do
image.save('/home/jiajunwu/public_html/vis_results_11feat/CAS/crop_output_'..r..'chn'..k..'.png', self.output[{1, (r - 1) * nChannels + k, {}, {}}]:double())
end
end
end
self.vis = false
--]]
return self.output
end
function Crop:updateGradInput(input, gradOutput)
for i = 1, #input do
if self.gradInput[i] == nil then
self.gradInput[i] = input[i].new()
end
self.gradInput[i]:resizeAs(input[i]):zero()
end
local batchSize = input[1]:size(1)
local nChannels = input[1]:size(2)
local nFeatures = input[2]:size(2)
for p = 1, batchSize do
for q = 1, nFeatures do
local idxSt = (q - 1) * nChannels + 1
local idxEd = q * nChannels
self.gradInput[1][{p, {}, {input[2][{p, q, 1}], input[2][{p, q, 1}] + self.oH - 1},
{input[2][{p, q, 2}], input[2][{p, q, 2}] + self.oW - 1}}]:add(gradOutput[{p, {idxSt, idxEd}, {}, {}}])
end
end
return self.gradInput
end
function Crop:__tostring__()
return string.format('%s()', torch.type(self))
end
|