import importlib import subprocess import ray from slime.utils.http_utils import is_port_available def load_function(path): """ Load a function from a module. :param path: The path to the function, e.g. "module.submodule.function". :return: The function object. """ module_path, _, attr = path.rpartition(".") module = importlib.import_module(module_path) return getattr(module, attr) class SingletonMeta(type): """ A metaclass for creating singleton classes. """ _instances = {} def __call__(cls, *args, **kwargs): if cls not in cls._instances: instance = super().__call__(*args, **kwargs) cls._instances[cls] = instance return cls._instances[cls] def clear_instances(cls): cls._instances = {} def exec_command(cmd: str, capture_output: bool = False) -> str | None: print(f"EXEC: {cmd}", flush=True) try: result = subprocess.run( ["bash", "-c", cmd], shell=False, check=True, capture_output=capture_output, **(dict(text=True) if capture_output else {}), ) except subprocess.CalledProcessError as e: if capture_output: print(f"{e.stdout=} {e.stderr=}") raise if capture_output: print(f"Captured stdout={result.stdout} stderr={result.stderr}") return result.stdout def get_current_node_ip(): address = ray._private.services.get_node_ip_address() # strip ipv6 address address = address.strip("[]") return address def get_free_port(start_port=10000, consecutive=1): # find the port where port, port + 1, port + 2, ... port + consecutive - 1 are all available port = start_port while not all(is_port_available(port + i) for i in range(consecutive)): port += 1 return port def should_run_periodic_action( rollout_id: int, interval: int | None, num_rollout_per_epoch: int | None = None, num_rollout: int | None = None, ) -> bool: """ Return True when a periodic action (eval/save/checkpoint) should run. Args: rollout_id: The current rollout index (0-based). interval: Desired cadence; disables checks when None. num_rollout_per_epoch: Optional epoch boundary to treat as a trigger. """ if interval is None: return False if num_rollout is not None and rollout_id == num_rollout - 1: return True step = rollout_id + 1 return (step % interval == 0) or (num_rollout_per_epoch is not None and step % num_rollout_per_epoch == 0) class Box: def __init__(self, inner): self._inner = inner @property def inner(self): return self._inner from collections import defaultdict from collections.abc import Callable, Iterable from typing import Any import torch # details: https://stackoverflow.com/questions/773/how-do-i-use-itertools-groupby def group_by(iterable, key=None): """Similar to itertools.groupby, but do not require iterable to be sorted""" ret = defaultdict(list) for item in iterable: ret[key(item) if key is not None else item].append(item) return dict(ret) # TODO fsdp can also use this def chunk_named_params_by_size(named_params: Iterable[tuple[str, torch.Tensor]], chunk_size: int): return _chunk_by_size( named_params, compute_size=lambda named_weight: named_weight[1].nbytes, chunk_size=chunk_size, ) def _chunk_by_size(objects: Iterable[Any], compute_size: Callable[[Any], int], chunk_size: int): bucket: list[Any] = [] bucket_size = 0 for obj in objects: obj_size = compute_size(obj) if bucket and (bucket_size + obj_size) >= chunk_size: yield bucket bucket = [] bucket_size = 0 bucket.append(obj) bucket_size += obj_size if bucket: yield bucket