Spaces:
Build error
Build error
| """Utils for NNCf optimization.""" | |
| # Copyright (C) 2022 Intel Corporation | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, | |
| # software distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions | |
| # and limitations under the License. | |
| import logging | |
| from copy import copy | |
| from typing import Any, Dict, Iterator, List, Tuple | |
| from nncf import NNCFConfig | |
| from nncf.api.compression import CompressionAlgorithmController | |
| from nncf.torch import create_compressed_model, load_state, register_default_init_args | |
| from nncf.torch.initialization import PTInitializingDataLoader | |
| from nncf.torch.nncf_network import NNCFNetwork | |
| from torch import nn | |
| from torch.utils.data.dataloader import DataLoader | |
| logger = logging.getLogger(name="NNCF compression") | |
| class InitLoader(PTInitializingDataLoader): | |
| """Initializing data loader for NNCF to be used with unsupervised training algorithms.""" | |
| def __init__(self, data_loader: DataLoader): | |
| super().__init__(data_loader) | |
| self._data_loader_iter: Iterator | |
| def __iter__(self): | |
| """Create iterator for dataloader.""" | |
| self._data_loader_iter = iter(self._data_loader) | |
| return self | |
| def __next__(self) -> Any: | |
| """Return next item from dataloader iterator.""" | |
| loaded_item = next(self._data_loader_iter) | |
| return loaded_item["image"] | |
| def get_inputs(self, dataloader_output) -> Tuple[Tuple, Dict]: | |
| """Get input to model. | |
| Returns: | |
| (dataloader_output,), {}: Tuple[Tuple, Dict]: The current model call to be made during | |
| the initialization process | |
| """ | |
| return (dataloader_output,), {} | |
| def get_target(self, _): | |
| """Return structure for ground truth in loss criterion based on dataloader output. | |
| This implementation does not do anything and is a placeholder. | |
| Returns: | |
| None | |
| """ | |
| return None | |
| def wrap_nncf_model( | |
| model: nn.Module, config: Dict, dataloader: DataLoader = None, init_state_dict: Dict = None | |
| ) -> Tuple[CompressionAlgorithmController, NNCFNetwork]: | |
| """Wrap model by NNCF. | |
| :param model: Anomalib model. | |
| :param config: NNCF config. | |
| :param dataloader: Dataloader for initialization of NNCF model. | |
| :param init_state_dict: Opti | |
| :return: compression controller, compressed model | |
| """ | |
| nncf_config = NNCFConfig.from_dict(config) | |
| if not dataloader and not init_state_dict: | |
| logger.warning( | |
| "Either dataloader or NNCF pre-trained " | |
| "model checkpoint should be set. Without this, " | |
| "quantizers will not be initialized" | |
| ) | |
| compression_state = None | |
| resuming_state_dict = None | |
| if init_state_dict: | |
| resuming_state_dict = init_state_dict.get("model") | |
| compression_state = init_state_dict.get("compression_state") | |
| if dataloader: | |
| init_loader = InitLoader(dataloader) # type: ignore | |
| nncf_config = register_default_init_args(nncf_config, init_loader) | |
| nncf_ctrl, nncf_model = create_compressed_model( | |
| model=model, config=nncf_config, dump_graphs=False, compression_state=compression_state | |
| ) | |
| if resuming_state_dict: | |
| load_state(nncf_model, resuming_state_dict, is_resume=True) | |
| return nncf_ctrl, nncf_model | |
| def is_state_nncf(state: Dict) -> bool: | |
| """The function to check if sate is the result of NNCF-compressed model.""" | |
| return bool(state.get("meta", {}).get("nncf_enable_compression", False)) | |
| def compose_nncf_config(nncf_config: Dict, enabled_options: List[str]) -> Dict: | |
| """Compose NNCf config by selected options. | |
| :param nncf_config: | |
| :param enabled_options: | |
| :return: config | |
| """ | |
| optimisation_parts = nncf_config | |
| optimisation_parts_to_choose = [] | |
| if "order_of_parts" in optimisation_parts: | |
| # The result of applying the changes from optimisation parts | |
| # may depend on the order of applying the changes | |
| # (e.g. if for nncf_quantization it is sufficient to have `total_epochs=2`, | |
| # but for sparsity it is required `total_epochs=50`) | |
| # So, user can define `order_of_parts` in the optimisation_config | |
| # to specify the order of applying the parts. | |
| order_of_parts = optimisation_parts["order_of_parts"] | |
| assert isinstance(order_of_parts, list), 'The field "order_of_parts" in optimisation config should be a list' | |
| for part in enabled_options: | |
| assert part in order_of_parts, ( | |
| f"The part {part} is selected, " "but it is absent in order_of_parts={order_of_parts}" | |
| ) | |
| optimisation_parts_to_choose = [part for part in order_of_parts if part in enabled_options] | |
| assert "base" in optimisation_parts, 'Error: the optimisation config does not contain the "base" part' | |
| nncf_config_part = optimisation_parts["base"] | |
| for part in optimisation_parts_to_choose: | |
| assert part in optimisation_parts, f'Error: the optimisation config does not contain the part "{part}"' | |
| optimisation_part_dict = optimisation_parts[part] | |
| try: | |
| nncf_config_part = merge_dicts_and_lists_b_into_a(nncf_config_part, optimisation_part_dict) | |
| except AssertionError as cur_error: | |
| err_descr = ( | |
| f"Error during merging the parts of nncf configs:\n" | |
| f"the current part={part}, " | |
| f"the order of merging parts into base is {optimisation_parts_to_choose}.\n" | |
| f"The error is:\n{cur_error}" | |
| ) | |
| raise RuntimeError(err_descr) from None | |
| return nncf_config_part | |
| # pylint: disable=invalid-name | |
| def merge_dicts_and_lists_b_into_a(a, b): | |
| """The function to merge dict configs.""" | |
| return _merge_dicts_and_lists_b_into_a(a, b, "") | |
| def _merge_dicts_and_lists_b_into_a(a, b, cur_key=None): | |
| """The function is inspired by mmcf.Config._merge_a_into_b. | |
| * works with usual dicts and lists and derived types | |
| * supports merging of lists (by concatenating the lists) | |
| * makes recursive merging for dict + dict case | |
| * overwrites when merging scalar into scalar | |
| Note that we merge b into a (whereas Config makes merge a into b), | |
| since otherwise the order of list merging is counter-intuitive. | |
| """ | |
| def _err_str(_a, _b, _key): | |
| if _key is None: | |
| _key_str = "of whole structures" | |
| else: | |
| _key_str = f"during merging for key=`{_key}`" | |
| return ( | |
| f"Error in merging parts of config: different types {_key_str}," | |
| f" type(a) = {type(_a)}," | |
| f" type(b) = {type(_b)}" | |
| ) | |
| assert isinstance(a, (dict, list)), f"Can merge only dicts and lists, whereas type(a)={type(a)}" | |
| assert isinstance(b, (dict, list)), _err_str(a, b, cur_key) | |
| assert isinstance(a, list) == isinstance(b, list), _err_str(a, b, cur_key) | |
| if isinstance(a, list): | |
| # the main diff w.r.t. mmcf.Config -- merging of lists | |
| return a + b | |
| a = copy(a) | |
| for k in b.keys(): | |
| if k not in a: | |
| a[k] = copy(b[k]) | |
| continue | |
| new_cur_key = cur_key + "." + k if cur_key else k | |
| if isinstance(a[k], (dict, list)): | |
| a[k] = _merge_dicts_and_lists_b_into_a(a[k], b[k], new_cur_key) | |
| continue | |
| assert not isinstance(b[k], (dict, list)), _err_str(a[k], b[k], new_cur_key) | |
| # suppose here that a[k] and b[k] are scalars, just overwrite | |
| a[k] = b[k] | |
| return a | |