Spaces:
Running
Running
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT license. | |
| import os | |
| import argparse | |
| import logging | |
| import json | |
| import base64 | |
| from .runtime.common import enable_multi_thread, enable_multi_phase | |
| from .runtime.msg_dispatcher import MsgDispatcher | |
| from .tools.package_utils import create_builtin_class_instance, create_customized_class_instance | |
| logger = logging.getLogger('nni.main') | |
| logger.debug('START') | |
| if os.environ.get('COVERAGE_PROCESS_START'): | |
| import coverage | |
| coverage.process_startup() | |
| def main(): | |
| parser = argparse.ArgumentParser(description='Dispatcher command line parser') | |
| parser.add_argument('--exp_params', type=str, required=True) | |
| args, _ = parser.parse_known_args() | |
| exp_params_decode = base64.b64decode(args.exp_params).decode('utf-8') | |
| logger.debug('decoded exp_params: [%s]', exp_params_decode) | |
| exp_params = json.loads(exp_params_decode) | |
| logger.debug('exp_params json obj: [%s]', json.dumps(exp_params, indent=4)) | |
| if exp_params.get('multiThread'): | |
| enable_multi_thread() | |
| if exp_params.get('multiPhase'): | |
| enable_multi_phase() | |
| if exp_params.get('advisor') is not None: | |
| # advisor is enabled and starts to run | |
| _run_advisor(exp_params) | |
| else: | |
| # tuner (and assessor) is enabled and starts to run | |
| assert exp_params.get('tuner') is not None | |
| tuner = _create_tuner(exp_params) | |
| if exp_params.get('assessor') is not None: | |
| assessor = _create_assessor(exp_params) | |
| else: | |
| assessor = None | |
| dispatcher = MsgDispatcher(tuner, assessor) | |
| try: | |
| dispatcher.run() | |
| tuner._on_exit() | |
| if assessor is not None: | |
| assessor._on_exit() | |
| except Exception as exception: | |
| logger.exception(exception) | |
| tuner._on_error() | |
| if assessor is not None: | |
| assessor._on_error() | |
| raise | |
| def _run_advisor(exp_params): | |
| if exp_params.get('advisor').get('builtinAdvisorName'): | |
| dispatcher = create_builtin_class_instance( | |
| exp_params.get('advisor').get('builtinAdvisorName'), | |
| exp_params.get('advisor').get('classArgs'), | |
| 'advisors') | |
| else: | |
| dispatcher = create_customized_class_instance(exp_params.get('advisor')) | |
| if dispatcher is None: | |
| raise AssertionError('Failed to create Advisor instance') | |
| try: | |
| dispatcher.run() | |
| except Exception as exception: | |
| logger.exception(exception) | |
| raise | |
| def _create_tuner(exp_params): | |
| if exp_params.get('tuner').get('builtinTunerName'): | |
| tuner = create_builtin_class_instance( | |
| exp_params.get('tuner').get('builtinTunerName'), | |
| exp_params.get('tuner').get('classArgs'), | |
| 'tuners') | |
| else: | |
| tuner = create_customized_class_instance(exp_params.get('tuner')) | |
| if tuner is None: | |
| raise AssertionError('Failed to create Tuner instance') | |
| return tuner | |
| def _create_assessor(exp_params): | |
| if exp_params.get('assessor').get('builtinAssessorName'): | |
| assessor = create_builtin_class_instance( | |
| exp_params.get('assessor').get('builtinAssessorName'), | |
| exp_params.get('assessor').get('classArgs'), | |
| 'assessors') | |
| else: | |
| assessor = create_customized_class_instance(exp_params.get('assessor')) | |
| if assessor is None: | |
| raise AssertionError('Failed to create Assessor instance') | |
| return assessor | |
| if __name__ == '__main__': | |
| try: | |
| main() | |
| except Exception as exception: | |
| logger.exception(exception) | |
| raise | |