| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Group-wise helpers for RL training utilities. |
| |
| Public API: |
| - as_torch_index(index, device=None) -> torch.LongTensor |
| - group_mean_std(scores, gidx, eps=1e-6, device=None) -> (mean_g, std_g, count_g) |
| |
| Default device policy: |
| - If `device` is None: |
| * In pytest (detected by env "PYTEST_CURRENT_TEST"): use CPU. |
| * Else if CUDA is available: use CUDA. |
| * Else: use CPU. |
| - You can override via env "VERL_FORCE_DEVICE" (e.g., "cuda:0" / "cpu"). |
| |
| Notes: |
| - as_torch_index: canonicalizes arbitrary group labels to a contiguous 1-D torch.long |
| tensor in range [0..G-1]. Robust to torch/numpy/list/tuple, ints/floats/bools, |
| numeric strings, UUIDs, mixed object arrays. Near-integer floats (|x-round(x)|<=1e-6) |
| are rounded; otherwise factorization is applied. |
| - group_mean_std: pure-PyTorch per-group mean/std with Bessel correction for variance |
| (denominator max(count-1, 1)). Singleton groups fallback to mean=0, std=1 for |
| compatibility with common “native” conventions. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import os |
| from typing import Any, Optional |
|
|
| import numpy as np |
| import torch |
|
|
| from verl.utils.device import get_device_name |
|
|
| __all__ = ["as_torch_index", "group_mean_std"] |
|
|
|
|
| def _resolve_device(explicit: Optional[torch.device | str]) -> torch.device: |
| """ |
| Resolve device according to policy described in the module docstring. |
| Priority: |
| 1) explicit argument |
| 2) VERL_FORCE_DEVICE env |
| 3) pytest detection -> cpu |
| 4) cuda if available, else cpu |
| """ |
| if explicit is not None: |
| return torch.device(explicit) |
|
|
| forced = os.getenv("VERL_FORCE_DEVICE") |
| if forced: |
| return torch.device(forced) |
|
|
| |
| if "PYTEST_CURRENT_TEST" in os.environ: |
| return torch.device("cpu") |
|
|
| return torch.device(get_device_name()) |
|
|
|
|
| def _to_1d_numpy_object_array(x: Any) -> np.ndarray: |
| """Best-effort: convert arbitrary input into a 1-D numpy array; fallback to object dtype.""" |
| try: |
| arr = np.asarray(x) |
| except Exception: |
| try: |
| arr = np.array(list(x), dtype=object) |
| except Exception: |
| arr = np.array([x], dtype=object) |
| if arr.ndim != 1: |
| arr = arr.reshape(-1) |
| return arr |
|
|
|
|
| def as_torch_index(index: Any, device: torch.device | str | None = None) -> torch.Tensor: |
| """ |
| Convert arbitrary group labels to a contiguous 1-D torch.long tensor (0..G-1). |
| |
| Args: |
| index: Any iterable of labels or tensor/ndarray. |
| device: Target device; if None, resolved via _resolve_device(). |
| |
| Returns: |
| torch.LongTensor with shape (N,) |
| """ |
| target = _resolve_device(device) |
|
|
| |
| if isinstance(index, torch.Tensor): |
| t = index.reshape(-1) |
| if t.dtype in ( |
| torch.int64, |
| torch.int32, |
| torch.int16, |
| torch.int8, |
| getattr(torch, "uint8", torch.uint8), |
| torch.bool, |
| ): |
| return t.to(device=target, dtype=torch.long) |
|
|
| if t.dtype in (torch.float16, torch.float32, torch.float64, torch.bfloat16): |
| t64 = t.to(dtype=torch.float64) |
| rounded = torch.round(t64) |
| if torch.allclose(t64, rounded, rtol=0.0, atol=1e-6): |
| return rounded.to(device=target, dtype=torch.long) |
| arr = np.array([str(x.item()) for x in t], dtype=object) |
| else: |
| arr = np.array([str(x.item()) if hasattr(x, "item") else str(x) for x in t], dtype=object) |
|
|
| else: |
| |
| arr = _to_1d_numpy_object_array(index) |
|
|
| |
| if arr.dtype != object and np.issubdtype(arr.dtype, np.integer): |
| return torch.from_numpy(arr.astype(np.int64, copy=False)).to(device=target) |
|
|
| |
| if arr.dtype != object and np.issubdtype(arr.dtype, np.floating): |
| arr64 = arr.astype(np.float64, copy=False) |
| rounded = np.rint(arr64) |
| if np.allclose(arr64, rounded, rtol=0.0, atol=1e-6): |
| return torch.from_numpy(rounded.astype(np.int64)).to(device=target) |
| |
|
|
| |
| try: |
| coerced = arr.astype(np.int64) |
| return torch.from_numpy(coerced).to(device=target) |
| except Exception: |
| pass |
|
|
| if arr.dtype != object: |
| arr = arr.astype(object) |
|
|
| |
| try: |
| _, inv = np.unique(arr, return_inverse=True) |
| except Exception: |
| sarr = np.array([str(x) for x in arr], dtype=object) |
| _, inv = np.unique(sarr, return_inverse=True) |
|
|
| inv = inv.astype(np.int64, copy=False) |
| return torch.from_numpy(inv).to(device=target) |
|
|
|
|
| @torch.no_grad() |
| def group_mean_std( |
| scores: torch.Tensor, |
| gidx: torch.Tensor, |
| eps: float = 1e-6, |
| device: torch.device | str | None = None, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Compute per-group mean/std/count in pure PyTorch. |
| |
| mean_g = sum / count |
| std_g = sqrt( max( (sum2 - sum^2/count) / max(count-1, 1), eps ) ) |
| |
| Singleton groups fallback to mean=0, std=1. |
| |
| Args: |
| scores: (N,) float tensor. |
| gidx : (N,) long/int tensor with group indices (0..G-1). |
| eps : Numerical floor for variance. |
| device: Target device; if None, resolved via _resolve_device(). |
| |
| Returns: |
| mean_g: (G,) float32 |
| std_g : (G,) float32 |
| count : (G,) float32 |
| """ |
| target = _resolve_device(device) |
|
|
| scores = scores.reshape(-1).to(device=target, dtype=torch.float32) |
| gidx = gidx.reshape(-1).to(device=target, dtype=torch.long) |
|
|
| if scores.numel() != gidx.numel(): |
| raise ValueError(f"scores and gidx length mismatch: {scores.numel()} vs {gidx.numel()}") |
|
|
| G = int(torch.max(gidx).item()) + 1 if gidx.numel() > 0 else 0 |
| if G == 0: |
| |
| empty = torch.empty(0, device=target, dtype=torch.float32) |
| return empty, empty, empty |
|
|
| ones = torch.ones_like(scores, dtype=torch.float32) |
|
|
| count = torch.zeros(G, device=target, dtype=torch.float32).index_add_(0, gidx, ones) |
| s1 = torch.zeros(G, device=target, dtype=torch.float32).index_add_(0, gidx, scores) |
| s2 = torch.zeros(G, device=target, dtype=torch.float32).index_add_(0, gidx, scores * scores) |
|
|
| mean = s1 / count.clamp_min(1.0) |
| var_num = s2 - (s1 * s1) / count.clamp_min(1.0) |
| denom = (count - 1.0).clamp_min(1.0) |
| var = var_num / denom |
| std = torch.sqrt(torch.clamp(var, min=eps)) |
|
|
| |
| single = count <= 1.0 |
| if torch.any(single): |
| mean = mean.clone() |
| std = std.clone() |
| mean[single] = 0.0 |
| std[single] = 1.0 |
|
|
| return mean, std, count |
|
|