|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import abc |
|
|
import logging |
|
|
import threading |
|
|
import time |
|
|
from contextlib import contextmanager |
|
|
from inspect import getframeinfo, stack |
|
|
from typing import Any, Optional |
|
|
|
|
|
|
|
|
__all__ = [ |
|
|
"TimerRequest", |
|
|
"TimerClient", |
|
|
"RequestQueue", |
|
|
"TimerServer", |
|
|
"configure", |
|
|
"expires", |
|
|
] |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class TimerRequest: |
|
|
""" |
|
|
Data object representing a countdown timer acquisition and release |
|
|
that is used between the ``TimerClient`` and ``TimerServer``. |
|
|
A negative ``expiration_time`` should be interpreted as a "release" |
|
|
request. |
|
|
|
|
|
.. note:: the type of ``worker_id`` is implementation specific. |
|
|
It is whatever the TimerServer and TimerClient implementations |
|
|
have on to uniquely identify a worker. |
|
|
""" |
|
|
|
|
|
__slots__ = ["worker_id", "scope_id", "expiration_time"] |
|
|
|
|
|
def __init__(self, worker_id: Any, scope_id: str, expiration_time: float): |
|
|
self.worker_id = worker_id |
|
|
self.scope_id = scope_id |
|
|
self.expiration_time = expiration_time |
|
|
|
|
|
def __eq__(self, other): |
|
|
if isinstance(other, TimerRequest): |
|
|
return ( |
|
|
self.worker_id == other.worker_id |
|
|
and self.scope_id == other.scope_id |
|
|
and self.expiration_time == other.expiration_time |
|
|
) |
|
|
return False |
|
|
|
|
|
|
|
|
class TimerClient(abc.ABC): |
|
|
""" |
|
|
Client library to acquire and release countdown timers by communicating |
|
|
with the TimerServer. |
|
|
""" |
|
|
|
|
|
@abc.abstractmethod |
|
|
def acquire(self, scope_id: str, expiration_time: float) -> None: |
|
|
""" |
|
|
Acquires a timer for the worker that holds this client object |
|
|
given the scope_id and expiration_time. Typically registers |
|
|
the timer with the TimerServer. |
|
|
""" |
|
|
|
|
|
@abc.abstractmethod |
|
|
def release(self, scope_id: str): |
|
|
""" |
|
|
Releases the timer for the ``scope_id`` on the worker this |
|
|
client represents. After this method is |
|
|
called, the countdown timer on the scope is no longer in effect. |
|
|
""" |
|
|
|
|
|
|
|
|
class RequestQueue(abc.ABC): |
|
|
""" |
|
|
Consumer queue holding timer acquisition/release requests |
|
|
""" |
|
|
|
|
|
@abc.abstractmethod |
|
|
def size(self) -> int: |
|
|
""" |
|
|
Returns the size of the queue at the time this method is called. |
|
|
Note that by the time ``get`` is called the size of the queue |
|
|
may have increased. The size of the queue should not decrease |
|
|
until the ``get`` method is called. That is, the following assertion |
|
|
should hold: |
|
|
|
|
|
size = q.size() |
|
|
res = q.get(size, timeout=0) |
|
|
assert size == len(res) |
|
|
|
|
|
-- or -- |
|
|
|
|
|
size = q.size() |
|
|
res = q.get(size * 2, timeout=1) |
|
|
assert size <= len(res) <= size * 2 |
|
|
""" |
|
|
|
|
|
@abc.abstractmethod |
|
|
def get(self, size: int, timeout: float) -> list[TimerRequest]: |
|
|
""" |
|
|
Gets up to ``size`` number of timer requests in a blocking fashion |
|
|
(no more than ``timeout`` seconds). |
|
|
""" |
|
|
|
|
|
|
|
|
class TimerServer(abc.ABC): |
|
|
""" |
|
|
Entity that monitors active timers and expires them |
|
|
in a timely fashion. This server is responsible for |
|
|
reaping workers that have expired timers. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, request_queue: RequestQueue, max_interval: float, daemon: bool = True |
|
|
): |
|
|
""" |
|
|
:param request_queue: Consumer ``RequestQueue`` |
|
|
:param max_interval: max time (in seconds) to wait |
|
|
for an item in the request_queue |
|
|
:param daemon: whether to run the watchdog thread as a daemon |
|
|
""" |
|
|
super().__init__() |
|
|
self._request_queue = request_queue |
|
|
self._max_interval = max_interval |
|
|
self._daemon = daemon |
|
|
self._watchdog_thread: Optional[threading.Thread] = None |
|
|
self._stop_signaled = False |
|
|
|
|
|
@abc.abstractmethod |
|
|
def register_timers(self, timer_requests: list[TimerRequest]) -> None: |
|
|
""" |
|
|
Processes the incoming timer requests and registers them with the server. |
|
|
The timer request can either be a acquire-timer or release-timer request. |
|
|
Timer requests with a negative expiration_time should be interpreted |
|
|
as a release-timer request. |
|
|
""" |
|
|
|
|
|
@abc.abstractmethod |
|
|
def clear_timers(self, worker_ids: set[Any]) -> None: |
|
|
""" |
|
|
Clears all timers for the given ``worker_ids``. |
|
|
""" |
|
|
|
|
|
@abc.abstractmethod |
|
|
def get_expired_timers(self, deadline: float) -> dict[str, list[TimerRequest]]: |
|
|
""" |
|
|
Returns all expired timers for each worker_id. An expired timer |
|
|
is a timer for which the expiration_time is less than or equal to |
|
|
the provided deadline. |
|
|
""" |
|
|
|
|
|
@abc.abstractmethod |
|
|
def _reap_worker(self, worker_id: Any) -> bool: |
|
|
""" |
|
|
Reaps the given worker. Returns True if the worker has been |
|
|
successfully reaped, False otherwise. If any uncaught exception |
|
|
is thrown from this method, the worker is considered reaped |
|
|
and all associated timers will be removed. |
|
|
""" |
|
|
|
|
|
def _reap_worker_no_throw(self, worker_id: Any) -> bool: |
|
|
""" |
|
|
Wraps ``_reap_worker(worker_id)``, if an uncaught exception is |
|
|
thrown, then it considers the worker as reaped. |
|
|
""" |
|
|
try: |
|
|
return self._reap_worker(worker_id) |
|
|
except Exception: |
|
|
logger.exception( |
|
|
"Uncaught exception thrown from _reap_worker(), " |
|
|
"check that the implementation correctly catches exceptions", |
|
|
) |
|
|
return True |
|
|
|
|
|
def _watchdog_loop(self): |
|
|
while not self._stop_signaled: |
|
|
try: |
|
|
self._run_watchdog() |
|
|
except Exception: |
|
|
logger.exception("Error running watchdog") |
|
|
|
|
|
def _run_watchdog(self): |
|
|
batch_size = max(1, self._request_queue.size()) |
|
|
timer_requests = self._request_queue.get(batch_size, self._max_interval) |
|
|
self.register_timers(timer_requests) |
|
|
now = time.time() |
|
|
reaped_worker_ids = set() |
|
|
for worker_id, expired_timers in self.get_expired_timers(now).items(): |
|
|
logger.info( |
|
|
"Reaping worker_id=[%s]. Expired timers: %s", |
|
|
worker_id, |
|
|
self._get_scopes(expired_timers), |
|
|
) |
|
|
if self._reap_worker_no_throw(worker_id): |
|
|
logger.info("Successfully reaped worker=[%s]", worker_id) |
|
|
reaped_worker_ids.add(worker_id) |
|
|
else: |
|
|
logger.error( |
|
|
"Error reaping worker=[%s]. Will retry on next watchdog.", worker_id |
|
|
) |
|
|
self.clear_timers(reaped_worker_ids) |
|
|
|
|
|
def _get_scopes(self, timer_requests): |
|
|
return [r.scope_id for r in timer_requests] |
|
|
|
|
|
def start(self) -> None: |
|
|
logger.info( |
|
|
"Starting %s... max_interval=%s, daemon=%s", |
|
|
type(self).__name__, |
|
|
self._max_interval, |
|
|
self._daemon, |
|
|
) |
|
|
self._watchdog_thread = threading.Thread( |
|
|
target=self._watchdog_loop, daemon=self._daemon |
|
|
) |
|
|
logger.info("Starting watchdog thread...") |
|
|
self._watchdog_thread.start() |
|
|
|
|
|
def stop(self) -> None: |
|
|
logger.info("Stopping %s", type(self).__name__) |
|
|
self._stop_signaled = True |
|
|
if self._watchdog_thread: |
|
|
logger.info("Stopping watchdog thread...") |
|
|
self._watchdog_thread.join(self._max_interval) |
|
|
self._watchdog_thread = None |
|
|
else: |
|
|
logger.info("No watchdog thread running, doing nothing") |
|
|
|
|
|
|
|
|
_timer_client: Optional[TimerClient] = None |
|
|
|
|
|
|
|
|
def configure(timer_client: TimerClient): |
|
|
""" |
|
|
Configures a timer client. Must be called before using ``expires``. |
|
|
""" |
|
|
global _timer_client |
|
|
_timer_client = timer_client |
|
|
logger.info("Timer client configured to: %s", type(_timer_client).__name__) |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def expires( |
|
|
after: float, scope: Optional[str] = None, client: Optional[TimerClient] = None |
|
|
): |
|
|
""" |
|
|
Acquires a countdown timer that expires in ``after`` seconds from now, |
|
|
unless the code-block that it wraps is finished within the timeframe. |
|
|
When the timer expires, this worker is eligible to be reaped. The |
|
|
exact meaning of "reaped" depends on the client implementation. In |
|
|
most cases, reaping means to terminate the worker process. |
|
|
Note that the worker is NOT guaranteed to be reaped at exactly |
|
|
``time.now() + after``, but rather the worker is "eligible" for being |
|
|
reaped and the ``TimerServer`` that the client talks to will ultimately |
|
|
make the decision when and how to reap the workers with expired timers. |
|
|
|
|
|
Usage:: |
|
|
|
|
|
torch.distributed.elastic.timer.configure(LocalTimerClient()) |
|
|
with expires(after=10): |
|
|
torch.distributed.all_reduce(...) |
|
|
""" |
|
|
if client is None: |
|
|
if _timer_client is None: |
|
|
raise RuntimeError("Configure timer client before using countdown timers.") |
|
|
client = _timer_client |
|
|
if scope is None: |
|
|
|
|
|
caller = getframeinfo(stack()[1][0]) |
|
|
scope = f"{caller.filename}#{caller.lineno}" |
|
|
expiration = time.time() + after |
|
|
client.acquire(scope, expiration) |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
client.release(scope) |
|
|
|