Spaces:
Sleeping
Sleeping
| import json | |
| import inspect | |
| import hashlib | |
| from _plotly_utils.utils import PlotlyJSONEncoder | |
| from dash.long_callback.managers import BaseLongCallbackManager | |
| class CeleryLongCallbackManager(BaseLongCallbackManager): | |
| def __init__(self, celery_app, cache_by=None, expire=None): | |
| """ | |
| Long callback manager that runs callback logic on a celery task queue, | |
| and stores results using a celery result backend. | |
| :param celery_app: | |
| A celery.Celery application instance that must be configured with a | |
| result backend. See the celery documentation for information on | |
| configuration options. | |
| :param cache_by: | |
| A list of zero-argument functions. When provided, caching is enabled and | |
| the return values of these functions are combined with the callback | |
| function's input arguments and source code to generate cache keys. | |
| :param expire: | |
| If provided, a cache entry will be removed when it has not been accessed | |
| for ``expire`` seconds. If not provided, the lifetime of cache entries | |
| is determined by the default behavior of the celery result backend. | |
| """ | |
| try: | |
| import celery # pylint: disable=import-outside-toplevel,import-error | |
| from celery.backends.base import ( # pylint: disable=import-outside-toplevel,import-error | |
| DisabledBackend, | |
| ) | |
| except ImportError as missing_imports: | |
| raise ImportError( | |
| """\ | |
| CeleryLongCallbackManager requires extra dependencies which can be installed doing | |
| $ pip install "dash[celery]"\n""" | |
| ) from missing_imports | |
| if not isinstance(celery_app, celery.Celery): | |
| raise ValueError("First argument must be a celery.Celery object") | |
| if isinstance(celery_app.backend, DisabledBackend): | |
| raise ValueError("Celery instance must be configured with a result backend") | |
| super().__init__(cache_by) | |
| self.handle = celery_app | |
| self.expire = expire | |
| def terminate_job(self, job): | |
| if job is None: | |
| return | |
| self.handle.control.terminate(job) | |
| def terminate_unhealthy_job(self, job): | |
| task = self.get_task(job) | |
| if task and task.status in ("FAILURE", "REVOKED"): | |
| return self.terminate_job(job) | |
| return False | |
| def job_running(self, job): | |
| future = self.get_task(job) | |
| return future and future.status in ( | |
| "PENDING", | |
| "RECEIVED", | |
| "STARTED", | |
| "RETRY", | |
| "PROGRESS", | |
| ) | |
| def make_job_fn(self, fn, progress, args_deps): | |
| return _make_job_fn(fn, self.handle, progress, args_deps) | |
| def get_task(self, job): | |
| if job: | |
| return self.handle.AsyncResult(job) | |
| return None | |
| def clear_cache_entry(self, key): | |
| self.handle.backend.delete(key) | |
| def call_job_fn(self, key, job_fn, args): | |
| task = job_fn.delay(key, self._make_progress_key(key), args) | |
| return task.task_id | |
| def get_progress(self, key): | |
| progress_key = self._make_progress_key(key) | |
| progress_data = self.handle.backend.get(progress_key) | |
| if progress_data: | |
| return json.loads(progress_data) | |
| return None | |
| def result_ready(self, key): | |
| return self.handle.backend.get(key) is not None | |
| def get_result(self, key, job): | |
| # Get result value | |
| result = self.handle.backend.get(key) | |
| if result is None: | |
| return None | |
| result = json.loads(result) | |
| # Clear result if not caching | |
| if self.cache_by is None: | |
| self.clear_cache_entry(key) | |
| else: | |
| if self.expire: | |
| # Set/update expiration time | |
| self.handle.backend.expire(key, self.expire) | |
| self.clear_cache_entry(self._make_progress_key(key)) | |
| self.terminate_job(job) | |
| return result | |
| def _make_job_fn(fn, celery_app, progress, args_deps): | |
| cache = celery_app.backend | |
| # Hash function source and module to create a unique (but stable) celery task name | |
| fn_source = inspect.getsource(fn) | |
| fn_str = fn_source | |
| fn_hash = hashlib.sha1(fn_str.encode("utf-8")).hexdigest() | |
| def job_fn(result_key, progress_key, user_callback_args, fn=fn): | |
| def _set_progress(progress_value): | |
| cache.set(progress_key, json.dumps(progress_value, cls=PlotlyJSONEncoder)) | |
| maybe_progress = [_set_progress] if progress else [] | |
| if isinstance(args_deps, dict): | |
| user_callback_output = fn(*maybe_progress, **user_callback_args) | |
| elif isinstance(args_deps, (list, tuple)): | |
| user_callback_output = fn(*maybe_progress, *user_callback_args) | |
| else: | |
| user_callback_output = fn(*maybe_progress, user_callback_args) | |
| cache.set(result_key, json.dumps(user_callback_output, cls=PlotlyJSONEncoder)) | |
| return job_fn | |