| | import torch |
| | import os |
| | import argparse |
| | import copy |
| |
|
| | def parse_args(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--source', type=str) |
| | parser.add_argument('--target', type=str, default=None) |
| | args = parser.parse_args() |
| | return args |
| |
|
| | def main(): |
| |
|
| | args = parse_args() |
| |
|
| | if args.target is None: |
| | args.target = '/'.join(args.source.split('/')[:-1]) |
| |
|
| | ckpt = torch.load(args.source, map_location='cpu') |
| | |
| | experts = dict() |
| |
|
| | new_ckpt = copy.deepcopy(ckpt) |
| |
|
| | state_dict = new_ckpt['state_dict'] |
| |
|
| | for key, value in state_dict.items(): |
| | if 'mlp.experts' in key: |
| | experts[key] = value |
| |
|
| | keys = ckpt['state_dict'].keys() |
| |
|
| | target_expert = 0 |
| | new_ckpt = copy.deepcopy(ckpt) |
| |
|
| | for key in keys: |
| | if 'mlp.fc2' in key: |
| | value = new_ckpt['state_dict'][key] |
| | value = torch.cat([value, experts[key.replace('fc2.', f'experts.{target_expert}.')]], dim=0) |
| | new_ckpt['state_dict'][key] = value |
| |
|
| | torch.save(new_ckpt, os.path.join(args.target, 'coco.pth')) |
| |
|
| | names = ['aic', 'mpii', 'ap10k', 'apt36k','wholebody'] |
| | num_keypoints = [14, 16, 17, 17, 133] |
| | weight_names = ['keypoint_head.deconv_layers.0.weight', |
| | 'keypoint_head.deconv_layers.1.weight', |
| | 'keypoint_head.deconv_layers.1.bias', |
| | 'keypoint_head.deconv_layers.1.running_mean', |
| | 'keypoint_head.deconv_layers.1.running_var', |
| | 'keypoint_head.deconv_layers.1.num_batches_tracked', |
| | 'keypoint_head.deconv_layers.3.weight', |
| | 'keypoint_head.deconv_layers.4.weight', |
| | 'keypoint_head.deconv_layers.4.bias', |
| | 'keypoint_head.deconv_layers.4.running_mean', |
| | 'keypoint_head.deconv_layers.4.running_var', |
| | 'keypoint_head.deconv_layers.4.num_batches_tracked', |
| | 'keypoint_head.final_layer.weight', |
| | 'keypoint_head.final_layer.bias'] |
| | |
| | exist_range = True |
| |
|
| | for i in range(5): |
| |
|
| | new_ckpt = copy.deepcopy(ckpt) |
| |
|
| | target_expert = i + 1 |
| |
|
| | for key in keys: |
| | if 'mlp.fc2' in key: |
| | expert_key = key.replace('fc2.', f'experts.{target_expert}.') |
| | if expert_key in experts: |
| | value = new_ckpt['state_dict'][key] |
| | value = torch.cat([value, experts[expert_key]], dim=0) |
| | else: |
| | exist_range = False |
| |
|
| | new_ckpt['state_dict'][key] = value |
| |
|
| | if not exist_range: |
| | break |
| |
|
| | for tensor_name in weight_names: |
| | new_ckpt['state_dict'][tensor_name] = new_ckpt['state_dict'][tensor_name.replace('keypoint_head', f'associate_keypoint_heads.{i}')] |
| |
|
| | for tensor_name in ['keypoint_head.final_layer.weight', 'keypoint_head.final_layer.bias']: |
| | new_ckpt['state_dict'][tensor_name] = new_ckpt['state_dict'][tensor_name][:num_keypoints[i]] |
| | |
| | |
| | for j in range(5): |
| | |
| | for tensor_name in weight_names: |
| | new_ckpt['state_dict'].pop(tensor_name.replace('keypoint_head', f'associate_keypoint_heads.{j}')) |
| | |
| | keys = new_ckpt['state_dict'].keys() |
| | for key in list(keys): |
| | if 'expert' in keys: |
| | new_ckpt['state_dict'].pop(key) |
| | |
| | torch.save(new_ckpt, os.path.join(args.target, f'{names[i]}.pth')) |
| |
|
| | if __name__ == '__main__': |
| | main() |
| |
|