deep / copy_weights.py
Aryan6192's picture
deep
79cf6ef verified
import os
import yaml
import torch
WEIGHTS_MAPPING = {
'snapshots/efficientnet-b7_ns_aa-original-mstd0.5_large_crop_100k/snapshot_100000.pth': 'efficientnet-b7_ns_aa-original-mstd0.5_large_crop_100k_v4_cad79a/snapshot_100000.fp16.pth',
'snapshots/efficientnet-b7_ns_aa-original-mstd0.5_re_100k/snapshot_100000.pth': 'efficientnet-b7_ns_aa-original-mstd0.5_re_100k_v4_cad79a/snapshot_100000.fp16.pth',
'snapshots/efficientnet-b7_ns_seq_aa-original-mstd0.5_100k/snapshot_100000.pth': 'efficientnet-b7_ns_seq_aa-original-mstd0.5_100k_v4_cad79a/snapshot_100000.fp16.pth'
}
SRC_DETECTOR_WEIGHTS = 'external_data/WIDERFace_DSFD_RES152.pth'
DST_DETECTOR_WEIGHTS = 'WIDERFace_DSFD_RES152.fp16.pth'
def copy_weights(src_path, dst_path):
state = torch.load(src_path, map_location=lambda storage, loc: storage)
state = {key: value.half() for key, value in state.items()}
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
torch.save(state, dst_path)
def main():
with open('config.yaml', 'r') as f:
config = yaml.load(f)
for src_rel_path, dst_rel_path in WEIGHTS_MAPPING.items():
src_path = os.path.join(config['ARTIFACTS_PATH'], src_rel_path)
dst_path = os.path.join(config['MODELS_PATH'], dst_rel_path)
copy_weights(src_path, dst_path)
copy_weights(SRC_DETECTOR_WEIGHTS, os.path.join(config['MODELS_PATH'], DST_DETECTOR_WEIGHTS))
if __name__ == '__main__':
main()