| |
| import argparse |
| import re |
| import tempfile |
| from collections import OrderedDict |
|
|
| import torch |
| from mmengine import Config |
|
|
|
|
| def is_head(key): |
| valid_head_list = [ |
| 'bbox_head', 'mask_head', 'semantic_head', 'grid_head', 'mask_iou_head' |
| ] |
|
|
| return any(key.startswith(h) for h in valid_head_list) |
|
|
|
|
| def parse_config(config_strings): |
| temp_file = tempfile.NamedTemporaryFile() |
| config_path = f'{temp_file.name}.py' |
| with open(config_path, 'w') as f: |
| f.write(config_strings) |
|
|
| config = Config.fromfile(config_path) |
| is_two_stage = True |
| is_ssd = False |
| is_retina = False |
| reg_cls_agnostic = False |
| if 'rpn_head' not in config.model: |
| is_two_stage = False |
| |
| if config.model.bbox_head.type == 'SSDHead': |
| is_ssd = True |
| elif config.model.bbox_head.type == 'RetinaHead': |
| is_retina = True |
| elif isinstance(config.model['bbox_head'], list): |
| reg_cls_agnostic = True |
| elif 'reg_class_agnostic' in config.model.bbox_head: |
| reg_cls_agnostic = config.model.bbox_head \ |
| .reg_class_agnostic |
| temp_file.close() |
| return is_two_stage, is_ssd, is_retina, reg_cls_agnostic |
|
|
|
|
| def reorder_cls_channel(val, num_classes=81): |
| |
| if val.dim() == 1: |
| new_val = torch.cat((val[1:], val[:1]), dim=0) |
| |
| else: |
| out_channels, in_channels = val.shape[:2] |
| |
| if out_channels != num_classes and out_channels % num_classes == 0: |
| new_val = val.reshape(-1, num_classes, in_channels, *val.shape[2:]) |
| new_val = torch.cat((new_val[:, 1:], new_val[:, :1]), dim=1) |
| new_val = new_val.reshape(val.size()) |
| |
| elif out_channels == num_classes: |
| new_val = torch.cat((val[1:], val[:1]), dim=0) |
| |
| else: |
| new_val = val |
|
|
| return new_val |
|
|
|
|
| def truncate_cls_channel(val, num_classes=81): |
|
|
| |
| if val.dim() == 1: |
| if val.size(0) % num_classes == 0: |
| new_val = val[:num_classes - 1] |
| else: |
| new_val = val |
| |
| else: |
| out_channels, in_channels = val.shape[:2] |
| |
| if out_channels % num_classes == 0: |
| new_val = val.reshape(num_classes, in_channels, *val.shape[2:])[1:] |
| new_val = new_val.reshape(-1, *val.shape[1:]) |
| |
| else: |
| new_val = val |
|
|
| return new_val |
|
|
|
|
| def truncate_reg_channel(val, num_classes=81): |
| |
| if val.dim() == 1: |
| |
| if val.size(0) % num_classes == 0: |
| new_val = val.reshape(num_classes, -1)[:num_classes - 1] |
| new_val = new_val.reshape(-1) |
| |
| else: |
| new_val = val |
| |
| else: |
| out_channels, in_channels = val.shape[:2] |
| |
| if out_channels % num_classes == 0: |
| new_val = val.reshape(num_classes, -1, in_channels, |
| *val.shape[2:])[1:] |
| new_val = new_val.reshape(-1, *val.shape[1:]) |
| |
| else: |
| new_val = val |
|
|
| return new_val |
|
|
|
|
| def convert(in_file, out_file, num_classes): |
| """Convert keys in checkpoints. |
| |
| There can be some breaking changes during the development of mmdetection, |
| and this tool is used for upgrading checkpoints trained with old versions |
| to the latest one. |
| """ |
| checkpoint = torch.load(in_file) |
| in_state_dict = checkpoint.pop('state_dict') |
| out_state_dict = OrderedDict() |
| meta_info = checkpoint['meta'] |
| is_two_stage, is_ssd, is_retina, reg_cls_agnostic = parse_config( |
| '#' + meta_info['config']) |
| if meta_info['mmdet_version'] <= '0.5.3' and is_retina: |
| upgrade_retina = True |
| else: |
| upgrade_retina = False |
|
|
| |
| |
| |
| if meta_info['mmdet_version'] < '2.5.0': |
| upgrade_rpn = True |
| else: |
| upgrade_rpn = False |
|
|
| for key, val in in_state_dict.items(): |
| new_key = key |
| new_val = val |
| if is_two_stage and is_head(key): |
| new_key = 'roi_head.{}'.format(key) |
|
|
| |
| if upgrade_rpn: |
| m = re.search( |
| r'(conv_cls|retina_cls|rpn_cls|fc_cls|fcos_cls|' |
| r'fovea_cls).(weight|bias)', new_key) |
| else: |
| m = re.search( |
| r'(conv_cls|retina_cls|fc_cls|fcos_cls|' |
| r'fovea_cls).(weight|bias)', new_key) |
| if m is not None: |
| print(f'reorder cls channels of {new_key}') |
| new_val = reorder_cls_channel(val, num_classes) |
|
|
| |
| if upgrade_rpn: |
| m = re.search(r'(fc_reg).(weight|bias)', new_key) |
| else: |
| m = re.search(r'(fc_reg|rpn_reg).(weight|bias)', new_key) |
| if m is not None and not reg_cls_agnostic: |
| print(f'truncate regression channels of {new_key}') |
| new_val = truncate_reg_channel(val, num_classes) |
|
|
| |
| m = re.search(r'(conv_logits).(weight|bias)', new_key) |
| if m is not None: |
| print(f'truncate mask prediction channels of {new_key}') |
| new_val = truncate_cls_channel(val, num_classes) |
|
|
| m = re.search(r'(cls_convs|reg_convs).\d.(weight|bias)', key) |
| |
| |
| |
| if m is not None and upgrade_retina: |
| param = m.groups()[1] |
| new_key = key.replace(param, f'conv.{param}') |
| out_state_dict[new_key] = val |
| print(f'rename the name of {key} to {new_key}') |
| continue |
|
|
| m = re.search(r'(cls_convs).\d.(weight|bias)', key) |
| if m is not None and is_ssd: |
| print(f'reorder cls channels of {new_key}') |
| new_val = reorder_cls_channel(val, num_classes) |
|
|
| out_state_dict[new_key] = new_val |
| checkpoint['state_dict'] = out_state_dict |
| torch.save(checkpoint, out_file) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description='Upgrade model version') |
| parser.add_argument('in_file', help='input checkpoint file') |
| parser.add_argument('out_file', help='output checkpoint file') |
| parser.add_argument( |
| '--num-classes', |
| type=int, |
| default=81, |
| help='number of classes of the original model') |
| args = parser.parse_args() |
| convert(args.in_file, args.out_file, args.num_classes) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|