Spaces:
Sleeping
Sleeping
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import argparse | |
| from collections import OrderedDict | |
| import torch | |
| def change_model(args): | |
| dis_model = torch.load(args.dis_path, map_location='cpu') | |
| all_name = [] | |
| if args.two_dis: | |
| for name, v in dis_model['state_dict'].items(): | |
| if name.startswith('teacher.backbone'): | |
| all_name.append((name[8:], v)) | |
| elif name.startswith('distill_losses.loss_mgd.down'): | |
| all_name.append(('head.' + name[24:], v)) | |
| elif name.startswith('teacher.neck'): | |
| all_name.append((name[8:], v)) | |
| elif name.startswith('student.head'): | |
| all_name.append((name[8:], v)) | |
| else: | |
| continue | |
| else: | |
| for name, v in dis_model['state_dict'].items(): | |
| if name.startswith('student.'): | |
| all_name.append((name[8:], v)) | |
| else: | |
| continue | |
| state_dict = OrderedDict(all_name) | |
| dis_model['state_dict'] = state_dict | |
| save_keys = ['meta', 'state_dict'] | |
| ckpt_keys = list(dis_model.keys()) | |
| for k in ckpt_keys: | |
| if k not in save_keys: | |
| dis_model.pop(k, None) | |
| torch.save(dis_model, args.output_path) | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='Transfer CKPT') | |
| parser.add_argument('dis_path', help='dis_model path') | |
| parser.add_argument('output_path', help='output path') | |
| parser.add_argument( | |
| '--two_dis', action='store_true', default=False, help='if two dis') | |
| args = parser.parse_args() | |
| change_model(args) | |