Spaces:
Sleeping
Sleeping
| import json | |
| import sys | |
| import time | |
| import traceback | |
| from abc import abstractmethod | |
| from functools import wraps | |
| from threading import Thread, Event, Lock | |
| from typing import Optional, Callable, Any, Mapping | |
| from uuid import UUID | |
| import requests | |
| from flask import Flask, request | |
| from .action import ConnectionRefuse, DisconnectionRefuse, TaskRefuse, TaskFail | |
| from ..base import random_token, ControllableService, get_http_engine_class, split_http_address, success_response, \ | |
| failure_response, DblEvent | |
| from ..config import DEFAULT_SLAVE_PORT, DEFAULT_CHANNEL, GLOBAL_HOST, DEFAULT_HEARTBEAT_SPAN, MIN_HEARTBEAT_SPAN, \ | |
| DEFAULT_REQUEST_RETRIES, DEFAULT_REQUEST_RETRY_WAITING | |
| from ..exception import SlaveErrorCode, get_slave_exception_by_error, get_master_exception_by_error | |
| class Slave(ControllableService): | |
| r""" | |
| Overview: | |
| Interaction slave client | |
| """ | |
| def __init__( | |
| self, | |
| host: Optional[str] = None, | |
| port: Optional[int] = None, | |
| heartbeat_span: Optional[float] = None, | |
| request_retries: Optional[int] = None, | |
| request_retry_waiting: Optional[float] = None, | |
| channel: Optional[int] = None | |
| ): | |
| """ | |
| Overview: | |
| Constructor of Slave class | |
| Arguments: | |
| - host (:obj:`Optional[str]`): Host of the slave server, based on flask (None means `0.0.0.0`) | |
| - port (:obj:`Optional[int]`): Port of the slave server, based on flask (None means `7236`) | |
| - heartbeat_span (:obj:`Optional[float]`): Time span of heartbeat packages in seconds \ | |
| (None means `3.0`, minimum is `0.2`) | |
| - request_retries (:obj:`Optional[int]`): Max times for request retries (None means `5`) | |
| - request_retry_waiting (:obj:`Optional[float]`): Sleep time before requests' retrying (None means `1.0`) | |
| - channel (:obj:`Optional[int]`): Channel id for the slave client, please make sure that channel id is \ | |
| equal to the master client's channel id, or the connection cannot be established. (None means `0`, \ | |
| but 0 channel is not recommended to be used in production) | |
| """ | |
| # server part | |
| self.__host = host or GLOBAL_HOST | |
| self.__port = port or DEFAULT_SLAVE_PORT | |
| self.__flask_app_value = None | |
| self.__run_app_thread = Thread(target=self.__run_app, name='slave_run_app') | |
| # heartbeat part | |
| self.__heartbeat_span = max(heartbeat_span or DEFAULT_HEARTBEAT_SPAN, MIN_HEARTBEAT_SPAN) | |
| self.__heartbeat_thread = Thread(target=self.__heartbeat, name='slave_heartbeat') | |
| self.__request_retries = max(request_retries or DEFAULT_REQUEST_RETRIES, 0) | |
| self.__request_retry_waiting = max(request_retry_waiting or DEFAULT_REQUEST_RETRY_WAITING, 0.0) | |
| # task part | |
| self.__has_task = DblEvent() | |
| self.__task_lock = Lock() | |
| self.__task_id = None | |
| self.__task_data = None | |
| self.__task_thread = Thread(target=self.__task, name='slave_task') | |
| # self-connection part | |
| self.__self_http_engine = get_http_engine_class( | |
| headers={ | |
| 'Token': lambda: self.__self_token, | |
| }, | |
| http_error_gene=get_slave_exception_by_error, | |
| # )()('localhost', self.__port, False) | |
| )()(self.__host, self.__port, False) # TODO: Confirm how to ping itself | |
| self.__self_token = random_token() | |
| # master-connection part | |
| self.__channel = channel or DEFAULT_CHANNEL | |
| self.__connected = DblEvent() | |
| self.__master_token = None | |
| self.__master_address = None | |
| self.__master_http_engine = None | |
| # global part | |
| self.__shutdown_event = Event() | |
| self.__lock = Lock() | |
| # master connection | |
| def __register_master(self, token: str, address: str): | |
| self.__master_token = token | |
| self.__master_address = address | |
| self.__master_http_engine = get_http_engine_class( | |
| headers={ | |
| 'Channel': lambda: str(self.__channel), | |
| 'Token': lambda: self.__master_token, | |
| }, | |
| http_error_gene=get_master_exception_by_error, | |
| )()(*split_http_address(self.__master_address)) | |
| def __unregister_master(self): | |
| self.__master_token = None | |
| self.__master_address = None | |
| self.__master_http_engine = None | |
| def __open_master_connection(self, token: str, address: str): | |
| self.__register_master(token, address) | |
| self.__connected.open() | |
| def __close_master_connection(self): | |
| self.__unregister_master() | |
| self.__connected.close() | |
| # server part | |
| def __generate_app(self): | |
| app = Flask(__name__) | |
| # master apis | |
| app.route('/connect', methods=['POST'])(self.__check_master_request(self.__connect, False)) | |
| app.route('/disconnect', methods=['DELETE'])(self.__check_master_request(self.__disconnect, True)) | |
| app.route('/task/new', methods=['POST'])(self.__check_master_request(self.__new_task, True)) | |
| # self apis | |
| app.route('/ping', methods=['GET'])(self.__check_self_request(self.__self_ping)) | |
| app.route('/shutdown', methods=['DELETE'])(self.__check_self_request(self.__self_shutdown)) | |
| return app | |
| def __flask_app(self) -> Flask: | |
| return self.__flask_app_value or self.__generate_app() | |
| def __run_app(self): | |
| self.__flask_app().run( | |
| host=self.__host, | |
| port=self.__port, | |
| ) | |
| # both method checkers | |
| def __check_shutdown(self, func: Callable[[], Any]) -> Callable[[], Any]: | |
| def _func(): | |
| if self.__shutdown_event.is_set(): | |
| return failure_response( | |
| code=SlaveErrorCode.SYSTEM_SHUTTING_DOWN, message='System has already been shutting down.' | |
| ), 401 | |
| else: | |
| return func() | |
| return _func | |
| # server method checkers (master) | |
| def __check_master_request(self, | |
| func: Callable[[str, Mapping[str, Any]], Any], | |
| need_match: bool = True) -> Callable[[], Any]: | |
| return self.__check_shutdown(self.__check_channel(self.__check_master_token(func, need_match))) | |
| # noinspection DuplicatedCode | |
| def __check_channel(self, func: Callable[[], Any]) -> Callable[[], Any]: | |
| def _func(): | |
| channel = request.headers.get('Channel', None) | |
| channel = int(channel) if channel else None | |
| if channel is None: | |
| return failure_response(code=SlaveErrorCode.CHANNEL_NOT_FOUND, message='Channel not found.'), 400 | |
| elif channel != self.__channel: | |
| return failure_response( | |
| code=SlaveErrorCode.CHANNEL_INVALID, message='Channel not match with this endpoint.' | |
| ), 403 | |
| else: | |
| return func() | |
| return _func | |
| def __check_master_token(self, | |
| func: Callable[[str, Mapping[str, Any]], Any], | |
| need_match: bool = True) -> Callable[[], Any]: | |
| def _func(): | |
| master_token = request.headers.get('Token', None) | |
| if master_token is None: | |
| return failure_response( | |
| code=SlaveErrorCode.MASTER_TOKEN_NOT_FOUND, message='Master token not found.' | |
| ), 400 | |
| elif need_match and (master_token != self.__master_token): | |
| return failure_response( | |
| code=SlaveErrorCode.MASTER_TOKEN_INVALID, message='Master not match with this endpoint.' | |
| ), 403 | |
| else: | |
| return func(master_token, json.loads(request.data.decode())) | |
| return _func | |
| # server method checkers (self) | |
| # noinspection DuplicatedCode | |
| def __check_self_request(self, func: Callable[[], Any]) -> Callable[[], Any]: | |
| return self.__check_shutdown(self.__check_slave_token(func)) | |
| def __check_slave_token(self, func: Callable[[], Any]) -> Callable[[], Any]: | |
| def _func(): | |
| slave_token = request.headers.get('Token', None) | |
| if slave_token is None: | |
| return failure_response(code=SlaveErrorCode.SELF_TOKEN_NOT_FOUND, message='Slave token not found.'), 400 | |
| elif slave_token != self.__self_token: | |
| return failure_response( | |
| code=SlaveErrorCode.SELF_TOKEN_INVALID, message='Slave token not match with this endpoint.' | |
| ), 403 | |
| else: | |
| return func() | |
| return _func | |
| # server methods (self) | |
| # noinspection PyMethodMayBeStatic | |
| def __self_ping(self): | |
| return success_response(message='PONG!') | |
| def __self_shutdown(self): | |
| _shutdown_func = request.environ.get('werkzeug.server.shutdown') | |
| if _shutdown_func is None: | |
| raise RuntimeError('Not running with the Werkzeug Server') | |
| self.__shutdown_event.set() | |
| _shutdown_func() | |
| return success_response(message='Shutdown request received, this server will be down later.') | |
| # server methods (master) | |
| # noinspection PyUnusedLocal | |
| def __connect(self, token: str, data: Mapping[str, Any]): | |
| if self.__connected.is_open(): | |
| return failure_response( | |
| code=SlaveErrorCode.SLAVE_ALREADY_CONNECTED, message='This slave already connected.' | |
| ), 400 | |
| else: | |
| _master_info, _connection_data = data['master'], data['data'] | |
| try: | |
| self._before_connection(_connection_data) | |
| except ConnectionRefuse as err: | |
| return err.get_response() | |
| else: | |
| self.__open_master_connection(token, _master_info['address']) | |
| return success_response(message='Connect success.') | |
| # noinspection PyUnusedLocal | |
| def __new_task(self, token: str, data: Mapping[str, Any]): | |
| with self.__task_lock: | |
| if self.__has_task.is_open(): | |
| return failure_response(code=SlaveErrorCode.TASK_ALREADY_EXIST, message='Already has a task.'), 400 | |
| else: | |
| _task_info, _task_data = data['task'], data['data'] | |
| _task_id = _task_info['id'] | |
| try: | |
| self._before_task(_task_data) | |
| except TaskRefuse as err: | |
| return err.get_response() | |
| else: | |
| self.__task_id = UUID(_task_id) | |
| self.__task_data = _task_data | |
| self.__has_task.open() | |
| return success_response(message='Task received!') | |
| # noinspection PyUnusedLocal | |
| def __disconnect(self, token: str, data: Mapping[str, Any]): | |
| if self.__connected.is_close(): | |
| return failure_response( | |
| code=SlaveErrorCode.SLAVE_NOT_CONNECTED, message='This slave not connected yet.' | |
| ), 400 | |
| else: | |
| _disconnection_data = data['data'] | |
| try: | |
| self._before_disconnection(_disconnection_data) | |
| except DisconnectionRefuse as err: | |
| return err.get_response() | |
| else: | |
| self.__close_master_connection() | |
| return success_response(message='Disconnect success.') | |
| # heartbeat part | |
| def __heartbeat(self): | |
| _last_time = time.time() | |
| while not self.__shutdown_event.is_set(): | |
| if self.__connected.is_open(): | |
| try: | |
| self.__master_heartbeat() | |
| except requests.exceptions.RequestException as err: | |
| self._lost_connection(self.__master_address, err) | |
| self.__close_master_connection() | |
| traceback.print_exception(*sys.exc_info(), file=sys.stderr) | |
| _last_time += self.__heartbeat_span | |
| time.sleep(max(_last_time - time.time(), 0)) | |
| # task part | |
| def __task(self): | |
| while not self.__shutdown_event.is_set(): | |
| self.__has_task.wait_for_open(timeout=1.0) | |
| if self.__has_task.is_open(): | |
| # noinspection PyBroadException | |
| try: | |
| result = self._process_task(self.__task_data) | |
| except TaskFail as fail: | |
| self.__has_task.close() | |
| self.__master_task_fail(fail.result) | |
| except Exception: | |
| self.__has_task.close() | |
| traceback.print_exception(*sys.exc_info(), file=sys.stderr) | |
| else: | |
| self.__has_task.close() | |
| self.__master_task_complete(result) | |
| # self request operations | |
| def __self_request(self, method: Optional[str] = 'GET', path: Optional[str] = None) -> requests.Response: | |
| return self.__self_http_engine.request( | |
| method, | |
| path, | |
| retries=self.__request_retries, | |
| retry_waiting=self.__request_retry_waiting, | |
| ) | |
| def __ping_once(self): | |
| return self.__self_request('GET', '/ping') | |
| def __ping_until_started(self): | |
| while True: | |
| try: | |
| self.__ping_once() | |
| except (requests.exceptions.BaseHTTPError, requests.exceptions.RequestException): | |
| time.sleep(0.2) | |
| else: | |
| break | |
| def __shutdown(self): | |
| self.__self_request('DELETE', '/shutdown') | |
| # master request operations | |
| def __master_request( | |
| self, | |
| method: Optional[str] = 'GET', | |
| path: Optional[str] = None, | |
| data: Optional[Mapping[str, Any]] = None | |
| ) -> requests.Response: | |
| return self.__master_http_engine.request( | |
| method, | |
| path, | |
| data, | |
| retries=self.__request_retries, | |
| retry_waiting=self.__request_retry_waiting, | |
| ) | |
| def __master_heartbeat(self): | |
| return self.__master_request('GET', '/slave/heartbeat') | |
| def __master_task_complete(self, result: Mapping[str, Any]): | |
| return self.__master_request( | |
| 'PUT', '/slave/task/complete', data={ | |
| 'task': { | |
| 'id': str(self.__task_id) | |
| }, | |
| 'result': result or {}, | |
| } | |
| ) | |
| def __master_task_fail(self, result: Mapping[str, Any]): | |
| return self.__master_request( | |
| 'PUT', '/slave/task/fail', data={ | |
| 'task': { | |
| 'id': str(self.__task_id) | |
| }, | |
| 'result': result or {}, | |
| } | |
| ) | |
| # public methods | |
| def ping(self) -> bool: | |
| """ | |
| Overview: | |
| Ping the current http server, check if it still run properly. | |
| Returns: | |
| - output (:obj:`bool`): The http server run properly or not. \ | |
| `True` means run properly, otherwise return `False`. | |
| """ | |
| with self.__lock: | |
| try: | |
| self.__ping_once() | |
| except (requests.exceptions.BaseHTTPError, requests.exceptions.RequestException): | |
| return False | |
| else: | |
| return True | |
| def start(self): | |
| """ | |
| Overview: | |
| Start current slave client | |
| Here are the steps executed inside in order: | |
| 1. Start the task-processing thread | |
| 2. Start the heartbeat thread | |
| 3. Start the http server thread | |
| 4. Wait until the http server is online (can be pinged) | |
| """ | |
| with self.__lock: | |
| self.__task_thread.start() | |
| self.__heartbeat_thread.start() | |
| self.__run_app_thread.start() | |
| self.__ping_until_started() | |
| def shutdown(self): | |
| """ | |
| Overview: | |
| Shutdown current slave client. | |
| A shutdown request will be sent to the http server, and the shutdown signal will be apply into the \ | |
| threads, the server will be down soon (You can use `join` method to wait until that time). | |
| """ | |
| with self.__lock: | |
| self.__shutdown() | |
| def join(self): | |
| """ | |
| Overview: | |
| Wait until current slave client is down completely. | |
| Here are the steps executed inside in order: | |
| 1. Wait until the http server thread down | |
| 2. Wait until the heartbeat thread down | |
| 3. Wait until the task-processing thread down | |
| """ | |
| with self.__lock: | |
| self.__run_app_thread.join() | |
| self.__heartbeat_thread.join() | |
| self.__task_thread.join() | |
| # inherit method | |
| def _before_connection(self, data: Mapping[str, Any]): | |
| """ | |
| Overview: | |
| Behaviours that will be executed before connection is established. | |
| Arguments: | |
| - data (:obj:`Mapping[str, Any]`): Connection data when connect to this slave, sent from master. | |
| Raises: | |
| - `ConnectionRefuse` After raise this, the connection from master end will be refused, \ | |
| no new connection will be established. | |
| """ | |
| pass | |
| def _before_disconnection(self, data: Mapping[str, Any]): | |
| """ | |
| Overview: | |
| Behaviours that will be executed before disconnection is executed. | |
| Arguments: | |
| - data (:obj:`Mapping[str, Any]`): Disconnection data when disconnect with this slave, sent from master. | |
| Raises: | |
| - `DisconnectionRefuse` After raise this, the disconnection request will be refused, \ | |
| current connection will be still exist. | |
| """ | |
| pass | |
| def _before_task(self, data: Mapping[str, Any]): | |
| """ | |
| Overview: | |
| Behaviours that will be executed before task is executed. | |
| Arguments: | |
| - data (:obj:`Mapping[str, Any]`): Data of the task | |
| Raises: | |
| - `TaskRefuse` After raise this, the new task will be refused. | |
| """ | |
| pass | |
| def _lost_connection(self, master_address: str, err: requests.exceptions.RequestException): | |
| """ | |
| Overview: | |
| Behaviours that will be executed after connection is lost. | |
| Arguments: | |
| - master_address (:obj:`str`): String address of master end | |
| - err (:obj:`request.exceptions.RequestException`): Http exception of this connection loss | |
| """ | |
| pass | |
| def _process_task(self, task: Mapping[str, Any]): | |
| """ | |
| Overview: | |
| Execute the task, this protected method must be implement in the subclass. | |
| Arguments: | |
| - task (:obj:`Mapping[str, Any]`): Data of the task | |
| Raises: | |
| - `TaskFail` After raise this, this task will be recognized as run failed, \ | |
| master will received the failure signal. | |
| Example: | |
| - A success task with return value (the return value will be received in master end) | |
| >>> def _process_task(self, task): | |
| >>> print('this is task data :', task) | |
| >>> return str(task) | |
| - A failed task with data (the data will be received in master end) | |
| >>> def _process_task(self, task): | |
| >>> print('this is task data :', task) | |
| >>> raise TaskFail(task) # this is a failed task | |
| - A failed task with data and message (both will be received in master end) | |
| >>> def _process_task(self, task): | |
| >>> print('this is task data :', task) | |
| >>> raise TaskFail(task, 'this is message') # this is a failed task with message | |
| """ | |
| raise NotImplementedError | |