| import importlib |
| import subprocess |
|
|
| import ray |
|
|
| from slime.utils.http_utils import is_port_available |
|
|
|
|
| def load_function(path): |
| """ |
| Load a function from a module. |
| :param path: The path to the function, e.g. "module.submodule.function". |
| :return: The function object. |
| """ |
| module_path, _, attr = path.rpartition(".") |
| module = importlib.import_module(module_path) |
| return getattr(module, attr) |
|
|
|
|
| class SingletonMeta(type): |
| """ |
| A metaclass for creating singleton classes. |
| """ |
|
|
| _instances = {} |
|
|
| def __call__(cls, *args, **kwargs): |
| if cls not in cls._instances: |
| instance = super().__call__(*args, **kwargs) |
| cls._instances[cls] = instance |
| return cls._instances[cls] |
|
|
| def clear_instances(cls): |
| cls._instances = {} |
|
|
|
|
| def exec_command(cmd: str, capture_output: bool = False) -> str | None: |
| print(f"EXEC: {cmd}", flush=True) |
|
|
| try: |
| result = subprocess.run( |
| ["bash", "-c", cmd], |
| shell=False, |
| check=True, |
| capture_output=capture_output, |
| **(dict(text=True) if capture_output else {}), |
| ) |
| except subprocess.CalledProcessError as e: |
| if capture_output: |
| print(f"{e.stdout=} {e.stderr=}") |
| raise |
|
|
| if capture_output: |
| print(f"Captured stdout={result.stdout} stderr={result.stderr}") |
| return result.stdout |
|
|
|
|
| def get_current_node_ip(): |
| address = ray._private.services.get_node_ip_address() |
| |
| address = address.strip("[]") |
| return address |
|
|
|
|
| def get_free_port(start_port=10000, consecutive=1): |
| |
| port = start_port |
| while not all(is_port_available(port + i) for i in range(consecutive)): |
| port += 1 |
| return port |
|
|
|
|
| def should_run_periodic_action( |
| rollout_id: int, |
| interval: int | None, |
| num_rollout_per_epoch: int | None = None, |
| num_rollout: int | None = None, |
| ) -> bool: |
| """ |
| Return True when a periodic action (eval/save/checkpoint) should run. |
| |
| Args: |
| rollout_id: The current rollout index (0-based). |
| interval: Desired cadence; disables checks when None. |
| num_rollout_per_epoch: Optional epoch boundary to treat as a trigger. |
| """ |
| if interval is None: |
| return False |
|
|
| if num_rollout is not None and rollout_id == num_rollout - 1: |
| return True |
|
|
| step = rollout_id + 1 |
| return (step % interval == 0) or (num_rollout_per_epoch is not None and step % num_rollout_per_epoch == 0) |
|
|
|
|
| class Box: |
| def __init__(self, inner): |
| self._inner = inner |
|
|
| @property |
| def inner(self): |
| return self._inner |
|
|
|
|
| from collections import defaultdict |
| from collections.abc import Callable, Iterable |
| from typing import Any |
|
|
| import torch |
|
|
|
|
| |
| def group_by(iterable, key=None): |
| """Similar to itertools.groupby, but do not require iterable to be sorted""" |
| ret = defaultdict(list) |
| for item in iterable: |
| ret[key(item) if key is not None else item].append(item) |
| return dict(ret) |
|
|
|
|
| |
| def chunk_named_params_by_size(named_params: Iterable[tuple[str, torch.Tensor]], chunk_size: int): |
| return _chunk_by_size( |
| named_params, |
| compute_size=lambda named_weight: named_weight[1].nbytes, |
| chunk_size=chunk_size, |
| ) |
|
|
|
|
| def _chunk_by_size(objects: Iterable[Any], compute_size: Callable[[Any], int], chunk_size: int): |
| bucket: list[Any] = [] |
| bucket_size = 0 |
|
|
| for obj in objects: |
| obj_size = compute_size(obj) |
|
|
| if bucket and (bucket_size + obj_size) >= chunk_size: |
| yield bucket |
| bucket = [] |
| bucket_size = 0 |
|
|
| bucket.append(obj) |
| bucket_size += obj_size |
|
|
| if bucket: |
| yield bucket |
|
|