File size: 7,698 Bytes
1faccd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
# Copyright 2024 Bytedance Ltd. and/or its affiliates
# Copyright 2023-2024 SGLang Team
# Copyright 2025 ModelBest Inc. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Group-wise helpers for RL training utilities.

Public API:
    - as_torch_index(index, device=None) -> torch.LongTensor
    - group_mean_std(scores, gidx, eps=1e-6, device=None) -> (mean_g, std_g, count_g)

Default device policy:
    - If `device` is None:
        * In pytest (detected by env "PYTEST_CURRENT_TEST"): use CPU.
        * Else if CUDA is available: use CUDA.
        * Else: use CPU.
    - You can override via env "VERL_FORCE_DEVICE" (e.g., "cuda:0" / "cpu").

Notes:
- as_torch_index: canonicalizes arbitrary group labels to a contiguous 1-D torch.long
  tensor in range [0..G-1]. Robust to torch/numpy/list/tuple, ints/floats/bools,
  numeric strings, UUIDs, mixed object arrays. Near-integer floats (|x-round(x)|<=1e-6)
  are rounded; otherwise factorization is applied.
- group_mean_std: pure-PyTorch per-group mean/std with Bessel correction for variance
  (denominator max(count-1, 1)). Singleton groups fallback to mean=0, std=1 for
  compatibility with common “native” conventions.
"""

from __future__ import annotations

import os
from typing import Any, Optional

import numpy as np
import torch

from verl.utils.device import get_device_name

__all__ = ["as_torch_index", "group_mean_std"]


def _resolve_device(explicit: Optional[torch.device | str]) -> torch.device:
    """
    Resolve device according to policy described in the module docstring.
    Priority:
      1) explicit argument
      2) VERL_FORCE_DEVICE env
      3) pytest detection -> cpu
      4) cuda if available, else cpu
    """
    if explicit is not None:
        return torch.device(explicit)

    forced = os.getenv("VERL_FORCE_DEVICE")
    if forced:
        return torch.device(forced)

    # Heuristic: pytest sets PYTEST_CURRENT_TEST
    if "PYTEST_CURRENT_TEST" in os.environ:
        return torch.device("cpu")

    return torch.device(get_device_name())


def _to_1d_numpy_object_array(x: Any) -> np.ndarray:
    """Best-effort: convert arbitrary input into a 1-D numpy array; fallback to object dtype."""
    try:
        arr = np.asarray(x)
    except Exception:
        try:
            arr = np.array(list(x), dtype=object)
        except Exception:
            arr = np.array([x], dtype=object)
    if arr.ndim != 1:
        arr = arr.reshape(-1)
    return arr


def as_torch_index(index: Any, device: torch.device | str | None = None) -> torch.Tensor:
    """
    Convert arbitrary group labels to a contiguous 1-D torch.long tensor (0..G-1).

    Args:
        index: Any iterable of labels or tensor/ndarray.
        device: Target device; if None, resolved via _resolve_device().

    Returns:
        torch.LongTensor with shape (N,)
    """
    target = _resolve_device(device)

    # ---------- Fast path: torch.Tensor ----------
    if isinstance(index, torch.Tensor):
        t = index.reshape(-1)
        if t.dtype in (
            torch.int64,
            torch.int32,
            torch.int16,
            torch.int8,
            getattr(torch, "uint8", torch.uint8),
            torch.bool,
        ):
            return t.to(device=target, dtype=torch.long)

        if t.dtype in (torch.float16, torch.float32, torch.float64, torch.bfloat16):
            t64 = t.to(dtype=torch.float64)
            rounded = torch.round(t64)
            if torch.allclose(t64, rounded, rtol=0.0, atol=1e-6):
                return rounded.to(device=target, dtype=torch.long)
            arr = np.array([str(x.item()) for x in t], dtype=object)
        else:
            arr = np.array([str(x.item()) if hasattr(x, "item") else str(x) for x in t], dtype=object)

    else:
        # ---------- Non-torch: go through numpy ----------
        arr = _to_1d_numpy_object_array(index)

        # Pure integers (incl. bool)
        if arr.dtype != object and np.issubdtype(arr.dtype, np.integer):
            return torch.from_numpy(arr.astype(np.int64, copy=False)).to(device=target)

        # Floats nearly equal to integers
        if arr.dtype != object and np.issubdtype(arr.dtype, np.floating):
            arr64 = arr.astype(np.float64, copy=False)
            rounded = np.rint(arr64)
            if np.allclose(arr64, rounded, rtol=0.0, atol=1e-6):
                return torch.from_numpy(rounded.astype(np.int64)).to(device=target)
            # fall through

        # Try numeric string coercion
        try:
            coerced = arr.astype(np.int64)
            return torch.from_numpy(coerced).to(device=target)
        except Exception:
            pass

        if arr.dtype != object:
            arr = arr.astype(object)

    # ---------- Factorization (UUIDs / mixed types / arbitrary labels) ----------
    try:
        _, inv = np.unique(arr, return_inverse=True)
    except Exception:
        sarr = np.array([str(x) for x in arr], dtype=object)
        _, inv = np.unique(sarr, return_inverse=True)

    inv = inv.astype(np.int64, copy=False)
    return torch.from_numpy(inv).to(device=target)


@torch.no_grad()
def group_mean_std(
    scores: torch.Tensor,
    gidx: torch.Tensor,
    eps: float = 1e-6,
    device: torch.device | str | None = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Compute per-group mean/std/count in pure PyTorch.

    mean_g = sum / count
    std_g  = sqrt( max( (sum2 - sum^2/count) / max(count-1, 1), eps ) )

    Singleton groups fallback to mean=0, std=1.

    Args:
        scores: (N,) float tensor.
        gidx  : (N,) long/int tensor with group indices (0..G-1).
        eps   : Numerical floor for variance.
        device: Target device; if None, resolved via _resolve_device().

    Returns:
        mean_g: (G,) float32
        std_g : (G,) float32
        count : (G,) float32
    """
    target = _resolve_device(device)

    scores = scores.reshape(-1).to(device=target, dtype=torch.float32)
    gidx = gidx.reshape(-1).to(device=target, dtype=torch.long)

    if scores.numel() != gidx.numel():
        raise ValueError(f"scores and gidx length mismatch: {scores.numel()} vs {gidx.numel()}")

    G = int(torch.max(gidx).item()) + 1 if gidx.numel() > 0 else 0
    if G == 0:
        # Return empty tensors on the selected device
        empty = torch.empty(0, device=target, dtype=torch.float32)
        return empty, empty, empty

    ones = torch.ones_like(scores, dtype=torch.float32)

    count = torch.zeros(G, device=target, dtype=torch.float32).index_add_(0, gidx, ones)
    s1 = torch.zeros(G, device=target, dtype=torch.float32).index_add_(0, gidx, scores)
    s2 = torch.zeros(G, device=target, dtype=torch.float32).index_add_(0, gidx, scores * scores)

    mean = s1 / count.clamp_min(1.0)
    var_num = s2 - (s1 * s1) / count.clamp_min(1.0)
    denom = (count - 1.0).clamp_min(1.0)
    var = var_num / denom
    std = torch.sqrt(torch.clamp(var, min=eps))

    # Singleton groups: mean=0, std=1
    single = count <= 1.0
    if torch.any(single):
        mean = mean.clone()
        std = std.clone()
        mean[single] = 0.0
        std[single] = 1.0

    return mean, std, count