Spaces:
Running
Running
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT license. | |
| import json | |
| import threading | |
| import time | |
| from abc import ABC, abstractmethod | |
| from queue import Empty, Queue | |
| from .log_utils import LogType, nni_log | |
| from .commands import CommandType | |
| INTERVAL_SECONDS = 0.5 | |
| class BaseChannel(ABC): | |
| def __init__(self, args): | |
| self.is_keep_parsed = args.node_count > 1 | |
| self.args = args | |
| self.node_id = self.args.node_id | |
| def _inner_send(self, message): | |
| pass | |
| def _inner_receive(self): | |
| return [] | |
| def _inner_open(self): | |
| pass | |
| def _inner_close(self): | |
| pass | |
| def open(self): | |
| # initialize receive, send threads. | |
| self.is_running = True | |
| self.receive_queue = Queue() | |
| self.receive_thread = threading.Thread(target=self._receive_loop) | |
| self.receive_thread.start() | |
| self.send_queue = Queue() | |
| self.send_thread = threading.Thread(target=self._send_loop) | |
| self.send_thread.start() | |
| self._inner_open() | |
| client_info = { | |
| "isReady": True, | |
| "runnerId": self.args.runner_id, | |
| "expId": self.args.exp_id, | |
| } | |
| nni_log(LogType.Info, 'Channel: send ready information %s' % client_info) | |
| self.send(CommandType.Initialized, client_info) | |
| def close(self): | |
| self.is_running = False | |
| try: | |
| self._inner_close() | |
| except Exception as err: | |
| # ignore any error on closing | |
| print("error on closing channel: %s" % err) | |
| def send(self, command, data): | |
| """Send command to Training Service. | |
| command: CommandType object. | |
| data: string payload. | |
| the message is sent synchronized. | |
| """ | |
| data["node"] = self.node_id | |
| data = json.dumps(data) | |
| data = data.encode('utf8') | |
| message = b'%b%014d%b' % (command.value, len(data), data) | |
| self.send_queue.put(message) | |
| def sent(self): | |
| return self.send_queue.qsize() == 0 | |
| def received(self): | |
| return self.receive_queue.qsize() > 0 | |
| def receive(self): | |
| """Receive a command from Training Service. | |
| Returns a tuple of command (CommandType) and payload (str) | |
| """ | |
| command = None | |
| data = None | |
| try: | |
| command_content = self.receive_queue.get(False) | |
| if command_content is not None: | |
| if (len(command_content) < 16): | |
| # invalid header | |
| nni_log(LogType.Error, 'incorrect command is found, command must be greater than 16 bytes!') | |
| return None, None | |
| header = command_content[:16] | |
| command = CommandType(header[:2]) | |
| length = int(header[2:]) | |
| if (len(command_content)-16 != length): | |
| nni_log(LogType.Error, 'incorrect command length, length {}, actual data length is {}, header {}.' | |
| .format(length, len(command_content)-16, header)) | |
| return None, None | |
| data = command_content[16:16+length] | |
| data = json.loads(data.decode('utf8')) | |
| if self.node_id is None: | |
| nni_log(LogType.Info, 'Received command, header: [%s], data: [%s]' % (header, data)) | |
| else: | |
| nni_log(LogType.Info, 'Received command(%s), header: [%s], data: [%s]' % (self.node_id, header, data)) | |
| except Empty: | |
| # do nothing, if no command received. | |
| pass | |
| except Exception as identifier: | |
| nni_log(LogType.Error, 'meet unhandled exception in base_channel: %s' % identifier) | |
| return command, data | |
| def _fetch_message(self, buffer, has_new_line=False): | |
| messages = [] | |
| while(len(buffer)) >= 16: | |
| header = buffer[:16] | |
| length = int(header[2:]) | |
| message_length = length+16 | |
| total_length = message_length | |
| if has_new_line: | |
| total_length += 1 | |
| # break, if buffer is too short. | |
| if len(buffer) < total_length: | |
| break | |
| data = buffer[16:message_length] | |
| if has_new_line and 10 != buffer[total_length-1]: | |
| nni_log(LogType.Error, 'end of message should be \\n, but got {}'.format(self.in_cache[total_length-1])) | |
| buffer = buffer[total_length:] | |
| messages.append(header + data) | |
| return messages, buffer | |
| def _receive_loop(self): | |
| while (self.is_running): | |
| messages = self._inner_receive() | |
| if messages is not None: | |
| for message in messages: | |
| self.receive_queue.put(message) | |
| time.sleep(INTERVAL_SECONDS) | |
| def _send_loop(self): | |
| while (self.is_running): | |
| message = None | |
| try: | |
| # no sleep, since it's a block call with INTERVAL_SECONDS second timeout | |
| message = self.send_queue.get(True, INTERVAL_SECONDS) | |
| except Empty: | |
| # do nothing, if no command received. | |
| pass | |
| if message is not None: | |
| self._inner_send(message) | |