|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local VRMaskRegressReward, parent = torch.class("nn.VRMaskRegressReward", "nn.Criterion") |
|
|
|
|
|
function VRMaskRegressReward:__init(module, scale, rho, criterion) |
|
|
parent.__init(self) |
|
|
self.module = module |
|
|
self.scale = scale or 1 |
|
|
self.rho = rho or 1 |
|
|
self.criterion = criterion or nn.MSECriterion() |
|
|
self.sizeAverage = true |
|
|
self.gradInput = {} |
|
|
end |
|
|
|
|
|
function VRMaskRegressReward:updateOutput(inputTable, targetTable) |
|
|
assert(torch.type(inputTable) == 'table') |
|
|
local input = self:toBatch(inputTable[1], 1) |
|
|
local baseline = self:toBatch(inputTable[2], 1) |
|
|
assert((#input)[1] * self.rho == (#baseline)[1]) |
|
|
|
|
|
assert(torch.type(targetTable) == 'table') |
|
|
local target = self:toBatch(targetTable[1], 1) |
|
|
local mask = self:toBatch(targetTable[2], 1) |
|
|
|
|
|
|
|
|
self.reward = self.reward or baseline.new() |
|
|
self.reward:resize((#baseline)[1]) |
|
|
for i = 1, (#input)[1] do |
|
|
local diff = (input[i]:maskedSelect(mask[i]) - |
|
|
target[i]:maskedSelect(mask[i])):pow(2):mul(-self.scale) |
|
|
if diff:dim() > 0 then |
|
|
self.reward[{{(i - 1) * self.rho + 1, i * self.rho}}] = diff:mean() |
|
|
else |
|
|
self.reward[{{(i - 1) * self.rho + 1, i * self.rho}}] = 0 |
|
|
end |
|
|
end |
|
|
|
|
|
|
|
|
self.output = -self.reward:sum() |
|
|
if self.sizeAverage then |
|
|
self.output = self.output/(#baseline)[1] |
|
|
end |
|
|
return self.output |
|
|
end |
|
|
|
|
|
function VRMaskRegressReward:updateGradInput(inputTable, target) |
|
|
local input = self:toBatch(inputTable[1], 1) |
|
|
local baseline = self:toBatch(inputTable[2], 1) |
|
|
|
|
|
|
|
|
self.vrReward = self.vrReward or self.reward.new() |
|
|
self.vrReward:resizeAs(self.reward):copy(self.reward) |
|
|
self.vrReward:add(-1, baseline) |
|
|
if self.sizeAverage then |
|
|
self.vrReward:div(input:size(1)) |
|
|
end |
|
|
|
|
|
self.module:reinforce(self.vrReward) |
|
|
|
|
|
|
|
|
self.gradInput = self.gradInput or {} |
|
|
self.gradInput[1] = self.gradInput[1] or input.new() |
|
|
self.gradInput[1]:resizeAs(input):zero() |
|
|
self.gradInput[1] = self:fromBatch(self.gradInput[1], 1) |
|
|
|
|
|
|
|
|
self.gradInput[2] = self.criterion:backward(baseline, self.reward) |
|
|
self.gradInput[2] = self:fromBatch(self.gradInput[2], 1) |
|
|
return self.gradInput |
|
|
end |
|
|
|
|
|
function VRMaskRegressReward:type(type) |
|
|
self._maxVal = nil |
|
|
self._maxIdx = nil |
|
|
self._target = nil |
|
|
local module = self.module |
|
|
self.module = nil |
|
|
local ret = parent.type(self, type) |
|
|
self.module = module |
|
|
return ret |
|
|
end |
|
|
|