Spaces:
Runtime error
Runtime error
Commit ·
e4bc86b
1
Parent(s): 3ece203
fix map location error. all previous commits were made by me, config was wrong for git
Browse files
run.py
CHANGED
|
@@ -155,7 +155,7 @@ kp_detector = KPDetector(**config['model_params']['kp_detector_params'], **confi
|
|
| 155 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 156 |
cpu = False if torch.cuda.is_available() else True
|
| 157 |
|
| 158 |
-
g_checkpoint = torch.load("checkpoints/generator.pth", map_location
|
| 159 |
kp_checkpoint = torch.load("checkpoints/kp_detector.pth", map_location=device)
|
| 160 |
|
| 161 |
ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in g_checkpoint.items())
|
|
@@ -165,8 +165,8 @@ kp_detector.load_state_dict(ckp_kp_detector)
|
|
| 165 |
|
| 166 |
depth_encoder = depth.ResnetEncoder(18, False)
|
| 167 |
depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4))
|
| 168 |
-
loaded_dict_enc = torch.load('checkpoints/encoder.pth')
|
| 169 |
-
loaded_dict_dec = torch.load('checkpoints/depth.pth')
|
| 170 |
|
| 171 |
filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
|
| 172 |
depth_encoder.load_state_dict(filtered_dict_enc)
|
|
|
|
| 155 |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 156 |
cpu = False if torch.cuda.is_available() else True
|
| 157 |
|
| 158 |
+
g_checkpoint = torch.load("checkpoints/generator.pth", map_location=device)
|
| 159 |
kp_checkpoint = torch.load("checkpoints/kp_detector.pth", map_location=device)
|
| 160 |
|
| 161 |
ckp_generator = OrderedDict((k.replace('module.',''),v) for k,v in g_checkpoint.items())
|
|
|
|
| 165 |
|
| 166 |
depth_encoder = depth.ResnetEncoder(18, False)
|
| 167 |
depth_decoder = depth.DepthDecoder(num_ch_enc=depth_encoder.num_ch_enc, scales=range(4))
|
| 168 |
+
loaded_dict_enc = torch.load('checkpoints/encoder.pth', map_location=device)
|
| 169 |
+
loaded_dict_dec = torch.load('checkpoints/depth.pth', map_location=device)
|
| 170 |
|
| 171 |
filtered_dict_enc = {k: v for k, v in loaded_dict_enc.items() if k in depth_encoder.state_dict()}
|
| 172 |
depth_encoder.load_state_dict(filtered_dict_enc)
|