| | from __future__ import absolute_import |
| | from __future__ import division |
| | from __future__ import print_function |
| |
|
| | import paddle |
| | import paddle.nn as nn |
| | import paddle.nn.functional as F |
| |
|
| | from ppdet.core.workspace import load_config, merge_config, create |
| | from ppdet.utils.checkpoint import load_weight, load_pretrain_weight |
| | from ppdet.utils.logger import setup_logger |
| | from ppdet.core.workspace import register, serializable |
| |
|
| | from paddle.utils import try_import |
| |
|
| | logger = setup_logger(__name__) |
| |
|
| |
|
| | @register |
| | @serializable |
| | class OFA(object): |
| | def __init__(self, ofa_config): |
| | super(OFA, self).__init__() |
| | self.ofa_config = ofa_config |
| |
|
| | def __call__(self, model, param_state_dict): |
| |
|
| | paddleslim = try_import('paddleslim') |
| | from paddleslim.nas.ofa import OFA, RunConfig, utils |
| | from paddleslim.nas.ofa.convert_super import Convert, supernet |
| | task = self.ofa_config['task'] |
| | expand_ratio = self.ofa_config['expand_ratio'] |
| |
|
| | skip_neck = self.ofa_config['skip_neck'] |
| | skip_head = self.ofa_config['skip_head'] |
| |
|
| | run_config = self.ofa_config['RunConfig'] |
| | if 'skip_layers' in run_config: |
| | skip_layers = run_config['skip_layers'] |
| | else: |
| | skip_layers = [] |
| |
|
| | |
| | sp_config = supernet(expand_ratio=expand_ratio) |
| | |
| | model = Convert(sp_config).convert(model) |
| |
|
| | skip_names = [] |
| | if skip_neck: |
| | skip_names.append('neck.') |
| | if skip_head: |
| | skip_names.append('head.') |
| |
|
| | for name, sublayer in model.named_sublayers(): |
| | for n in skip_names: |
| | if n in name: |
| | skip_layers.append(name) |
| |
|
| | run_config['skip_layers'] = skip_layers |
| | run_config = RunConfig(**run_config) |
| |
|
| | |
| | ofa_model = OFA(model, run_config=run_config) |
| |
|
| | ofa_model.set_epoch(0) |
| | ofa_model.set_task(task) |
| |
|
| | input_spec = [{ |
| | "image": paddle.ones( |
| | shape=[1, 3, 640, 640], dtype='float32'), |
| | "im_shape": paddle.full( |
| | [1, 2], 640, dtype='float32'), |
| | "scale_factor": paddle.ones( |
| | shape=[1, 2], dtype='float32') |
| | }] |
| |
|
| | ofa_model._clear_search_space(input_spec=input_spec) |
| | ofa_model._build_ss = True |
| | check_ss = ofa_model._sample_config('expand_ratio', phase=None) |
| | |
| | ofa_model.tokenize() |
| | |
| | logger.info('Token map is {}'.format(ofa_model.token_map)) |
| | logger.info('Search candidates is {}'.format(ofa_model.search_cands)) |
| | logger.info('The length of search_space is {}, search_space is {}'. |
| | format(len(ofa_model._ofa_layers), ofa_model._ofa_layers)) |
| | |
| | utils.set_state_dict(ofa_model.model, param_state_dict) |
| | return ofa_model |
| |
|