File size: 3,724 Bytes
e9fe176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
grad = require 'autograd'
--[[
    Input: [AlphaScale, Alpha, F, Theta, Trans]
    Output: [X, Y], both projected to 2D space
--]]

local Projection, parent = torch.class('nn.Projection', 'nn.Module')

function Projection:__init(Bs, minParams, maxParams, w, numAlpha, numF, numTheta, numTrans, imW, imH)
    parent.__init(self)

    assert((#Bs)[1] == numAlpha)
    assert((#Bs)[2] == 3)

    self.Bs = Bs
    self.numPoints = self.Bs[1]:size(2)
    self.minParams = minParams
    self.maxParams = maxParams
    self.w = w
    
    self.numAlpha = numAlpha
    self.numF = numF or 1
    self.numTheta = numTheta or 6
    self.numTrans = numTrans or 2
    self.num3DParams = self.numAlpha + self.numF + self.numTheta + self.numTrans
    self.indF = self.numAlpha
    self.indR = self.indF + self.numF
    self.indT = self.indR + self.numTheta

    self.imW = imW or 240
    self.imH = imH or 320
end

local function updateScalarOutput(input, p, q, r, self)
    assert(r == 1 or r == 2)
   
    local scaled = {}
    for i = 1, self.num3DParams do
        scaled[i] = input[p][i] / self.w[i] * (self.maxParams[i] - self.minParams[i]) + self.minParams[i]
    end

    local sumB = scaled[1] * self.Bs[{1, {}, q}]
    for i = 2, self.numAlpha do
        sumB = sumB + scaled[i] * self.Bs[{i, {}, q}]
    end

    local S1 = scaled[self.indR + 1]
    local S2 = scaled[self.indR + 2]
    local S3 = scaled[self.indR + 3]
    local C1 = scaled[self.indR + 4]
    local C2 = scaled[self.indR + 5]
    local C3 = scaled[self.indR + 6]
    local X3D = C2 * C3 * sumB[1] - C2 * S3 * sumB[2] + S2 * sumB[3] + scaled[self.indT + 1]
    local Y3D = (S1 * S2 * C3 + C1 * S3) * sumB[1] + (-S1 * S2 * S3 + C1 * C3) * sumB[2] - S1 * C2 * sumB[3] + scaled[self.indT + 2]
    local Z3D = (-C1 * S2 * C3 + S1 * S3) * sumB[1] + (C1 * S2 * S3 + S1 * C3) * sumB[2] + C1 * C2 * sumB[3]

    -- input is finv
    local f = 1 / scaled[self.indF + 1]
    if r == 1 then
        return (f / (f + Z3D) * Y3D) / self.imH + 0.5
    else
        return (f / (f + Z3D) * X3D) / self.imW + 0.5
    end
end

function Projection:updateOutput(input)
    assert(input:size(2) == self.num3DParams)
    
    local batchSize = input:size(1)
    self.output:resize(batchSize, self.numPoints, 2)

    self.dnet = {}
    for p = 1, batchSize do
        self.dnet[p] = {}
        for q = 1, self.numPoints do
            self.dnet[p][q] = {}
            for r = 1, 2 do
                self.output[p][q][r] = updateScalarOutput(input, p, q, r, self)                
                if self.train then
                    self.dnet[p][q][r] = grad(updateScalarOutput)
                end
            end
        end
        --print(input[p])
        --print(self.output[p])
    end

    return self.output
end

function Projection:updateGradInput(input, gradOutput)
    local batchSize = input:size(1)
    assert(self.train == true, 'should be in training mode when self.train is true')
    assert(self.dnet, 'must call :updateOutput() first')

    self.gradInput:resizeAs(input):zero()

    for p = 1, batchSize do
        for q = 1, self.numPoints do
            for r = 1, 2 do
                local stepGrad = self.dnet[p][q][r](input, p, q, r, self) * gradOutput[p][q][r]
                if stepGrad:sum() ~= stepGrad:sum() then
                    print(p..' '..q..' '..r..' nan')
                end
                self.gradInput = self.gradInput + stepGrad 
            end
        end
    end
   
    if self.gradInput:min() < -10 or self.gradInput:max() > 10 then
        print 'Projection Gradient Exploding'
    end
    self.dnet = nil

    return self.gradInput
end

function Projection:__tostring__()
   return string.format('%s()', torch.type(self))
end