shinka-backup / ccevolve /baselines /thetaevolve /slime /utils /reloadable_process_group.py
JustinTX's picture
Add files using upload-large-folder tool
d7b3a74 verified
import logging
import os
from contextlib import contextmanager
import torch
import torch.distributed as dist
from slime.utils.memory_utils import available_memory, clear_memory, print_memory
logger = logging.getLogger(__name__)
old_new_group_dict = {}
def monkey_patch_torch_dist():
pid = os.getpid()
if pid in old_new_group_dict:
assert dist.old_new_group == old_new_group_dict[pid]
return
logger.info("Applying monkey patch to torch.distributed")
old_new_group = dist.new_group
old_new_group_dict[pid] = old_new_group
dist.old_new_group = old_new_group
def new_group(*args, **kwargs):
group = old_new_group(*args, **kwargs)
# skip none nccl group.
if len(args) >= 3 and args[2] == "gloo" or "backend" in kwargs and kwargs["backend"] == "gloo":
return group
# Get ranks from arguments
if len(args) >= 1 and args[0] is not None:
ranks = args[0]
elif "ranks" in kwargs and kwargs["ranks"] is not None:
ranks = kwargs["ranks"]
else:
# If no ranks specified, use all ranks in world
ranks = list(range(dist.get_world_size()))
if len(ranks) == 1:
return group
group = ReloadableProcessGroup(group, ranks)
return group
dist.new_group = new_group
def get_new_function(func):
def new_function(*args, **kwargs):
args = tuple([arg.group if isinstance(arg, ReloadableProcessGroup) else arg for arg in args])
kwargs = {k: (v.group if isinstance(v, ReloadableProcessGroup) else v) for k, v in kwargs.items()}
with _wrap_low_level_call():
return func(*args, **kwargs)
return new_function
dist.get_rank = get_new_function(dist.get_rank)
dist.get_world_size = get_new_function(dist.get_world_size)
dist.get_backend = get_new_function(dist.get_backend)
dist.get_global_rank = get_new_function(dist.get_global_rank)
dist.get_group_rank = get_new_function(dist.get_group_rank)
dist.get_process_group_ranks = get_new_function(dist.get_process_group_ranks)
dist.all_reduce = get_new_function(dist.all_reduce)
dist.all_gather = get_new_function(dist.all_gather)
dist.all_gather_into_tensor = get_new_function(dist.all_gather_into_tensor)
dist.all_gather_object = get_new_function(dist.all_gather_object)
dist.all_to_all = get_new_function(dist.all_to_all)
dist.all_to_all_single = get_new_function(dist.all_to_all_single)
dist.broadcast = get_new_function(dist.broadcast)
dist.reduce = get_new_function(dist.reduce)
dist.reduce_scatter = get_new_function(dist.reduce_scatter)
dist.reduce_scatter_tensor = get_new_function(dist.reduce_scatter_tensor)
dist.scatter = get_new_function(dist.scatter)
dist.gather = get_new_function(dist.gather)
dist.barrier = get_new_function(dist.barrier)
dist.send = get_new_function(dist.send)
dist.recv = get_new_function(dist.recv)
dist._coalescing_manager = get_new_function(dist._coalescing_manager)
# p2p
old_isend = dist.isend
old_irecv = dist.irecv
dist.isend = get_new_function(dist.isend)
dist.irecv = get_new_function(dist.irecv)
def get_new_p2pop_function(func):
def new_function(*args, **kwargs):
def convert(arg):
if isinstance(arg, ReloadableProcessGroup):
return arg.group
elif arg == dist.isend:
arg = old_isend
elif arg == dist.irecv:
arg = old_irecv
return arg
args = (convert(arg) for arg in args)
kwargs = {k: convert(v) for k, v in kwargs.items()}
return func(*args, **kwargs)
return new_function
dist.P2POp.__new__ = get_new_p2pop_function(dist.P2POp.__new__)
dist.P2POp.__init__ = get_new_p2pop_function(dist.P2POp.__init__)
class ReloadableProcessGroup(torch.distributed.ProcessGroup):
GROUPS = {}
def __init__(self, group, ranks):
super().__init__(
rank=dist.get_rank(group),
size=dist.get_world_size(group),
)
self.group = group
self.group_info = {
"ranks": ranks,
}
pid = os.getpid()
if pid not in ReloadableProcessGroup.GROUPS:
ReloadableProcessGroup.GROUPS[pid] = []
ReloadableProcessGroup.GROUPS[pid].append(self)
def __getattr__(self, name):
return getattr(self.group, name)
@staticmethod
def destroy_process_groups():
pid = os.getpid()
for reloadable_group in ReloadableProcessGroup.GROUPS.get(pid, []):
if reloadable_group.group is None:
continue
try:
dist.destroy_process_group(reloadable_group.group)
except ValueError as e:
logger.warning(
f"Process group already invalid/destroyed; skipping cleanup. Exception: {e}",
exc_info=True,
)
del reloadable_group.group
reloadable_group.group = None
@staticmethod
def reload_process_groups():
pid = os.getpid()
reloadable_groups = ReloadableProcessGroup.GROUPS.get(pid, [])
logger.info(f"Reloading {len(reloadable_groups)} process groups in pid {pid}")
old_new_group = old_new_group_dict.get(pid)
for reloadable_group in reloadable_groups:
if reloadable_group.group is not None:
continue
group = old_new_group(ranks=reloadable_group.group_info["ranks"], backend="nccl")
reloadable_group.group = group
def rank(self) -> int:
return self.group.rank()
def size(self) -> int:
return self.group.size()
def name(self) -> str:
return self.group.name()
def shutdown(self) -> None:
if self.group is not None:
self.group.shutdown()
def abort(self) -> None:
if self.group is not None:
self.group.abort()
def _fwd(self, method, *args, **kwargs):
inner = self.group
if inner is None:
raise RuntimeError("ReloadableProcessGroup: inner PG is None, call reload() first.")
with _wrap_low_level_call():
return getattr(inner, method)(*args, **kwargs)
def barrier(self, *a, **kw):
return self._fwd("barrier", *a, **kw)
def broadcast(self, *a, **kw):
return self._fwd("broadcast", *a, **kw)
def allreduce(self, *a, **kw):
return self._fwd("allreduce", *a, **kw)
def allreduce_coalesced(self, *a, **kw):
return self._fwd("allreduce_coalesced", *a, **kw)
def reduce(self, *a, **kw):
return self._fwd("reduce", *a, **kw)
def allgather(self, *a, **kw):
return self._fwd("allgather", *a, **kw)
def _allgather_base(self, *a, **kw):
return self._fwd("_allgather_base", *a, **kw)
def allgather_coalesced(self, *a, **kw):
return self._fwd("allgather_coalesced", *a, **kw)
def allgather_into_tensor_coalesced(self, *a, **kw):
return self._fwd("allgather_into_tensor_coalesced", *a, **kw)
def gather(self, *a, **kw):
return self._fwd("gather", *a, **kw)
def scatter(self, *a, **kw):
return self._fwd("scatter", *a, **kw)
def reduce_scatter(self, *a, **kw):
return self._fwd("reduce_scatter", *a, **kw)
def _reduce_scatter_base(self, *a, **kw):
return self._fwd("_reduce_scatter_base", *a, **kw)
def reduce_scatter_tensor_coalesced(self, *a, **kw):
return self._fwd("reduce_scatter_tensor_coalesced", *a, **kw)
def alltoall_base(self, *a, **kw):
return self._fwd("alltoall_base", *a, **kw)
def alltoall(self, *a, **kw):
return self._fwd("alltoall", *a, **kw)
def send(self, *a, **kw):
return self._fwd("send", *a, **kw)
def recv(self, *a, **kw):
return self._fwd("recv", *a, **kw)
def recv_anysource(self, *a, **kw):
return self._fwd("recv_anysource", *a, **kw)
def _start_coalescing(self, *a, **kw):
return self._fwd("_start_coalescing", *a, **kw)
def _end_coalescing(self, *a, **kw):
return self._fwd("_end_coalescing", *a, **kw)
def _get_backend_name(self):
return self._fwd("_get_backend_name")
def _get_backend(self, *a, **kw):
return self._fwd("_get_backend", *a, **kw)
def _set_default_backend(self, *a, **kw):
return self._fwd("_set_default_backend", *a, **kw)
@property
def bound_device_id(self):
return self.group.bound_device_id
@bound_device_id.setter
def bound_device_id(self, dev):
self.group.bound_device_id = dev
def destroy_process_groups():
"""Destroy all reloadable process groups."""
ReloadableProcessGroup.destroy_process_groups()
def reload_process_groups():
"""Reload all reloadable process groups."""
ReloadableProcessGroup.reload_process_groups()
@contextmanager
def _wrap_low_level_call():
try:
mem_info = available_memory()
if mem_info["free_GB"] < 5:
clear_memory()
yield
except Exception as e:
mem_info = print_memory("after torch distributed error")
e.add_note(f"{mem_info=}")
raise