File size: 1,401 Bytes
2940390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch, argparse
from model.OneRestore import OneRestore
from model.Embedder import Embedder

parser = argparse.ArgumentParser()

parser.add_argument("--type", type=str, default = 'OneRestore')
parser.add_argument("--input-file", type=str, default = './ckpts/onerestore_cdd-11.tar')
parser.add_argument("--output-file", type=str, default = './ckpts/onerestore_cdd-11.tar')

args = parser.parse_args()

if args.type == 'OneRestore':
    restorer = OneRestore().to("cuda" if torch.cuda.is_available() else "cpu")
    restorer_info = torch.load(args.input_file, map_location='cuda:0')
    weights_dict = {}
    for k, v in restorer_info['state_dict'].items():
        new_k = k.replace('module.', '') if 'module' in k else k
        weights_dict[new_k] = v
    restorer.load_state_dict(weights_dict)
    torch.save(restorer.state_dict(), args.output_file)
elif args.type == 'Embedder':
    combine_type = ['clear', 'low', 'haze', 'rain', 'snow',\
                    'low_haze', 'low_rain', 'low_snow', 'haze_rain',\
                        'haze_snow', 'low_haze_rain', 'low_haze_snow']
    embedder = Embedder(combine_type).to("cuda" if torch.cuda.is_available() else "cpu")
    embedder_info = torch.load(args.input_file)
    embedder.load_state_dict(embedder_info['state_dict'])
    torch.save(embedder.state_dict(), args.output_file)
else:
    print('ERROR!')