Spaces:
Running
Running
| from abc import ABCMeta, abstractmethod | |
| from functools import wraps | |
| from threading import Lock | |
| from typing import Optional, Any, Mapping, Type, Callable | |
| from uuid import uuid4, UUID | |
| import requests | |
| from requests.exceptions import RequestException | |
| from .base import _BEFORE_HOOK_TYPE, _AFTER_HOOK_TYPE, _ERROR_HOOK_TYPE | |
| from .task import Task, _task_complete, _task_fail | |
| from ..base import random_token, ControllableContext, get_http_engine_class, get_values_from_response | |
| from ..config import DEFAULT_CHANNEL, DEFAULT_SLAVE_PORT, DEFAULT_REQUEST_RETRIES, DEFAULT_REQUEST_RETRY_WAITING | |
| from ..exception import get_slave_exception_by_error | |
| _COMPLETE_TRIGGER_NAME = '__TASK_COMPLETE__' | |
| _FAIL_TRIGGER_NAME = '__TASK_FAIL__' | |
| class _ISlaveConnection(ControllableContext, metaclass=ABCMeta): | |
| """ | |
| Overview: | |
| Basic model of the connection classes, such as `SlaveConnection` and `SlaveConnectionProxy`, \ | |
| which are used widely in interaction module. | |
| Example: | |
| - The following code shows a sample to correctly use slave connection | |
| >>> connection = master.new_connection('cnn1,', '127.0.0.1', 2333) | |
| >>> connection.connect() | |
| >>> try: | |
| >>> pass # do anything you like | |
| >>> finally: | |
| >>> connection.disconnect() | |
| - Another simple structure of the code above | |
| >>> with master.new_connection('cnn1,', '127.0.0.1', 2333) as connection: | |
| >>> pass # do anything you like, connect and disconnect will be done automatically | |
| """ | |
| def connect(self): | |
| """ | |
| Overview: | |
| Connect to slave end. | |
| """ | |
| raise NotImplementedError # pragma: no cover | |
| def disconnect(self): | |
| """ | |
| Overview: | |
| Disconnect from slave end. | |
| """ | |
| raise NotImplementedError # pragma: no cover | |
| def new_task(self, data: Optional[Mapping[str, Any]] = None): | |
| """ | |
| Overview: | |
| Send new task to slave end and receive task result from it. | |
| Arguments: | |
| - data (:obj:`Optional[Mapping[str, Any]]`): Data of the new task | |
| Returns: | |
| - result (:obj:`Mapping[str, Any]`): Result of the task processed by slave end | |
| """ | |
| raise NotImplementedError # pragma: no cover | |
| def start(self): | |
| """ | |
| Overview: | |
| Alias for `connect`, for supporting context manager. | |
| """ | |
| self.connect() | |
| def close(self): | |
| """ | |
| Overview: | |
| Alias for `disconnect`, for support context manager. | |
| """ | |
| self.disconnect() | |
| class SlaveConnection(_ISlaveConnection, metaclass=ABCMeta): | |
| """ | |
| Overview: | |
| Slave connection object, which need to directly interact with slave end. | |
| """ | |
| def __init__( | |
| self, | |
| host: str, | |
| port: Optional[int] = None, | |
| https: bool = False, | |
| channel: Optional[int] = None, | |
| my_address: Optional[str] = None, | |
| token: Optional[str] = None, | |
| request_retries: Optional[int] = None, | |
| request_retry_waiting: Optional[float] = None, | |
| ): | |
| """ | |
| Overview: | |
| Constructor of `SlaveConnection` | |
| Arguments: | |
| - host (:obj:`str`): Host of the slave server | |
| - port (:obj:`Optional[int]`): Port of the slave server (None means `7236`) | |
| - https (:obj:`bool`): Use https or not | |
| - channel (:obj:`Optional[int]`): Channel id for the slave client. | |
| - my_address (:obj:`Optional[str]`): The address of current server (None will grep local ip automatically, \ | |
| this address will be used when connect to slave, the slave's request will be send to this address, \ | |
| **so please make sure the address can be achieved by slave**) | |
| - token (:obj:`Optional[str]`): Token of this connection, it is a token for authenticate to the \ | |
| connection (`None` means this token would be randomly generated) | |
| - request_retries (:obj:`Optional[int]`): Max times for request retries (None means `5`) | |
| - request_retry_waiting (:obj:`Optional[float]`): Sleep time before requests' retrying (None means `1.0`, \ | |
| unit: second) | |
| """ | |
| # meta info part | |
| self.__channel = channel or DEFAULT_CHANNEL | |
| self.__my_address = my_address | |
| self.__token = token or random_token() | |
| # request part | |
| self.__http_engine = get_http_engine_class( | |
| headers={ | |
| 'Channel': lambda: str(self.__channel), | |
| 'Token': lambda: self.__token, | |
| }, | |
| http_error_gene=get_slave_exception_by_error, | |
| )()(host, port or DEFAULT_SLAVE_PORT, https) | |
| self.__request_retries = max(request_retries or DEFAULT_REQUEST_RETRIES, 0) | |
| self.__request_retry_waiting = max(request_retry_waiting or DEFAULT_REQUEST_RETRY_WAITING, 0.0) | |
| # threading part | |
| self.__lock = Lock() | |
| self.__is_connected = False | |
| # task part | |
| self.__tasks = {} | |
| self.__init_triggers() | |
| def __request(self, method: str, path: str, data: Optional[Mapping[str, Any]] = None) -> requests.Response: | |
| return self.__http_engine.request( | |
| method, | |
| path, | |
| data, | |
| retries=self.__request_retries, | |
| retry_waiting=self.__request_retry_waiting, | |
| ) | |
| def is_connected(self) -> bool: | |
| """ | |
| Overview: | |
| Check connection status | |
| Returns: | |
| - connected (:obj:`bool`): Whether this connection is still alive | |
| """ | |
| with self.__lock: | |
| return self.__is_connected | |
| def _before_connect(self) -> Mapping[str, Any]: | |
| pass # pragma: no cover | |
| def _after_connect( | |
| self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str, Any]] | |
| ) -> Any: | |
| pass # pragma: no cover | |
| def _error_connect(self, error: RequestException) -> Any: | |
| raise error # pragma: no cover | |
| def __connect(self): | |
| try: | |
| response = self.__request( | |
| 'POST', '/connect', { | |
| 'master': { | |
| 'address': self.__my_address, | |
| }, | |
| 'data': (self._before_connect() or {}) | |
| } | |
| ) | |
| except RequestException as err: | |
| return self._error_connect(err) | |
| else: | |
| self.__is_connected = True | |
| return self._after_connect(*get_values_from_response(response)) | |
| def connect(self): | |
| with self.__lock: | |
| return self.__connect() | |
| def _before_disconnect(self) -> Mapping[str, Any]: | |
| pass # pragma: no cover | |
| def _after_disconnect( | |
| self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str, Any]] | |
| ) -> Any: | |
| pass # pragma: no cover | |
| def _error_disconnect(self, error: RequestException) -> Any: | |
| raise error # pragma: no cover | |
| def __disconnect(self): | |
| try: | |
| response = self.__request('DELETE', '/disconnect', { | |
| 'data': self._before_disconnect() or {}, | |
| }) | |
| except RequestException as err: | |
| return self._error_disconnect(err) | |
| else: | |
| self.__is_connected = False | |
| return self._after_disconnect(*get_values_from_response(response)) | |
| def disconnect(self): | |
| with self.__lock: | |
| return self.__disconnect() | |
| def _before_new_task(self, data: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]: | |
| return data # pragma: no cover | |
| def _after_new_task( | |
| self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str, Any]] | |
| ) -> Any: | |
| pass # pragma: no cover | |
| def _error_new_task(self, error: RequestException) -> Any: | |
| raise error # pragma: no cover | |
| def new_task(self, data: Optional[Mapping[str, Any]] = None) -> Task: | |
| with self.__lock: | |
| _uuid = uuid4() | |
| _task = Task( | |
| http_engine=self.__http_engine, | |
| data=data, | |
| task_id=_uuid, | |
| before_task_start=self._before_new_task, | |
| after_task_start=self._after_new_task, | |
| error_task_start=self._error_new_task, | |
| ) | |
| self.__tasks[_uuid] = _task | |
| return _task | |
| def __task_complete(self, task_id: UUID, task_result: Mapping[str, Any]): | |
| _task = self.__tasks[task_id] | |
| _task_complete(_task, task_result) | |
| del self.__tasks[task_id] | |
| def __task_fail(self, task_id: UUID, task_result: Mapping[str, Any]): | |
| _task = self.__tasks[task_id] | |
| _task_fail(_task, task_result) | |
| del self.__tasks[task_id] | |
| def __task_complete_trigger(self, task_id: UUID, task_result: Mapping[str, Any]): | |
| with self.__lock: | |
| if task_id in self.__tasks.keys(): | |
| return self.__task_complete(task_id, task_result) | |
| else: | |
| raise KeyError("Task {uuid} not found in this connection.".format(uuid=repr(str(task_id)))) | |
| def __task_fail_trigger(self, task_id: UUID, task_result: Mapping[str, Any]): | |
| with self.__lock: | |
| if task_id in self.__tasks.keys(): | |
| return self.__task_fail(task_id, task_result) | |
| else: | |
| raise KeyError("Task {uuid} not found in this connection.".format(uuid=repr(str(task_id)))) | |
| def __init_triggers(self): | |
| setattr(self, _COMPLETE_TRIGGER_NAME, self.__task_complete_trigger) | |
| setattr(self, _FAIL_TRIGGER_NAME, self.__task_fail_trigger) | |
| def _connection_task_complete(connection: SlaveConnection, task_id: UUID, task_result: Mapping[str, Any]): | |
| return getattr(connection, _COMPLETE_TRIGGER_NAME)(task_id, task_result) | |
| def _connection_task_fail(connection: SlaveConnection, task_id: UUID, task_result: Mapping[str, Any]): | |
| return getattr(connection, _FAIL_TRIGGER_NAME)(task_id, task_result) | |
| class SlaveConnectionProxy(_ISlaveConnection): | |
| """ | |
| Overview: | |
| Proxy class for `SlaveConnection` class, which wraps the original methods. | |
| """ | |
| def __init__( | |
| self, | |
| connection: SlaveConnection, | |
| after_connect: Optional[Callable] = None, | |
| after_disconnect: Optional[Callable] = None | |
| ): | |
| """ | |
| Overview: | |
| Constructor of `SlaveConnectionProxy` | |
| Arguments: | |
| - connection (:obj:`SlaveConnection`): Slave connection object | |
| - after_connect (:obj:`Optional[Callable]`): Behaviour going to be executed after connection established | |
| - after_disconnect (:obj:`Optional[Callable]`): Behaviour going to be executed after connection killed | |
| """ | |
| self.__connection = connection | |
| self.__lock = Lock() | |
| self.__after_connect = after_connect | |
| self.__after_disconnect = after_disconnect | |
| self.__init_triggers() | |
| def is_connected(self) -> bool: | |
| """ | |
| Overview: | |
| Check connection status | |
| Returns: | |
| - connected (:obj:`bool`): Whether this connection is still alive | |
| """ | |
| with self.__lock: | |
| return self.__connection.is_connected | |
| def connect(self): | |
| with self.__lock: | |
| result = self.__connection.connect() | |
| if self.__after_connect is not None: | |
| self.__after_connect(connection=self) | |
| return result | |
| def disconnect(self): | |
| with self.__lock: | |
| result = self.__connection.disconnect() | |
| if self.__after_disconnect is not None: | |
| self.__after_disconnect(connection=self) | |
| return result | |
| def new_task(self, data: Optional[Mapping[str, Any]] = None): | |
| with self.__lock: | |
| return self.__connection.new_task(data) | |
| def __task_complete_trigger(self, task_id: UUID, task_result: Mapping[str, Any]): | |
| with self.__lock: | |
| return _connection_task_complete(self.__connection, task_id, task_result) | |
| def __task_fail_trigger(self, task_id: UUID, task_result: Mapping[str, Any]): | |
| with self.__lock: | |
| return _connection_task_fail(self.__connection, task_id, task_result) | |
| def __init_triggers(self): | |
| setattr(self, _COMPLETE_TRIGGER_NAME, self.__task_complete_trigger) | |
| setattr(self, _FAIL_TRIGGER_NAME, self.__task_fail_trigger) | |
| def _proxy_task_complete(proxy: SlaveConnectionProxy, task_id: UUID, task_result: Mapping[str, Any]): | |
| return getattr(proxy, _COMPLETE_TRIGGER_NAME)(task_id, task_result) | |
| def _proxy_task_fail(proxy: SlaveConnectionProxy, task_id: UUID, task_result: Mapping[str, Any]): | |
| return getattr(proxy, _FAIL_TRIGGER_NAME)(task_id, task_result) | |
| def _slave_task_complete(connection: _ISlaveConnection, task_id: UUID, task_result: Mapping[str, Any]): | |
| if isinstance(connection, SlaveConnection): | |
| return _connection_task_complete(connection, task_id, task_result) | |
| elif isinstance(connection, SlaveConnectionProxy): | |
| return _proxy_task_complete(connection, task_id, task_result) | |
| else: | |
| raise TypeError( | |
| "{expect1} or {expect2} expected, but {actual} found.".format( | |
| expect1=SlaveConnection.__name__, | |
| expect2=SlaveConnectionProxy.__name__, | |
| actual=type(connection).__name__, | |
| ) | |
| ) | |
| def _slave_task_fail(connection: _ISlaveConnection, task_id: UUID, task_result: Mapping[str, Any]): | |
| if isinstance(connection, SlaveConnection): | |
| return _connection_task_fail(connection, task_id, task_result) | |
| elif isinstance(connection, SlaveConnectionProxy): | |
| return _proxy_task_fail(connection, task_id, task_result) | |
| else: | |
| raise TypeError( | |
| "{expect1} or {expect2} expected, but {actual} found.".format( | |
| expect1=SlaveConnection.__name__, | |
| expect2=SlaveConnectionProxy.__name__, | |
| actual=type(connection).__name__, | |
| ) | |
| ) | |
| def _default_wrap(func: Callable) -> Callable: | |
| def _new_func(*args, **kwargs): | |
| if func: | |
| return func(*args, **kwargs) | |
| else: | |
| return None | |
| return _new_func | |
| def _get_connection_class( | |
| before_new_task: Optional[_BEFORE_HOOK_TYPE] = None, | |
| after_new_task: Optional[_AFTER_HOOK_TYPE] = None, | |
| error_new_task: Optional[_ERROR_HOOK_TYPE] = None, | |
| before_connect: Optional[_BEFORE_HOOK_TYPE] = None, | |
| after_connect: Optional[_AFTER_HOOK_TYPE] = None, | |
| error_connect: Optional[_ERROR_HOOK_TYPE] = None, | |
| before_disconnect: Optional[_BEFORE_HOOK_TYPE] = None, | |
| after_disconnect: Optional[_AFTER_HOOK_TYPE] = None, | |
| error_disconnect: Optional[_ERROR_HOOK_TYPE] = None, | |
| ) -> Type[SlaveConnection]: | |
| class _Connection(SlaveConnection): | |
| def _before_connect(self) -> Mapping[str, Any]: | |
| return _default_wrap(before_connect)() or {} | |
| def _after_connect( | |
| self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str, | |
| Any]] | |
| ) -> Any: | |
| return _default_wrap(after_connect)(status_code, success, code, message, data) | |
| def _error_connect(self, error: RequestException) -> Any: | |
| return _default_wrap(error_connect)(error) | |
| def _before_disconnect(self) -> Mapping[str, Any]: | |
| return _default_wrap(before_disconnect)() or {} | |
| def _after_disconnect( | |
| self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str, | |
| Any]] | |
| ) -> Any: | |
| return _default_wrap(after_disconnect)(status_code, success, code, message, data) | |
| def _error_disconnect(self, error: RequestException) -> Any: | |
| return _default_wrap(error_disconnect)(error) | |
| def _before_new_task(self, data: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]: | |
| return _default_wrap(before_new_task)(data) or {} | |
| def _after_new_task( | |
| self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str, | |
| Any]] | |
| ) -> Any: | |
| return _default_wrap(after_new_task)(status_code, success, code, message, data) | |
| def _error_new_task(self, error: RequestException) -> Any: | |
| return _default_wrap(error_new_task)(error) | |
| return _Connection | |