|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
import multiprocessing as mp |
|
|
import os |
|
|
import signal |
|
|
import time |
|
|
from queue import Empty |
|
|
from typing import Any, Dict, List, Set, Tuple |
|
|
|
|
|
from .api import RequestQueue, TimerClient, TimerRequest, TimerServer |
|
|
|
|
|
__all__ = ['LocalTimerClient', 'MultiprocessingRequestQueue', 'LocalTimerServer'] |
|
|
|
|
|
class LocalTimerClient(TimerClient): |
|
|
""" |
|
|
Client side of ``LocalTimerServer``. This client is meant to be used |
|
|
on the same host that the ``LocalTimerServer`` is running on and uses |
|
|
pid to uniquely identify a worker. This is particularly useful in situations |
|
|
where one spawns a subprocess (trainer) per GPU on a host with multiple |
|
|
GPU devices. |
|
|
""" |
|
|
|
|
|
def __init__(self, mp_queue): |
|
|
super().__init__() |
|
|
self._mp_queue = mp_queue |
|
|
|
|
|
def acquire(self, scope_id, expiration_time): |
|
|
pid = os.getpid() |
|
|
acquire_request = TimerRequest(pid, scope_id, expiration_time) |
|
|
self._mp_queue.put(acquire_request) |
|
|
|
|
|
def release(self, scope_id): |
|
|
pid = os.getpid() |
|
|
release_request = TimerRequest(pid, scope_id, -1) |
|
|
self._mp_queue.put(release_request) |
|
|
|
|
|
|
|
|
class MultiprocessingRequestQueue(RequestQueue): |
|
|
""" |
|
|
A ``RequestQueue`` backed by python ``multiprocessing.Queue`` |
|
|
""" |
|
|
|
|
|
def __init__(self, mp_queue: mp.Queue): |
|
|
super().__init__() |
|
|
self._mp_queue = mp_queue |
|
|
|
|
|
def size(self) -> int: |
|
|
return self._mp_queue.qsize() |
|
|
|
|
|
def get(self, size, timeout: float) -> List[TimerRequest]: |
|
|
requests = [] |
|
|
wait = timeout |
|
|
for _ in range(0, size): |
|
|
start = time.time() |
|
|
|
|
|
try: |
|
|
r = self._mp_queue.get(block=True, timeout=wait) |
|
|
except Empty: |
|
|
break |
|
|
|
|
|
requests.append(r) |
|
|
wait = wait - (time.time() - start) |
|
|
if wait <= 0: |
|
|
break |
|
|
|
|
|
return requests |
|
|
|
|
|
|
|
|
class LocalTimerServer(TimerServer): |
|
|
""" |
|
|
Server that works with ``LocalTimerClient``. Clients are expected to be |
|
|
subprocesses to the parent process that is running this server. Each host |
|
|
in the job is expected to start its own timer server locally and each |
|
|
server instance manages timers for local workers (running on processes |
|
|
on the same host). |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, mp_queue: mp.Queue, max_interval: float = 60, daemon: bool = True |
|
|
): |
|
|
super().__init__(MultiprocessingRequestQueue(mp_queue), max_interval, daemon) |
|
|
self._timers: Dict[Tuple[Any, str], TimerRequest] = {} |
|
|
|
|
|
def register_timers(self, timer_requests: List[TimerRequest]) -> None: |
|
|
for request in timer_requests: |
|
|
pid = request.worker_id |
|
|
scope_id = request.scope_id |
|
|
expiration_time = request.expiration_time |
|
|
|
|
|
|
|
|
if expiration_time < 0: |
|
|
self._timers.pop((pid, scope_id), None) |
|
|
else: |
|
|
self._timers[(pid, scope_id)] = request |
|
|
|
|
|
def clear_timers(self, worker_ids: Set[int]) -> None: |
|
|
for (pid, scope_id) in list(self._timers.keys()): |
|
|
if pid in worker_ids: |
|
|
self._timers.pop((pid, scope_id)) |
|
|
|
|
|
def get_expired_timers(self, deadline: float) -> Dict[Any, List[TimerRequest]]: |
|
|
|
|
|
expired_timers: Dict[Any, List[TimerRequest]] = {} |
|
|
for request in self._timers.values(): |
|
|
if request.expiration_time <= deadline: |
|
|
expired_scopes = expired_timers.setdefault(request.worker_id, []) |
|
|
expired_scopes.append(request) |
|
|
return expired_timers |
|
|
|
|
|
def _reap_worker(self, worker_id: int) -> bool: |
|
|
try: |
|
|
os.kill(worker_id, signal.SIGKILL) |
|
|
return True |
|
|
except ProcessLookupError: |
|
|
logging.info(f"Process with pid={worker_id} does not exist. Skipping") |
|
|
return True |
|
|
except Exception as e: |
|
|
logging.error(f"Error terminating pid={worker_id}", exc_info=e) |
|
|
return False |
|
|
|