|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
local VRRegressReward, parent = torch.class("nn.VRRegressReward", "nn.Criterion") |
|
|
|
|
|
function VRRegressReward:__init(module, scale, rho, criterion) |
|
|
parent.__init(self) |
|
|
self.module = module |
|
|
self.scale = scale or 1 |
|
|
self.rho = rho or 1 |
|
|
self.criterion = criterion or nn.MSECriterion() |
|
|
self.sizeAverage = true |
|
|
end |
|
|
|
|
|
function VRRegressReward:updateOutput(inputTable, target) |
|
|
assert(torch.type(inputTable) == 'table') |
|
|
local input = self:toBatch(inputTable[1], 1) |
|
|
local baseline = self:toBatch(inputTable[2], 1) |
|
|
assert((#input)[1] * self.rho == (#baseline)[1]) |
|
|
|
|
|
|
|
|
local diff = (input - target):pow(2):mul(-self.scale) |
|
|
self.reward = self.reward or baseline.new() |
|
|
self.reward:resize((#baseline)[1]) |
|
|
for i = 1, (#input)[1] do |
|
|
self.reward[{{(i - 1) * self.rho + 1, i * self.rho}}] = diff[i]:mean() |
|
|
end |
|
|
|
|
|
|
|
|
self.output = -self.reward:sum() |
|
|
if self.sizeAverage then |
|
|
self.output = self.output/(#baseline)[1] |
|
|
end |
|
|
return self.output |
|
|
end |
|
|
|
|
|
function VRRegressReward:updateGradInput(inputTable, target) |
|
|
local input = self:toBatch(inputTable[1], 1) |
|
|
local baseline = self:toBatch(inputTable[2], 1) |
|
|
|
|
|
|
|
|
self.vrReward = self.vrReward or self.reward.new() |
|
|
self.vrReward:resizeAs(self.reward):copy(self.reward) |
|
|
self.vrReward:add(-1, baseline) |
|
|
if self.sizeAverage then |
|
|
self.vrReward:div(input:size(1)) |
|
|
end |
|
|
|
|
|
self.module:reinforce(self.vrReward) |
|
|
|
|
|
|
|
|
self.gradInput = self.gradInput or {} |
|
|
self.gradInput[1] = self.gradInput[1] or input.new() |
|
|
self.gradInput[1]:resizeAs(input):zero() |
|
|
self.gradInput[1] = self:fromBatch(self.gradInput[1], 1) |
|
|
|
|
|
|
|
|
self.gradInput[2] = self.criterion:backward(baseline, self.reward) |
|
|
self.gradInput[2] = self:fromBatch(self.gradInput[2], 1) |
|
|
return self.gradInput |
|
|
end |
|
|
|
|
|
function VRRegressReward:type(type) |
|
|
self._maxVal = nil |
|
|
self._maxIdx = nil |
|
|
self._target = nil |
|
|
local module = self.module |
|
|
self.module = nil |
|
|
local ret = parent.type(self, type) |
|
|
self.module = module |
|
|
return ret |
|
|
end |
|
|
|