File size: 4,683 Bytes
e9fe176 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
--
-- code derived from https://github.com/soumith/dcgan.torch
--
local util = {}
require 'torch'
function util.normalize(img)
-- rescale image to 0 .. 1
local min = img:min()
local max = img:max()
img = torch.FloatTensor(img:size()):copy(img)
img:add(-min):mul(1/(max-min))
return img
end
function util.normalizeBatch(batch)
for i = 1, batch:size(1) do
batch[i] = util.normalize(batch[i]:squeeze())
end
return batch
end
function util.basename_batch(batch)
for i = 1, #batch do
batch[i] = paths.basename(batch[i])
end
return batch
end
-- default preprocessing
--
-- Preprocesses an image before passing it to a net
-- Converts from RGB to BGR and rescales from [0,1] to [-1,1]
function util.preprocess(img)
-- RGB to BGR
local perm = torch.LongTensor{3, 2, 1}
img = img:index(1, perm)
-- [0,1] to [-1,1]
img = img:mul(2):add(-1)
-- check that input is in expected range
assert(img:max()<=1,"badly scaled inputs")
assert(img:min()>=-1,"badly scaled inputs")
return img
end
-- Undo the above preprocessing.
function util.deprocess(img)
-- BGR to RGB
local perm = torch.LongTensor{3, 2, 1}
img = img:index(1, perm)
-- [-1,1] to [0,1]
img = img:add(1):div(2)
return img
end
function util.preprocess_batch(batch)
for i = 1, batch:size(1) do
batch[i] = util.preprocess(batch[i]:squeeze())
end
return batch
end
function util.deprocess_batch(batch)
for i = 1, batch:size(1) do
batch[i] = util.deprocess(batch[i]:squeeze())
end
return batch
end
-- preprocessing specific to colorization
function util.deprocessLAB(L, AB)
local L2 = torch.Tensor(L:size()):copy(L)
if L2:dim() == 3 then
L2 = L2[{1, {}, {} }]
end
local AB2 = torch.Tensor(AB:size()):copy(AB)
AB2 = torch.clamp(AB2, -1.0, 1.0)
-- local AB2 = AB
L2 = L2:add(1):mul(50.0)
AB2 = AB2:mul(110.0)
L2 = L2:reshape(1, L2:size(1), L2:size(2))
im_lab = torch.cat(L2, AB2, 1)
im_rgb = torch.clamp(image.lab2rgb(im_lab):mul(255.0), 0.0, 255.0)/255.0
return im_rgb
end
function util.deprocessL(L)
local L2 = torch.Tensor(L:size()):copy(L)
L2 = L2:add(1):mul(255.0/2.0)
if L2:dim()==2 then
L2 = L2:reshape(1,L2:size(1),L2:size(2))
end
L2 = L2:repeatTensor(L2,3,1,1)/255.0
return L2
end
function util.deprocessL_batch(batch)
local batch_new = {}
for i = 1, batch:size(1) do
batch_new[i] = util.deprocessL(batch[i]:squeeze())
end
return batch_new
end
function util.deprocessLAB_batch(batchL, batchAB)
local batch = {}
for i = 1, batchL:size(1) do
batch[i] = util.deprocessLAB(batchL[i]:squeeze(), batchAB[i]:squeeze())
end
return batch
end
function util.scaleBatch(batch,s1,s2)
local scaled_batch = torch.Tensor(batch:size(1),batch:size(2),s1,s2)
for i = 1, batch:size(1) do
scaled_batch[i] = image.scale(batch[i],s1,s2):squeeze()
end
return scaled_batch
end
function util.toTrivialBatch(input)
return input:reshape(1,input:size(1),input:size(2),input:size(3))
end
function util.fromTrivialBatch(input)
return input[1]
end
function util.scaleImage(input, loadSize)
-- replicate bw images to 3 channels
if input:size(1)==1 then
input = torch.repeatTensor(input,3,1,1)
end
input = image.scale(input, loadSize, loadSize)
return input
end
function util.getAspectRatio(path)
local input = image.load(path, 3, 'float')
local ar = input:size(3)/input:size(2)
return ar
end
function util.loadImage(path, loadSize, nc)
local input = image.load(path, 3, 'float')
input= util.preprocess(util.scaleImage(input, loadSize))
if nc == 1 then
input = input[{{1}, {}, {}}]
end
return input
end
-- TO DO: loading code is rather hacky; clean it up and make sure it works on all types of nets / cpu/gpu configurations
function util.load(filename, opt)
if opt.cudnn>0 then
require 'cudnn'
end
local net = torch.load(filename)
if opt.gpu > 0 then
require 'cunn'
net:cuda()
-- calling cuda on cudnn saved nngraphs doesn't change all variables to cuda, so do it below
if net.forwardnodes then
for i=1,#net.forwardnodes do
if net.forwardnodes[i].data.module then
net.forwardnodes[i].data.module:cuda()
end
end
end
else
net:float()
end
net:apply(function(m) if m.weight then
m.gradWeight = m.weight:clone():zero();
m.gradBias = m.bias:clone():zero(); end end)
return net
end
function util.cudnn(net)
require 'cudnn'
require 'util/cudnn_convert_custom'
return cudnn_convert_custom(net, cudnn)
end
return util
|