from __future__ import annotations import math import random from collections.abc import Sequence from linalg_zero.grpo.types import CurriculumConfig, Task class ShuffleBagSampler: """ Coverage-guaranteeing sampler ("shuffle bag"). Repeated calls to `sample_batch()` return indices from the eligible pool such that: - No index repeats until all currently-eligible indices have been returned once (a "cycle"), unless `batch_size` exceeds the pool size. - If the eligible pool grows over time (curriculum), newly added indices are injected into the current cycle so they are surfaced before any repeats. This is deterministic given the same `seed` and the same sequence of eligible pools. """ def __init__(self, *, seed: int = 0, shuffle: bool = True) -> None: self._rng = random.Random(int(seed)) self._shuffle = bool(shuffle) self._pool_set: set[int] = set() self._pool_order: list[int] = [] self._seen_in_cycle: set[int] = set() self._remaining: list[int] = [] self._remaining_set: set[int] = set() def _reset_cycle(self) -> None: self._seen_in_cycle.clear() self._remaining = list(self._pool_order) if self._shuffle and len(self._remaining) > 1: self._rng.shuffle(self._remaining) self._remaining_set = set(self._remaining) def _update_pool(self, *, eligible: Sequence[int]) -> None: eligible_list = list(eligible) eligible_set = set(eligible_list) if not eligible_set: raise ValueError("Cannot sample from an empty index set.") if not self._pool_set: self._pool_set = eligible_set self._pool_order = eligible_list self._reset_cycle() return added = [idx for idx in eligible_list if idx not in self._pool_set] removed = self._pool_set - eligible_set if not added and not removed: return if removed: self._pool_set.difference_update(removed) self._seen_in_cycle.difference_update(removed) if self._remaining: self._remaining = [idx for idx in self._remaining if idx not in removed] self._remaining_set.difference_update(removed) if added: self._pool_set.update(added) # Inject newly-eligible tasks into the current cycle so we don't defer them until the next cycle. for idx in added: if idx in self._seen_in_cycle or idx in self._remaining_set: continue pos = self._rng.randrange(len(self._remaining) + 1) self._remaining.insert(pos, idx) self._remaining_set.add(idx) self._pool_order = eligible_list # If we removed indices, it's possible to end up with an empty remaining set mid-cycle. # Start a fresh cycle in that case. if not self._remaining: self._reset_cycle() def sample_batch(self, *, eligible: Sequence[int], batch_size: int) -> list[int]: """ Sample a batch of indices with coverage guarantees. If `batch_size > len(eligible)`, repeats are unavoidable; this will still maximize coverage by cycling through the pool as many times as needed. """ if batch_size <= 0: return [] self._update_pool(eligible=eligible) out: list[int] = [] while len(out) < batch_size: if not self._remaining: self._reset_cycle() idx = self._remaining.pop() self._remaining_set.remove(idx) self._seen_in_cycle.add(idx) out.append(idx) return out def _clamp01(value: float) -> float: return float(max(0.0, min(1.0, value))) def _deterministic_counts_from_probs(*, probs: Sequence[float], total: int) -> list[int]: if total <= 0: return [0 for _ in probs] cleaned: list[float] = [max(0.0, float(p)) for p in probs] mass = sum(cleaned) if mass <= 0.0: # Fall back to uniform allocation if the distribution degenerates. cleaned = [1.0 for _ in probs] mass = float(len(cleaned)) normed = [p / mass for p in cleaned] expected = [p * total for p in normed] # Deterministic rounding: first take the floor of each expected count. counts = [math.floor(e) for e in expected] remainder = total - sum(counts) if remainder <= 0: return counts # Then allocate the leftover `remainder` to buckets with the largest fractional remainders # (`expected - floor(expected)`), not to buckets with the largest already-integer `counts`. fractional = [(expected[i] - counts[i], i) for i in range(len(counts))] fractional.sort(key=lambda item: (-item[0], item[1])) for _, idx in fractional[:remainder]: counts[idx] += 1 return counts class ToolCallsMixtureSampler: """ Deterministic per-step mixture sampler over tool-call "difficulty" buckets. - Buckets are defined by `len(task.actions)` (teacher tool calls). - Each bucket uses its own ShuffleBagSampler to maximize coverage within the bucket. - Per-step mixture weights are computed from a Gaussian centered on a target tool-call count that increases linearly with `difficulty`. """ def __init__( self, *, tasks: Sequence[Task], indices: Sequence[int], curriculum: CurriculumConfig, seed: int = 0, shuffle: bool = True, ) -> None: if curriculum.metric != "tool_calls": raise ValueError(f"Unknown curriculum metric: {curriculum.metric!r}") self._tasks = tasks self._curriculum = curriculum self._seed = int(seed) self._shuffle = bool(shuffle) self._calls = 0 base_indices = list(indices) if not base_indices: raise ValueError("Cannot create curriculum sampler from an empty index set.") buckets: dict[int, list[int]] = {} for idx in base_indices: tool_calls = len(tasks[idx].actions) buckets.setdefault(tool_calls, []).append(idx) self._min_tool_calls = min(buckets.keys()) max_tool_calls_seen = max(buckets.keys()) initial_tool_calls = max(self._min_tool_calls, int(curriculum.initial_max_tool_calls)) requested_final = ( max_tool_calls_seen if curriculum.final_max_tool_calls is None else min(int(curriculum.final_max_tool_calls), max_tool_calls_seen) ) requested_final = max(self._min_tool_calls, requested_final) if requested_final < initial_tool_calls: raise ValueError( "Invalid curriculum: final_max_tool_calls must be >= initial_max_tool_calls " f"(got initial={initial_tool_calls}, final={requested_final})." ) self._final_tool_calls = requested_final filtered_keys = sorted([k for k in buckets if k <= self._final_tool_calls]) if not filtered_keys: raise ValueError("Curriculum sampler has no buckets after applying final_max_tool_calls filter.") self._bucket_keys = filtered_keys self._bucket_ordered: dict[int, list[int]] = {} self._bucket_samplers: dict[int, ShuffleBagSampler] = {} for tool_calls in self._bucket_keys: bucket = list(buckets[tool_calls]) if self._shuffle and len(bucket) > 1: rng = random.Random(self._seed + (tool_calls + 1) * 1_000_003) rng.shuffle(bucket) self._bucket_ordered[tool_calls] = bucket self._bucket_samplers[tool_calls] = ShuffleBagSampler( seed=self._seed + (tool_calls + 1) * 2_000_033, shuffle=self._shuffle, ) def _target_tool_calls(self, *, difficulty: float) -> float: difficulty = _clamp01(difficulty) start = float(max(self._min_tool_calls, int(self._curriculum.initial_max_tool_calls))) end = float(max(start, self._final_tool_calls)) return float(max(self._min_tool_calls, min(end, start + difficulty * (end - start)))) def _mixture_probs(self, *, difficulty: float) -> list[float]: target = self._target_tool_calls(difficulty=difficulty) sigma = float(getattr(self._curriculum, "mixture_sigma", 0.0)) if sigma <= 0.0: # Hard assignment to the nearest bucket when sigma is disabled. nearest = min(self._bucket_keys, key=lambda tc: (abs(float(tc) - target), tc)) return [1.0 if tc == nearest else 0.0 for tc in self._bucket_keys] denom = 2.0 * sigma * sigma if denom == 0.0: # Extremely small positive sigma can underflow `sigma * sigma` to 0.0. nearest = min(self._bucket_keys, key=lambda tc: (abs(float(tc) - target), tc)) return [1.0 if tc == nearest else 0.0 for tc in self._bucket_keys] log_weights = [-((float(tc) - target) ** 2) / denom for tc in self._bucket_keys] max_log = max(log_weights) weights = [math.exp(w - max_log) for w in log_weights] mass = sum(weights) if mass <= 0.0: # Extremely small sigma can cause underflow; fall back to nearest-bucket. nearest = min(self._bucket_keys, key=lambda tc: (abs(float(tc) - target), tc)) probs = [1.0 if tc == nearest else 0.0 for tc in self._bucket_keys] else: probs = [w / mass for w in weights] floor = float(getattr(self._curriculum, "mixture_min_prob_easiest", 0.0)) floor = float(max(0.0, min(1.0, floor))) if floor <= 0.0: return probs p_easy = probs[0] if p_easy >= floor or p_easy >= 1.0: return probs remaining = 1.0 - p_easy if remaining <= 0.0: return [1.0] + [0.0 for _ in probs[1:]] scale = (1.0 - floor) / remaining adjusted = [floor] + [p * scale for p in probs[1:]] return adjusted def sample_batch(self, *, difficulty: float, batch_size: int) -> list[int]: if batch_size <= 0: return [] difficulty = _clamp01(difficulty) probs = self._mixture_probs(difficulty=difficulty) # Convert continuous mixture probabilities into an exact integer allocation for this step. # Note: when a bucket's expected count is <1 (e.g., early 3-tool-call exposure with small `p3`), # deterministic rounding can "flicker" between 0 and 1 across steps; once the expected count # crosses 1 (as difficulty increases), that bucket will appear every step thereafter. counts = _deterministic_counts_from_probs(probs=probs, total=batch_size) frac = self._curriculum.fraction_at_start + difficulty * ( self._curriculum.fraction_at_end - self._curriculum.fraction_at_start ) frac = _clamp01(frac) out: list[int] = [] for tool_calls, count in zip(self._bucket_keys, counts, strict=True): if count <= 0: continue ordered = self._bucket_ordered[tool_calls] k = math.ceil(frac * len(ordered)) if k <= 0: k = 1 eligible = ordered[:k] sampler = self._bucket_samplers[tool_calls] out.extend(sampler.sample_batch(eligible=eligible, batch_size=count)) if len(out) != batch_size: raise RuntimeError(f"Mixture sampler produced {len(out)} indices, expected {batch_size}.") # Mix difficulties within the step so the batch isn't ordered by bucket. # Deterministic given seed and call order (which is driven by the training step). rng = random.Random(self._seed + 9_000_001 + self._calls) rng.shuffle(out) self._calls += 1 return out def get_task_indices( # noqa: C901 *, task_ids: list[int] | None, start_index: int, end_index: int, tasks_length: int | None = None, tasks: Sequence[Task] | None = None, curriculum: CurriculumConfig | None = None, difficulty: float | None = None, seed: int = 0, ) -> list[int]: """ Return a list of task indices, optionally filtered by a curriculum. - If `task_ids` is provided, it always wins (curriculum is ignored). - Otherwise uses `[start_index, end_index)` (or full length if `end_index == -1`). - If `curriculum.enabled` and `difficulty` is provided, returns a deterministic subset that grows monotonically with `difficulty` (0.0 -> easiest subset, 1.0 -> full set). """ if task_ids: return list(task_ids) if tasks_length is None: if tasks is None: raise ValueError("Must provide `tasks_length` or `tasks` when task_ids is not set.") tasks_length = len(tasks) actual_start = max(0, start_index) actual_end = tasks_length if end_index == -1 else min(end_index, tasks_length) base_indices = list(range(actual_start, max(actual_start, actual_end))) if curriculum is None or not curriculum.enabled or difficulty is None: return base_indices if tasks is None: raise ValueError("Curriculum selection requires `tasks` to be provided.") if curriculum.metric != "tool_calls": raise ValueError(f"Unknown curriculum metric: {curriculum.metric!r}") difficulty = float(max(0.0, min(1.0, difficulty))) tool_calls_by_index: dict[int, int] = {} max_tool_calls_seen = 0 min_tool_calls_seen: int | None = None for idx in base_indices: tool_calls = len(tasks[idx].actions) tool_calls_by_index[idx] = tool_calls max_tool_calls_seen = max(max_tool_calls_seen, tool_calls) min_tool_calls_seen = tool_calls if min_tool_calls_seen is None else min(min_tool_calls_seen, tool_calls) initial_max = max(0, int(curriculum.initial_max_tool_calls)) final_max = ( max_tool_calls_seen if curriculum.final_max_tool_calls is None else int(curriculum.final_max_tool_calls) ) if final_max < initial_max: final_max = initial_max levels = max(1, final_max - initial_max + 1) level = min(levels - 1, math.floor(difficulty * levels)) allowed_max = initial_max + level if min_tool_calls_seen is not None: allowed_max = max(allowed_max, min_tool_calls_seen) # Fraction of tasks to expose within each included difficulty bucket. frac = curriculum.fraction_at_start + difficulty * (curriculum.fraction_at_end - curriculum.fraction_at_start) frac = float(max(0.0, min(1.0, frac))) buckets: dict[int, list[int]] = {} for idx in base_indices: tool_calls = tool_calls_by_index[idx] if tool_calls <= allowed_max: buckets.setdefault(tool_calls, []).append(idx) ordered_buckets: dict[int, list[int]] = {} for tool_calls in sorted(buckets.keys()): bucket = list(buckets[tool_calls]) rng = random.Random(seed + (tool_calls + 1) * 1_000_003) rng.shuffle(bucket) ordered_buckets[tool_calls] = bucket # Stable "easiest-first" ordering so that when new buckets are unlocked, # previously selected tasks remain selected. easiest_first: list[int] = [] for tool_calls in sorted(ordered_buckets.keys()): easiest_first.extend(ordered_buckets[tool_calls]) # Baseline floor (picked from easiest-first) to keep selection non-empty and monotonic. min_total = max(0, int(curriculum.min_total_tasks)) baseline = set(easiest_first[: min(min_total, len(easiest_first))]) # Fraction-based selection (per bucket), chosen as prefixes of deterministic per-bucket order. fraction_selected: set[int] = set() for tool_calls in sorted(ordered_buckets.keys()): bucket = ordered_buckets[tool_calls] k = math.ceil(frac * len(bucket)) if k <= 0: continue fraction_selected.update(bucket[:k]) selected_set = baseline | fraction_selected if not selected_set and easiest_first: selected_set = {easiest_first[0]} return [idx for idx in easiest_first if idx in selected_set] def sample_indices_to_length(*, indices: Sequence[int], length: int, rng: random.Random) -> list[int]: """Sample `length` indices, using replacement if needed.""" if length <= 0: return [] if not indices: raise ValueError("Cannot sample from an empty index set.") if len(indices) >= length: return list(rng.sample(list(indices), length)) return list(rng.choices(list(indices), k=length))