rawanessam's picture
Upload 79 files
e9fe176 verified
--
-- code derived from https://github.com/soumith/dcgan.torch
--
local util = {}
require 'torch'
function util.normalize(img)
-- rescale image to 0 .. 1
local min = img:min()
local max = img:max()
img = torch.FloatTensor(img:size()):copy(img)
img:add(-min):mul(1/(max-min))
return img
end
function util.normalizeBatch(batch)
for i = 1, batch:size(1) do
batch[i] = util.normalize(batch[i]:squeeze())
end
return batch
end
function util.basename_batch(batch)
for i = 1, #batch do
batch[i] = paths.basename(batch[i])
end
return batch
end
-- default preprocessing
--
-- Preprocesses an image before passing it to a net
-- Converts from RGB to BGR and rescales from [0,1] to [-1,1]
function util.preprocess(img)
-- RGB to BGR
local perm = torch.LongTensor{3, 2, 1}
img = img:index(1, perm)
-- [0,1] to [-1,1]
img = img:mul(2):add(-1)
-- check that input is in expected range
assert(img:max()<=1,"badly scaled inputs")
assert(img:min()>=-1,"badly scaled inputs")
return img
end
-- Undo the above preprocessing.
function util.deprocess(img)
-- BGR to RGB
local perm = torch.LongTensor{3, 2, 1}
img = img:index(1, perm)
-- [-1,1] to [0,1]
img = img:add(1):div(2)
return img
end
function util.preprocess_batch(batch)
for i = 1, batch:size(1) do
batch[i] = util.preprocess(batch[i]:squeeze())
end
return batch
end
function util.deprocess_batch(batch)
for i = 1, batch:size(1) do
batch[i] = util.deprocess(batch[i]:squeeze())
end
return batch
end
-- preprocessing specific to colorization
function util.deprocessLAB(L, AB)
local L2 = torch.Tensor(L:size()):copy(L)
if L2:dim() == 3 then
L2 = L2[{1, {}, {} }]
end
local AB2 = torch.Tensor(AB:size()):copy(AB)
AB2 = torch.clamp(AB2, -1.0, 1.0)
-- local AB2 = AB
L2 = L2:add(1):mul(50.0)
AB2 = AB2:mul(110.0)
L2 = L2:reshape(1, L2:size(1), L2:size(2))
im_lab = torch.cat(L2, AB2, 1)
im_rgb = torch.clamp(image.lab2rgb(im_lab):mul(255.0), 0.0, 255.0)/255.0
return im_rgb
end
function util.deprocessL(L)
local L2 = torch.Tensor(L:size()):copy(L)
L2 = L2:add(1):mul(255.0/2.0)
if L2:dim()==2 then
L2 = L2:reshape(1,L2:size(1),L2:size(2))
end
L2 = L2:repeatTensor(L2,3,1,1)/255.0
return L2
end
function util.deprocessL_batch(batch)
local batch_new = {}
for i = 1, batch:size(1) do
batch_new[i] = util.deprocessL(batch[i]:squeeze())
end
return batch_new
end
function util.deprocessLAB_batch(batchL, batchAB)
local batch = {}
for i = 1, batchL:size(1) do
batch[i] = util.deprocessLAB(batchL[i]:squeeze(), batchAB[i]:squeeze())
end
return batch
end
function util.scaleBatch(batch,s1,s2)
local scaled_batch = torch.Tensor(batch:size(1),batch:size(2),s1,s2)
for i = 1, batch:size(1) do
scaled_batch[i] = image.scale(batch[i],s1,s2):squeeze()
end
return scaled_batch
end
function util.toTrivialBatch(input)
return input:reshape(1,input:size(1),input:size(2),input:size(3))
end
function util.fromTrivialBatch(input)
return input[1]
end
function util.scaleImage(input, loadSize)
-- replicate bw images to 3 channels
if input:size(1)==1 then
input = torch.repeatTensor(input,3,1,1)
end
input = image.scale(input, loadSize, loadSize)
return input
end
function util.getAspectRatio(path)
local input = image.load(path, 3, 'float')
local ar = input:size(3)/input:size(2)
return ar
end
function util.loadImage(path, loadSize, nc)
local input = image.load(path, 3, 'float')
input= util.preprocess(util.scaleImage(input, loadSize))
if nc == 1 then
input = input[{{1}, {}, {}}]
end
return input
end
-- TO DO: loading code is rather hacky; clean it up and make sure it works on all types of nets / cpu/gpu configurations
function util.load(filename, opt)
if opt.cudnn>0 then
require 'cudnn'
end
local net = torch.load(filename)
if opt.gpu > 0 then
require 'cunn'
net:cuda()
-- calling cuda on cudnn saved nngraphs doesn't change all variables to cuda, so do it below
if net.forwardnodes then
for i=1,#net.forwardnodes do
if net.forwardnodes[i].data.module then
net.forwardnodes[i].data.module:cuda()
end
end
end
else
net:float()
end
net:apply(function(m) if m.weight then
m.gradWeight = m.weight:clone():zero();
m.gradBias = m.bias:clone():zero(); end end)
return net
end
function util.cudnn(net)
require 'cudnn'
require 'util/cudnn_convert_custom'
return cudnn_convert_custom(net, cudnn)
end
return util