Spaces:
Running
Running
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT license. | |
| """ A python wrapper for nni rest api | |
| Example: | |
| from nni.experiment import Experiment | |
| exp = Experiment() | |
| exp.start_experiment('../../../../examples/trials/mnist-pytorch/config.yml') | |
| exp.update_concurrency(3) | |
| print(exp.get_experiment_status()) | |
| print(exp.get_job_statistics()) | |
| print(exp.list_trial_jobs()) | |
| exp.stop_experiment() | |
| """ | |
| import sys | |
| import os | |
| import subprocess | |
| import re | |
| import json | |
| import requests | |
| __all__ = [ | |
| 'Experiment', | |
| 'TrialResult', | |
| 'TrialMetricData', | |
| 'TrialHyperParameters', | |
| 'TrialJob' | |
| ] | |
| EXPERIMENT_PATH = 'experiment' | |
| STATUS_PATH = 'check-status' | |
| JOB_STATISTICS_PATH = 'job-statistics' | |
| TRIAL_JOBS_PATH = 'trial-jobs' | |
| METRICS_PATH = 'metric-data' | |
| EXPORT_DATA_PATH = 'export-data' | |
| API_ROOT_PATH = 'api/v1/nni' | |
| def _nni_rest_get(endpoint, api_path, response_type='json'): | |
| _check_endpoint(endpoint) | |
| uri = '{}/{}/{}'.format(endpoint.strip('/'), API_ROOT_PATH, api_path) | |
| res = requests.get(uri) | |
| if _http_succeed(res.status_code): | |
| if response_type == 'json': | |
| return res.json() | |
| elif response_type == 'text': | |
| return res.text | |
| else: | |
| raise RuntimeError('Incorrect response_type') | |
| else: | |
| return None | |
| def _http_succeed(status_code): | |
| return status_code // 100 == 2 | |
| def _create_process(cmd): | |
| if sys.platform == 'win32': | |
| process = subprocess.Popen(cmd, stdout=subprocess.PIPE, creationflags=subprocess.CREATE_NEW_PROCESS_GROUP) | |
| else: | |
| process = subprocess.Popen(cmd, stdout=subprocess.PIPE) | |
| while process.poll() is None: | |
| output = process.stdout.readline() | |
| if output: | |
| print(output.decode('utf-8').strip()) | |
| return process.returncode | |
| def _check_endpoint(endpoint): | |
| if endpoint is None: | |
| raise RuntimeError("This instance hasn't been connect to an experiment.") | |
| class TrialResult: | |
| """ | |
| TrialResult stores the result information of a trial job. | |
| Parameters | |
| ---------- | |
| json_obj: dict | |
| Json object that stores the result information. | |
| Attributes | |
| ---------- | |
| parameter: dict | |
| Hyper parameters for this trial. | |
| value: serializable object, usually a number, or a dict with key "default" and other extra keys | |
| Final result. | |
| trialJobId: str | |
| Trial job id. | |
| """ | |
| def __init__(self, json_obj): | |
| self.parameter = None | |
| self.value = None | |
| self.trialJobId = None | |
| for key in json_obj.keys(): | |
| if key == 'id': | |
| setattr(self, 'trialJobId', json_obj[key]) | |
| elif hasattr(self, key): | |
| setattr(self, key, json_obj[key]) | |
| self.value = json.loads(self.value) | |
| def __repr__(self): | |
| return "TrialResult(parameter: {} value: {} trialJobId: {})".format(self.parameter, self.value, self.trialJobId) | |
| class TrialMetricData: | |
| """ | |
| TrialMetricData stores the metric data of a trial job. | |
| A trial job may have both intermediate metric and final metric. | |
| Parameters | |
| ---------- | |
| json_obj: dict | |
| Json object that stores the metric data. | |
| Attributes | |
| ---------- | |
| timestamp: int | |
| Time stamp. | |
| trialJobId: str | |
| Trial job id. | |
| parameterId: int | |
| Parameter id. | |
| type: str | |
| Metric type, `PERIODICAL` for intermediate result and `FINAL` for final result. | |
| sequence: int | |
| Sequence number in this trial. | |
| data: serializable object, usually a number, or a dict with key "default" and other extra keys | |
| Metric data. | |
| """ | |
| def __init__(self, json_obj): | |
| self.timestamp = None | |
| self.trialJobId = None | |
| self.parameterId = None | |
| self.type = None | |
| self.sequence = None | |
| self.data = None | |
| for key in json_obj.keys(): | |
| setattr(self, key, json_obj[key]) | |
| self.data = json.loads(json.loads(self.data)) | |
| def __repr__(self): | |
| return "TrialMetricData(timestamp: {} trialJobId: {} parameterId: {} type: {} sequence: {} data: {})" \ | |
| .format(self.timestamp, self.trialJobId, self.parameterId, self.type, self.sequence, self.data) | |
| class TrialHyperParameters: | |
| """ | |
| TrialHyperParameters stores the hyper parameters of a trial job. | |
| Parameters | |
| ---------- | |
| json_obj: dict | |
| Json object that stores the hyper parameters. | |
| Attributes | |
| ---------- | |
| parameter_id: int | |
| Parameter id. | |
| parameter_source: str | |
| Parameter source. | |
| parameters: dict | |
| Hyper parameters. | |
| parameter_index: int | |
| Parameter index. | |
| """ | |
| def __init__(self, json_obj): | |
| self.parameter_id = None | |
| self.parameter_source = None | |
| self.parameters = None | |
| self.parameter_index = None | |
| for key in json_obj.keys(): | |
| if hasattr(self, key): | |
| setattr(self, key, json_obj[key]) | |
| def __repr__(self): | |
| return "TrialHyperParameters(parameter_id: {} parameter_source: {} parameters: {} parameter_index: {})" \ | |
| .format(self.parameter_id, self.parameter_source, self.parameters, self.parameter_index) | |
| class TrialJob: | |
| """ | |
| TrialJob stores the information of a trial job. | |
| Parameters | |
| ---------- | |
| json_obj: dict | |
| json object that stores the hyper parameters | |
| Attributes | |
| ---------- | |
| trialJobId: str | |
| Trial job id. | |
| status: str | |
| Job status. | |
| hyperParameters: list of `nni.experiment.TrialHyperParameters` | |
| See `nni.experiment.TrialHyperParameters`. | |
| logPath: str | |
| Log path. | |
| startTime: int | |
| Job start time (timestamp). | |
| endTime: int | |
| Job end time (timestamp). | |
| finalMetricData: list of `nni.experiment.TrialMetricData` | |
| See `nni.experiment.TrialMetricData`. | |
| parameter_index: int | |
| Parameter index. | |
| """ | |
| def __init__(self, json_obj): | |
| self.trialJobId = None | |
| self.status = None | |
| self.hyperParameters = None | |
| self.logPath = None | |
| self.startTime = None | |
| self.endTime = None | |
| self.finalMetricData = None | |
| self.stderrPath = None | |
| for key in json_obj.keys(): | |
| if key == 'id': | |
| setattr(self, 'trialJobId', json_obj[key]) | |
| elif hasattr(self, key): | |
| setattr(self, key, json_obj[key]) | |
| if self.hyperParameters: | |
| self.hyperParameters = [TrialHyperParameters(json.loads(e)) for e in self.hyperParameters] | |
| if self.finalMetricData: | |
| self.finalMetricData = [TrialMetricData(e) for e in self.finalMetricData] | |
| def __repr__(self): | |
| return ("TrialJob(trialJobId: {} status: {} hyperParameters: {} logPath: {} startTime: {} " | |
| "endTime: {} finalMetricData: {} stderrPath: {})") \ | |
| .format(self.trialJobId, self.status, self.hyperParameters, self.logPath, | |
| self.startTime, self.endTime, self.finalMetricData, self.stderrPath) | |
| class Experiment: | |
| def __init__(self): | |
| self._endpoint = None | |
| self._exp_id = None | |
| self._port = None | |
| def endpoint(self): | |
| return self._endpoint | |
| def exp_id(self): | |
| return self._exp_id | |
| def port(self): | |
| return self._port | |
| def _exec_command(self, cmd, port=None): | |
| if self._endpoint is not None: | |
| raise RuntimeError('This instance has been connected to an experiment.') | |
| if _create_process(cmd) != 0: | |
| raise RuntimeError('Failed to establish experiment, please check your config.') | |
| else: | |
| if port: | |
| self._port = port | |
| else: | |
| self._port = 8080 | |
| self._endpoint = 'http://localhost:{}'.format(self._port) | |
| self._exp_id = self.get_experiment_profile()['id'] | |
| def start_experiment(self, config_file, port=None, debug=False): | |
| """ | |
| Start an experiment with specified configuration file and connect to it. | |
| Parameters | |
| ---------- | |
| config_file: str | |
| Path to the config file. | |
| port: int | |
| The port of restful server, bigger than 1024. | |
| debug: boolean | |
| Set debug mode. | |
| """ | |
| cmd = 'nnictl create --config {}'.format(config_file).split(' ') | |
| if port: | |
| cmd += '--port {}'.format(port).split(' ') | |
| if debug: | |
| cmd += ['--debug'] | |
| self._exec_command(cmd, port) | |
| def resume_experiment(self, exp_id, port=None, debug=False): | |
| """ | |
| Resume a stopped experiment with specified experiment id | |
| Parameters | |
| ---------- | |
| exp_id: str | |
| Experiment id. | |
| port: int | |
| The port of restful server, bigger than 1024. | |
| debug: boolean | |
| Set debug mode. | |
| """ | |
| cmd = 'nnictl resume {}'.format(exp_id).split(' ') | |
| if port: | |
| cmd += '--port {}'.format(port).split(' ') | |
| if debug: | |
| cmd += ['--debug'] | |
| self._exec_command(cmd, port) | |
| def view_experiment(self, exp_id, port=None): | |
| """ | |
| View a stopped experiment with specified experiment id. | |
| Parameters | |
| ---------- | |
| exp_id: str | |
| Experiment id. | |
| port: int | |
| The port of restful server, bigger than 1024. | |
| """ | |
| cmd = 'nnictl view {}'.format(exp_id).split(' ') | |
| if port: | |
| cmd += '--port {}'.format(port).split(' ') | |
| self._exec_command(cmd, port) | |
| def connect_experiment(self, endpoint): | |
| """ | |
| Connect to an existing experiment. | |
| Parameters | |
| ---------- | |
| endpoint: str | |
| The endpoint of nni rest server, i.e, the url of Web UI. Should be a format like `http://ip:port`. | |
| """ | |
| if self._endpoint is not None: | |
| raise RuntimeError('This instance has been connected to an experiment.') | |
| self._endpoint = endpoint | |
| try: | |
| self._exp_id = self.get_experiment_profile()['id'] | |
| except TypeError: | |
| raise RuntimeError('Invalid experiment endpoint.') | |
| self._port = int(re.search(r':[0-9]+', self._endpoint).group().replace(':', '')) | |
| def stop_experiment(self): | |
| """Stop the experiment. | |
| """ | |
| _check_endpoint(self._endpoint) | |
| cmd = 'nnictl stop {}'.format(self._exp_id).split(' ') | |
| if _create_process(cmd) != 0: | |
| raise RuntimeError('Failed to stop experiment.') | |
| self._endpoint = None | |
| self._exp_id = None | |
| self._port = None | |
| def update_searchspace(self, filename): | |
| """ | |
| Update the experiment's search space. | |
| Parameters | |
| ---------- | |
| filename: str | |
| Path to the searchspace file. | |
| """ | |
| _check_endpoint(self._endpoint) | |
| cmd = 'nnictl update searchspace {} --filename {}'.format(self._exp_id, filename).split(' ') | |
| if _create_process(cmd) != 0: | |
| raise RuntimeError('Failed to update searchspace.') | |
| def update_concurrency(self, value): | |
| """ | |
| Update an experiment's concurrency | |
| Parameters | |
| ---------- | |
| value: int | |
| New concurrency value. | |
| """ | |
| _check_endpoint(self._endpoint) | |
| cmd = 'nnictl update concurrency {} --value {}'.format(self._exp_id, value).split(' ') | |
| if _create_process(cmd) != 0: | |
| raise RuntimeError('Failed to update concurrency.') | |
| def update_duration(self, value): | |
| """ | |
| Update an experiment's duration | |
| Parameters | |
| ---------- | |
| value: str | |
| Strings like '1m' for one minute or '2h' for two hours. | |
| SUFFIX may be 's' for seconds, 'm' for minutes, 'h' for hours or 'd' for days. | |
| """ | |
| _check_endpoint(self._endpoint) | |
| cmd = 'nnictl update duration {} --value {}'.format(self._exp_id, value).split(' ') | |
| if _create_process(cmd) != 0: | |
| raise RuntimeError('Failed to update duration.') | |
| def update_trailnum(self, value): | |
| """ | |
| Update an experiment's maxtrialnum | |
| Parameters | |
| ---------- | |
| value: int | |
| New trailnum value. | |
| """ | |
| _check_endpoint(self._endpoint) | |
| cmd = 'nnictl update trialnum {} --value {}'.format(self._exp_id, value).split(' ') | |
| if _create_process(cmd) != 0: | |
| raise RuntimeError('Failed to update trailnum.') | |
| def get_experiment_status(self): | |
| """ | |
| Return experiment status as a dict. | |
| Returns | |
| ---------- | |
| dict | |
| Experiment status. | |
| """ | |
| _check_endpoint(self._endpoint) | |
| return _nni_rest_get(self._endpoint, STATUS_PATH) | |
| def get_trial_job(self, trial_job_id): | |
| """ | |
| Return a trial job. | |
| Parameters | |
| ---------- | |
| trial_job_id: str | |
| Trial job id. | |
| Returns | |
| ---------- | |
| nnicli.TrialJob | |
| A `nnicli.TrialJob` instance corresponding to `trial_job_id`. | |
| """ | |
| _check_endpoint(self._endpoint) | |
| assert trial_job_id is not None | |
| trial_job = _nni_rest_get(self._endpoint, os.path.join(TRIAL_JOBS_PATH, trial_job_id)) | |
| return TrialJob(trial_job) | |
| def list_trial_jobs(self): | |
| """ | |
| Return information for all trial jobs as a list. | |
| Returns | |
| ---------- | |
| list | |
| List of `nnicli.TrialJob`. | |
| """ | |
| _check_endpoint(self._endpoint) | |
| trial_jobs = _nni_rest_get(self._endpoint, TRIAL_JOBS_PATH) | |
| return [TrialJob(e) for e in trial_jobs] | |
| def get_job_statistics(self): | |
| """ | |
| Return trial job statistics information as a dict. | |
| Returns | |
| ---------- | |
| list | |
| Job statistics information. | |
| """ | |
| _check_endpoint(self._endpoint) | |
| return _nni_rest_get(self._endpoint, JOB_STATISTICS_PATH) | |
| def get_job_metrics(self, trial_job_id=None): | |
| """ | |
| Return trial job metrics. | |
| Parameters | |
| ---------- | |
| trial_job_id: str | |
| trial job id. if this parameter is None, all trail jobs' metrics will be returned. | |
| Returns | |
| ---------- | |
| dict | |
| Each key is a trialJobId, the corresponding value is a list of `nnicli.TrialMetricData`. | |
| """ | |
| _check_endpoint(self._endpoint) | |
| api_path = METRICS_PATH if trial_job_id is None else os.path.join(METRICS_PATH, trial_job_id) | |
| output = {} | |
| trail_metrics = _nni_rest_get(self._endpoint, api_path) | |
| for metric in trail_metrics: | |
| trial_id = metric["trialJobId"] | |
| if trial_id not in output: | |
| output[trial_id] = [TrialMetricData(metric)] | |
| else: | |
| output[trial_id].append(TrialMetricData(metric)) | |
| return output | |
| def export_data(self): | |
| """ | |
| Return exported information for all trial jobs. | |
| Returns | |
| ---------- | |
| list | |
| List of `nnicli.TrialResult`. | |
| """ | |
| _check_endpoint(self._endpoint) | |
| trial_results = _nni_rest_get(self._endpoint, EXPORT_DATA_PATH) | |
| return [TrialResult(e) for e in trial_results] | |
| def get_experiment_profile(self): | |
| """ | |
| Return experiment profile as a dict. | |
| Returns | |
| ---------- | |
| dict | |
| The profile of the experiment. | |
| """ | |
| _check_endpoint(self._endpoint) | |
| return _nni_rest_get(self._endpoint, EXPERIMENT_PATH) | |