JustinTX's picture
Add files using upload-large-folder tool
d7b3a74 verified
import importlib
import subprocess
import ray
from slime.utils.http_utils import is_port_available
def load_function(path):
"""
Load a function from a module.
:param path: The path to the function, e.g. "module.submodule.function".
:return: The function object.
"""
module_path, _, attr = path.rpartition(".")
module = importlib.import_module(module_path)
return getattr(module, attr)
class SingletonMeta(type):
"""
A metaclass for creating singleton classes.
"""
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
instance = super().__call__(*args, **kwargs)
cls._instances[cls] = instance
return cls._instances[cls]
def clear_instances(cls):
cls._instances = {}
def exec_command(cmd: str, capture_output: bool = False) -> str | None:
print(f"EXEC: {cmd}", flush=True)
try:
result = subprocess.run(
["bash", "-c", cmd],
shell=False,
check=True,
capture_output=capture_output,
**(dict(text=True) if capture_output else {}),
)
except subprocess.CalledProcessError as e:
if capture_output:
print(f"{e.stdout=} {e.stderr=}")
raise
if capture_output:
print(f"Captured stdout={result.stdout} stderr={result.stderr}")
return result.stdout
def get_current_node_ip():
address = ray._private.services.get_node_ip_address()
# strip ipv6 address
address = address.strip("[]")
return address
def get_free_port(start_port=10000, consecutive=1):
# find the port where port, port + 1, port + 2, ... port + consecutive - 1 are all available
port = start_port
while not all(is_port_available(port + i) for i in range(consecutive)):
port += 1
return port
def should_run_periodic_action(
rollout_id: int,
interval: int | None,
num_rollout_per_epoch: int | None = None,
num_rollout: int | None = None,
) -> bool:
"""
Return True when a periodic action (eval/save/checkpoint) should run.
Args:
rollout_id: The current rollout index (0-based).
interval: Desired cadence; disables checks when None.
num_rollout_per_epoch: Optional epoch boundary to treat as a trigger.
"""
if interval is None:
return False
if num_rollout is not None and rollout_id == num_rollout - 1:
return True
step = rollout_id + 1
return (step % interval == 0) or (num_rollout_per_epoch is not None and step % num_rollout_per_epoch == 0)
class Box:
def __init__(self, inner):
self._inner = inner
@property
def inner(self):
return self._inner
from collections import defaultdict
from collections.abc import Callable, Iterable
from typing import Any
import torch
# details: https://stackoverflow.com/questions/773/how-do-i-use-itertools-groupby
def group_by(iterable, key=None):
"""Similar to itertools.groupby, but do not require iterable to be sorted"""
ret = defaultdict(list)
for item in iterable:
ret[key(item) if key is not None else item].append(item)
return dict(ret)
# TODO fsdp can also use this
def chunk_named_params_by_size(named_params: Iterable[tuple[str, torch.Tensor]], chunk_size: int):
return _chunk_by_size(
named_params,
compute_size=lambda named_weight: named_weight[1].nbytes,
chunk_size=chunk_size,
)
def _chunk_by_size(objects: Iterable[Any], compute_size: Callable[[Any], int], chunk_size: int):
bucket: list[Any] = []
bucket_size = 0
for obj in objects:
obj_size = compute_size(obj)
if bucket and (bucket_size + obj_size) >= chunk_size:
yield bucket
bucket = []
bucket_size = 0
bucket.append(obj)
bucket_size += obj_size
if bucket:
yield bucket