File size: 1,443 Bytes
79cf6ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
import os
import yaml
import torch
WEIGHTS_MAPPING = {
'snapshots/efficientnet-b7_ns_aa-original-mstd0.5_large_crop_100k/snapshot_100000.pth': 'efficientnet-b7_ns_aa-original-mstd0.5_large_crop_100k_v4_cad79a/snapshot_100000.fp16.pth',
'snapshots/efficientnet-b7_ns_aa-original-mstd0.5_re_100k/snapshot_100000.pth': 'efficientnet-b7_ns_aa-original-mstd0.5_re_100k_v4_cad79a/snapshot_100000.fp16.pth',
'snapshots/efficientnet-b7_ns_seq_aa-original-mstd0.5_100k/snapshot_100000.pth': 'efficientnet-b7_ns_seq_aa-original-mstd0.5_100k_v4_cad79a/snapshot_100000.fp16.pth'
}
SRC_DETECTOR_WEIGHTS = 'external_data/WIDERFace_DSFD_RES152.pth'
DST_DETECTOR_WEIGHTS = 'WIDERFace_DSFD_RES152.fp16.pth'
def copy_weights(src_path, dst_path):
state = torch.load(src_path, map_location=lambda storage, loc: storage)
state = {key: value.half() for key, value in state.items()}
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
torch.save(state, dst_path)
def main():
with open('config.yaml', 'r') as f:
config = yaml.load(f)
for src_rel_path, dst_rel_path in WEIGHTS_MAPPING.items():
src_path = os.path.join(config['ARTIFACTS_PATH'], src_rel_path)
dst_path = os.path.join(config['MODELS_PATH'], dst_rel_path)
copy_weights(src_path, dst_path)
copy_weights(SRC_DETECTOR_WEIGHTS, os.path.join(config['MODELS_PATH'], DST_DETECTOR_WEIGHTS))
if __name__ == '__main__':
main() |