import os import yaml import torch WEIGHTS_MAPPING = { 'snapshots/efficientnet-b7_ns_aa-original-mstd0.5_large_crop_100k/snapshot_100000.pth': 'efficientnet-b7_ns_aa-original-mstd0.5_large_crop_100k_v4_cad79a/snapshot_100000.fp16.pth', 'snapshots/efficientnet-b7_ns_aa-original-mstd0.5_re_100k/snapshot_100000.pth': 'efficientnet-b7_ns_aa-original-mstd0.5_re_100k_v4_cad79a/snapshot_100000.fp16.pth', 'snapshots/efficientnet-b7_ns_seq_aa-original-mstd0.5_100k/snapshot_100000.pth': 'efficientnet-b7_ns_seq_aa-original-mstd0.5_100k_v4_cad79a/snapshot_100000.fp16.pth' } SRC_DETECTOR_WEIGHTS = 'external_data/WIDERFace_DSFD_RES152.pth' DST_DETECTOR_WEIGHTS = 'WIDERFace_DSFD_RES152.fp16.pth' def copy_weights(src_path, dst_path): state = torch.load(src_path, map_location=lambda storage, loc: storage) state = {key: value.half() for key, value in state.items()} os.makedirs(os.path.dirname(dst_path), exist_ok=True) torch.save(state, dst_path) def main(): with open('config.yaml', 'r') as f: config = yaml.load(f) for src_rel_path, dst_rel_path in WEIGHTS_MAPPING.items(): src_path = os.path.join(config['ARTIFACTS_PATH'], src_rel_path) dst_path = os.path.join(config['MODELS_PATH'], dst_rel_path) copy_weights(src_path, dst_path) copy_weights(SRC_DETECTOR_WEIGHTS, os.path.join(config['MODELS_PATH'], DST_DETECTOR_WEIGHTS)) if __name__ == '__main__': main()