File size: 3,932 Bytes
d7b3a74 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | import importlib
import subprocess
import ray
from slime.utils.http_utils import is_port_available
def load_function(path):
"""
Load a function from a module.
:param path: The path to the function, e.g. "module.submodule.function".
:return: The function object.
"""
module_path, _, attr = path.rpartition(".")
module = importlib.import_module(module_path)
return getattr(module, attr)
class SingletonMeta(type):
"""
A metaclass for creating singleton classes.
"""
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
return cls._instances[cls]
def clear_instances(cls):
cls._instances = {}
def exec_command(cmd: str, capture_output: bool = False) -> str | None:
print(f"EXEC: {cmd}", flush=True)
try:
result = subprocess.run(
["bash", "-c", cmd],
shell=False,
check=True,
capture_output=capture_output,
**(dict(text=True) if capture_output else {}),
)
except subprocess.CalledProcessError as e:
if capture_output:
print(f"{e.stdout=} {e.stderr=}")
raise
if capture_output:
print(f"Captured stdout={result.stdout} stderr={result.stderr}")
return result.stdout
def get_current_node_ip():
address = ray._private.services.get_node_ip_address()
# strip ipv6 address
address = address.strip("[]")
return address
def get_free_port(start_port=10000, consecutive=1):
# find the port where port, port + 1, port + 2, ... port + consecutive - 1 are all available
port = start_port
while not all(is_port_available(port + i) for i in range(consecutive)):
port += 1
return port
def should_run_periodic_action(
rollout_id: int,
interval: int | None,
num_rollout_per_epoch: int | None = None,
num_rollout: int | None = None,
) -> bool:
"""
Return True when a periodic action (eval/save/checkpoint) should run.
Args:
rollout_id: The current rollout index (0-based).
interval: Desired cadence; disables checks when None.
num_rollout_per_epoch: Optional epoch boundary to treat as a trigger.
"""
if interval is None:
return False
if num_rollout is not None and rollout_id == num_rollout - 1:
return True
step = rollout_id + 1
return (step % interval == 0) or (num_rollout_per_epoch is not None and step % num_rollout_per_epoch == 0)
class Box:
def __init__(self, inner):
self._inner = inner
@property
def inner(self):
return self._inner
from collections import defaultdict
from collections.abc import Callable, Iterable
from typing import Any
import torch
# details: https://stackoverflow.com/questions/773/how-do-i-use-itertools-groupby
def group_by(iterable, key=None):
"""Similar to itertools.groupby, but do not require iterable to be sorted"""
ret = defaultdict(list)
for item in iterable:
ret[key(item) if key is not None else item].append(item)
return dict(ret)
# TODO fsdp can also use this
def chunk_named_params_by_size(named_params: Iterable[tuple[str, torch.Tensor]], chunk_size: int):
return _chunk_by_size(
named_params,
compute_size=lambda named_weight: named_weight[1].nbytes,
chunk_size=chunk_size,
)
def _chunk_by_size(objects: Iterable[Any], compute_size: Callable[[Any], int], chunk_size: int):
bucket: list[Any] = []
bucket_size = 0
for obj in objects:
obj_size = compute_size(obj)
if bucket and (bucket_size + obj_size) >= chunk_size:
yield bucket
bucket = []
bucket_size = 0
bucket.append(obj)
bucket_size += obj_size
if bucket:
yield bucket
|