codewraith / data /source_files /clean /319db52d59c7.py
slenk's picture
Upload folder using huggingface_hub
eeef81e verified
import json
import inspect
import hashlib
from _plotly_utils.utils import PlotlyJSONEncoder
from dash.long_callback.managers import BaseLongCallbackManager
class CeleryLongCallbackManager(BaseLongCallbackManager):
def __init__(self, celery_app, cache_by=None, expire=None):
"""
Long callback manager that runs callback logic on a celery task queue,
and stores results using a celery result backend.
:param celery_app:
A celery.Celery application instance that must be configured with a
result backend. See the celery documentation for information on
configuration options.
:param cache_by:
A list of zero-argument functions. When provided, caching is enabled and
the return values of these functions are combined with the callback
function's input arguments and source code to generate cache keys.
:param expire:
If provided, a cache entry will be removed when it has not been accessed
for ``expire`` seconds. If not provided, the lifetime of cache entries
is determined by the default behavior of the celery result backend.
"""
try:
import celery # pylint: disable=import-outside-toplevel,import-error
from celery.backends.base import ( # pylint: disable=import-outside-toplevel,import-error
DisabledBackend,
)
except ImportError as missing_imports:
raise ImportError(
"""\
CeleryLongCallbackManager requires extra dependencies which can be installed doing
$ pip install "dash[celery]"\n"""
) from missing_imports
if not isinstance(celery_app, celery.Celery):
raise ValueError("First argument must be a celery.Celery object")
if isinstance(celery_app.backend, DisabledBackend):
raise ValueError("Celery instance must be configured with a result backend")
super().__init__(cache_by)
self.handle = celery_app
self.expire = expire
def terminate_job(self, job):
if job is None:
return
self.handle.control.terminate(job)
def terminate_unhealthy_job(self, job):
task = self.get_task(job)
if task and task.status in ("FAILURE", "REVOKED"):
return self.terminate_job(job)
return False
def job_running(self, job):
future = self.get_task(job)
return future and future.status in (
"PENDING",
"RECEIVED",
"STARTED",
"RETRY",
"PROGRESS",
)
def make_job_fn(self, fn, progress, args_deps):
return _make_job_fn(fn, self.handle, progress, args_deps)
def get_task(self, job):
if job:
return self.handle.AsyncResult(job)
return None
def clear_cache_entry(self, key):
self.handle.backend.delete(key)
def call_job_fn(self, key, job_fn, args):
task = job_fn.delay(key, self._make_progress_key(key), args)
return task.task_id
def get_progress(self, key):
progress_key = self._make_progress_key(key)
progress_data = self.handle.backend.get(progress_key)
if progress_data:
return json.loads(progress_data)
return None
def result_ready(self, key):
return self.handle.backend.get(key) is not None
def get_result(self, key, job):
# Get result value
result = self.handle.backend.get(key)
if result is None:
return None
result = json.loads(result)
# Clear result if not caching
if self.cache_by is None:
self.clear_cache_entry(key)
else:
if self.expire:
# Set/update expiration time
self.handle.backend.expire(key, self.expire)
self.clear_cache_entry(self._make_progress_key(key))
self.terminate_job(job)
return result
def _make_job_fn(fn, celery_app, progress, args_deps):
cache = celery_app.backend
# Hash function source and module to create a unique (but stable) celery task name
fn_source = inspect.getsource(fn)
fn_str = fn_source
fn_hash = hashlib.sha1(fn_str.encode("utf-8")).hexdigest()
@celery_app.task(name=f"long_callback_{fn_hash}")
def job_fn(result_key, progress_key, user_callback_args, fn=fn):
def _set_progress(progress_value):
cache.set(progress_key, json.dumps(progress_value, cls=PlotlyJSONEncoder))
maybe_progress = [_set_progress] if progress else []
if isinstance(args_deps, dict):
user_callback_output = fn(*maybe_progress, **user_callback_args)
elif isinstance(args_deps, (list, tuple)):
user_callback_output = fn(*maybe_progress, *user_callback_args)
else:
user_callback_output = fn(*maybe_progress, user_callback_args)
cache.set(result_key, json.dumps(user_callback_output, cls=PlotlyJSONEncoder))
return job_fn