alan-chen-intel commited on
Commit
e4bc86b
·
1 Parent(s): 3ece203

fix map location error. all previous commits were made by me, config was wrong for git

Browse files
Files changed (1) hide show
  1. run.py +3 -3
run.py CHANGED
@@ -155,7 +155,7 @@ kp_detector = KPDetector(**config['model_params']['kp_detector_params'], **confi
155
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
156
  cpu = False if torch.cuda.is_available() else True
157
 
158
- g_checkpoint = torch.load("checkpoints/generator.pth", map_location = device)
159
  kp_checkpoint = torch.load("checkpoints/kp_detector.pth", map_location=device)
160
 
161
  ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in g_checkpoint.items())
@@ -165,8 +165,8 @@ kp_detector.load_state_dict(ckp_kp_detector)
165
 
166
  depth_encoder = depth.ResnetEncoder(18, False)
167
  depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4))
168
- loaded_dict_enc = torch.load('checkpoints/encoder.pth')
169
- loaded_dict_dec = torch.load('checkpoints/depth.pth')
170
 
171
  filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
172
  depth_encoder.load_state_dict(filtered_dict_enc)
 
155
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
156
  cpu = False if torch.cuda.is_available() else True
157
 
158
+ g_checkpoint = torch.load("checkpoints/generator.pth", map_location=device)
159
  kp_checkpoint = torch.load("checkpoints/kp_detector.pth", map_location=device)
160
 
161
  ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in g_checkpoint.items())
 
165
 
166
  depth_encoder = depth.ResnetEncoder(18, False)
167
  depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4))
168
+ loaded_dict_enc = torch.load('checkpoints/encoder.pth', map_location=device)
169
+ loaded_dict_dec = torch.load('checkpoints/depth.pth', map_location=device)
170
 
171
  filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
172
  depth_encoder.load_state_dict(filtered_dict_enc)