File size: 656 Bytes
e9fe176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
-- Load all common packages.
require 'nn'
require 'cunn'
require 'cudnn'
require 'fbnn'
require 'fbcunn'

local nnutils = require 'fbcode.deeplearning.experimental.yuandong.utils.nnutils'
local pl = require 'pl.import_into'()

local opt = pl.lapp[[
   -i,--input         (default "")  Input model
   -o,--output        (default "")  Output model 
]]

local saved_check_point = torch.load(opt.input)

local model
if type(saved_check_point) == 'table' then 
    saved_check_point.model = nnutils.remove_batchnorm(saved_check_point.model)
else
    saved_check_point = nnutils.remove_batchnorm(saved_check_point)
end

torch.save(opt.output, saved_check_point)