updated-policy / training /segment_grpo.py
srinjoyd's picture
init
19f7f7b
"""
Segment-level GRPO loss β€” framework-agnostic core.
We optimise Phase 1 and Phase 2 as TWO separate GRPO problems, joined
only by the cross-phase reward `r_cross` which is added to the Phase-1
return *with stop-gradient on the Phase-2 path*.
Why segment-level? In typical GRPO, one trajectory of length L tokens
gets one scalar reward β€” but our episodes are 8-16k tokens (P1 + P2)
and the credit structure is fundamentally bimodal: a single bad P1
choice should bias every P1 token's update, but should not propagate
through the (already-trained) P2 policy. Segment-level GRPO solves
this exactly: each segment is its own group, advantages are normalized
within-segment within-group, and `r_cross` is bolted onto the P1 group
return as an additive constant per trajectory (with no gradient flowing
back through P2 β€” the trainer enforces this by simply not letting P2
parameters appear in the P1 group's loss graph).
This module provides:
- Segment : dataclass describing one (phase, trajectory, return)
- GRPOGroup : a group of K segments collected for the same prompt
- grpo_advantages: per-step advantages within a group
- grpo_loss : final scalar loss given log-probs from the model
The trainer wires this up like:
for batch in dataloader: # batch = list[GRPOGroup]
for group in batch:
adv = grpo_advantages(group)
logp = model.logp_per_token(group) # framework-specific
loss = grpo_loss(logp, adv,
ref_logp=ref.logp_per_token(group),
beta=0.04)
loss.backward()
"""
from __future__ import annotations
import math
from dataclasses import dataclass, field
from typing import List, Optional, Sequence
@dataclass
class Segment:
"""One (phase, trajectory) pair with a scalar return."""
segment_id: str
phase: int # 1 or 2
trajectory: list # list of step dicts (tokens implicit downstream)
terminal_reward: float
# Cross-phase credit: r_cross is *added* to terminal_reward for phase-1
# segments only; the trainer must arrange so no gradient flows from
# this term through phase-2 model parameters.
r_cross: float = 0.0
stop_gradient_through_p2: bool = True
@dataclass
class GRPOGroup:
"""K segments drawn for the same prompt (same task, same phase)."""
prompt_id: str
segments: List[Segment] = field(default_factory=list)
@property
def returns(self) -> List[float]:
return [s.terminal_reward + s.r_cross for s in self.segments]
# ──────────────────────────────────────────────────────────────────────
# Group-relative advantage normalization (the GRPO step)
# ──────────────────────────────────────────────────────────────────────
def grpo_advantages(group: GRPOGroup, eps: float = 1e-6) -> List[float]:
"""
Standard within-group standardization:
A_i = (R_i - mean(R)) / (std(R) + eps)
"""
R = group.returns
if not R:
return []
mu = sum(R) / len(R)
var = sum((r - mu) ** 2 for r in R) / max(len(R) - 1, 1)
sigma = math.sqrt(var) + eps
return [(r - mu) / sigma for r in R]
# ──────────────────────────────────────────────────────────────────────
# Loss (numerical core; trainer wraps in tensor ops)
# ──────────────────────────────────────────────────────────────────────
def grpo_loss(
logps: Sequence[Sequence[float]],
advantages: Sequence[float],
ref_logps: Optional[Sequence[Sequence[float]]] = None,
beta: float = 0.04,
clip: float = 0.2,
) -> float:
"""
Scalar GRPO objective (negated for minimization).
logps[i] : per-token log-prob sequence for segment i under the
current policy
ref_logps[i] : same sequence under the reference policy (for KL)
advantages[i] : segment-level advantage from `grpo_advantages`
The clipped objective:
L_pg(i,t) = -min(ratio_{i,t} * A_i,
clip(ratio_{i,t}, 1-c, 1+c) * A_i)
plus a per-token KL penalty `beta * KL[ref || policy]`.
This implementation runs on plain floats so unit tests can verify the
numerics; the trainer reimplements the same arithmetic in tensors.
"""
if not logps:
return 0.0
n_segments = len(logps)
total = 0.0
n_tokens = 0
for i, seq in enumerate(logps):
if not seq:
continue
adv = advantages[i] if i < len(advantages) else 0.0
ref_seq = ref_logps[i] if (ref_logps is not None and i < len(ref_logps)) else seq
for t, lp in enumerate(seq):
ref_lp = ref_seq[t] if t < len(ref_seq) else lp
ratio = math.exp(lp - ref_lp)
unclipped = ratio * adv
clipped = max(min(ratio, 1 + clip), 1 - clip) * adv
policy_term = -min(unclipped, clipped)
kl_term = beta * (ref_lp - lp) # forward KL approximation
total += policy_term + kl_term
n_tokens += 1
return total / max(n_tokens, 1)
# ──────────────────────────────────────────────────────────────────────
# Cross-phase wiring
# ──────────────────────────────────────────────────────────────────────
def attach_r_cross(
p1_segments: List[Segment],
r_cross_per_episode: List[float],
weight: float = 1.0,
) -> List[Segment]:
"""
Add `r_cross` to each Phase-1 segment with the configured weight.
`weight` is the curriculum-driven warmup factor (0 β†’ 1 across the
first ~500 Stage-4 steps).
"""
if len(p1_segments) != len(r_cross_per_episode):
raise ValueError(
f"r_cross length mismatch: segs={len(p1_segments)} "
f"r_cross={len(r_cross_per_episode)}")
out: List[Segment] = []
for s, rc in zip(p1_segments, r_cross_per_episode):
s2 = Segment(
segment_id = s.segment_id,
phase = s.phase,
trajectory = s.trajectory,
terminal_reward = s.terminal_reward,
r_cross = float(rc) * float(weight),
stop_gradient_through_p2 = True,
)
out.append(s2)
return out
def group_by_prompt(segments: List[Segment], group_size: int) -> List[GRPOGroup]:
"""
Bucket segments into groups of `group_size`. In production, the
sampler arranges that all segments in a group came from the same
prompt (same task seed) so within-group standardization is meaningful.
"""
groups: List[GRPOGroup] = []
bucket: List[Segment] = []
for s in segments:
bucket.append(s)
if len(bucket) >= group_size:
groups.append(GRPOGroup(prompt_id=bucket[0].segment_id, segments=bucket))
bucket = []
if bucket:
groups.append(GRPOGroup(prompt_id=bucket[0].segment_id, segments=bucket))
return groups