Spaces:
Running
Running
| # Copyright (c) Microsoft Corporation. | |
| # Licensed under the MIT license. | |
| import asyncio | |
| import os | |
| import websockets | |
| from .base_channel import BaseChannel | |
| from .log_utils import LogType, nni_log | |
| class WebChannel(BaseChannel): | |
| def __init__(self, args): | |
| self.node_id = args.node_id | |
| self.args = args | |
| self.client = None | |
| self.in_cache = b"" | |
| self.timeout = 10 | |
| super(WebChannel, self).__init__(args) | |
| self._event_loop = None | |
| def _inner_open(self): | |
| url = "ws://{}:{}".format(self.args.nnimanager_ip, self.args.nnimanager_port) | |
| try: | |
| connect = asyncio.wait_for(websockets.connect(url), self.timeout) | |
| self._event_loop = asyncio.get_event_loop() | |
| client = self._event_loop.run_until_complete(connect) | |
| self.client = client | |
| nni_log(LogType.Info, 'WebChannel: connected with info %s' % url) | |
| except asyncio.TimeoutError: | |
| nni_log(LogType.Error, 'connect to %s timeout! Please make sure NNIManagerIP configured correctly, and accessable.' % url) | |
| os._exit(1) | |
| def _inner_close(self): | |
| if self.client is not None: | |
| self.client.close() | |
| self.client = None | |
| if self._event_loop.is_running(): | |
| self._event_loop.stop() | |
| self._event_loop = None | |
| def _inner_send(self, message): | |
| loop = asyncio.new_event_loop() | |
| loop.run_until_complete(self.client.send(message)) | |
| def _inner_receive(self): | |
| messages = [] | |
| if self.client is not None: | |
| received = self._event_loop.run_until_complete(self.client.recv()) | |
| # receive message is string, to get consistent result, encode it here. | |
| self.in_cache += received.encode("utf8") | |
| messages, self.in_cache = self._fetch_message(self.in_cache) | |
| return messages | |