File size: 1,828 Bytes
e9fe176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
require 'nn'
require 'cunn'
require 'cudnn'
require 'fbnn'
require 'fbcunn'
local cjson = require 'cjson'
-- Convert a file to json.

local pl = require'pl.import_into'()

local function merge(t_dst, t_src)
    for i, v in ipairs(t_src) do
        table.insert(t_dst, v)
    end
    return t_dst
end

local function extract_array(t)
    local all_array = {}
    if type(t) == 'table' then
        for i, v in ipairs(t) do
            merge(all_array, extract_array(v))
        end
    elseif torch.typename(t) and torch.typename(t):match('Tensor') then
        t:apply(function (x) table.insert(all_array, x) end)
    else
        error("Input is not a table or a tensor!")
    end
    return all_array
end

local function recursive_save(t, name_prefix)
    local all_array = {}
    local all_save = {}

    print(name_prefix)

    if type(t) == 'table' then
        for k, v in pairs(t) do
            if type(k) == 'string' then
                save_content = recursive_save(v, name_prefix .. "_" .. k)
                for kk, vv in pairs(save_content) do
                    all_save[kk] = vv
                end
            elseif type(k) == 'number' then
                -- Save v with the existing prefix.
                -- For tensor, save every element. 
                merge(all_array, extract_array(v))
            end 
        end
    else
        all_array = extract_array(t)
    end

    if #all_array > 0 then all_save[name_prefix] = all_array end
    return all_save
end

local opt = pl.lapp[[
   -i,--input         (default "")  Input model
   -o,--outputprefix  (default "")  Output model 
]]

local save_content = recursive_save(torch.load(opt.input), opt.outputprefix)
for name, content in pairs(save_content) do
    local f = assert(io.open(name, "w"))
    f:write(cjson.encode(content))
    f:close()
end