forensics-grpo / code /src /open_r1 /reward.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
55 kB
"""Reward functions for ActivityForensics GRPO training.
The model localises ALL AI-manipulated time intervals in a video. Each
sample's GT and prediction are LISTS of (start, end) tuples (in seconds, possibly empty).
Rewards:
- forensics_iou_reward: legacy continuous soft_F1 reward (set-level recall+prec).
- forensics_format_reward: regex match for "<answer> s1 to e1; s2 to e2 </answer>".
- hungarian_iou_reward: Hungarian-matched gIoU with FP/FN penalties (multi-span).
- anomaly_alignment_reward: per-second forgery score gap (in-segment vs out).
- combined_forgery_reward: alpha * hungarian_iou + beta * alignment.
Per-generator stats are logged whenever the dataset feeds the `generator`
column through (trl exposes all dataset columns as kwargs).
"""
import math
import os
import re
from collections import defaultdict
from datetime import datetime
from typing import Optional, Sequence, Tuple
import numpy as np
from scipy.optimize import linear_sum_assignment
# ---------------------------------------------------------------------------
# Robust answer parser
# ---------------------------------------------------------------------------
_NUM = r"\d+(?:\.\d+)?"
_ANSWER_RE = re.compile(r"<answer>(.*?)</answer>", re.DOTALL | re.IGNORECASE)
_THINK_RE = re.compile(r"<think>(.*?)</think>", re.DOTALL | re.IGNORECASE)
# Range separators: "to", "-", "–", "—", "~", "until"
_RANGE_SEP = r"(?:\s*(?:to|-|–|—|~|until)\s*)"
_SEGMENT_RE = re.compile(rf"({_NUM}){_RANGE_SEP}({_NUM})", re.IGNORECASE)
# <timestep>start to end</timestep> tags emitted inside <think> per the CoT prompt.
_TIMESTEP_RE = re.compile(
rf"<timestep>\s*({_NUM}){_RANGE_SEP}({_NUM})\s*</timestep>",
re.IGNORECASE | re.DOTALL,
)
# Strict format check (used by format reward): allow ; , and \n between segments.
_FULL_FORMAT_RE = re.compile(
rf"<answer>\s*{_NUM}\s+to\s+{_NUM}(?:\s*[;,]\s*{_NUM}\s+to\s+{_NUM})*\s*</answer>",
re.IGNORECASE | re.DOTALL,
)
def parse_segments(output_string: str):
"""Extract a list of (start, end) tuples from the LAST <answer> block.
Tolerant to common formatting noise: accepts hyphens / dashes / "to" / "until"
between numbers, and any non-digit / non-decimal text between segments.
Returns [] if no <answer> tag or no parseable segments.
"""
if output_string is None:
return []
answer_matches = _ANSWER_RE.findall(output_string)
if not answer_matches:
return []
last = answer_matches[-1]
segments = []
for m in _SEGMENT_RE.finditer(last):
try:
s, e = float(m.group(1)), float(m.group(2))
except ValueError:
continue
if e > s:
segments.append((s, e))
return segments
def parse_timestep_tags(output_string: str):
"""Extract <timestep>start to end</timestep> intervals from the FIRST <think> block.
Stage-2 CoT prompt asks the model to mark candidate intervals with these tags
while reasoning. Tags outside <think> are ignored. Returns [] when no tags found.
"""
if not output_string:
return []
m = _THINK_RE.search(output_string)
if not m:
return []
inner = m.group(1)
segs = []
for mm in _TIMESTEP_RE.finditer(inner):
try:
s, e = float(mm.group(1)), float(mm.group(2))
except ValueError:
continue
if e > s:
segs.append((s, e))
return segs
def _mass_iou_1d(set_a, set_b, resolution: float = 0.5) -> float:
"""1D set-IoU: rasterise both interval sets and divide intersection mass by union mass.
Independent of K on either side. Returns 0 when either is empty or extents degenerate.
"""
if not set_a or not set_b:
return 0.0
all_segs = list(set_a) + list(set_b)
lo = min(s for s, _ in all_segs)
hi = max(e for _, e in all_segs)
if hi <= lo:
return 0.0
n = max(1, int((hi - lo) / resolution) + 1)
grid_a = np.zeros(n, dtype=bool)
grid_b = np.zeros(n, dtype=bool)
for s, e in set_a:
i0 = max(0, int((s - lo) / resolution))
i1 = min(n, int((e - lo) / resolution + 0.999))
if i0 < i1:
grid_a[i0:i1] = True
for s, e in set_b:
i0 = max(0, int((s - lo) / resolution))
i1 = min(n, int((e - lo) / resolution + 0.999))
if i0 < i1:
grid_b[i0:i1] = True
inter = int((grid_a & grid_b).sum())
union = int((grid_a | grid_b).sum())
return float(inter / union) if union > 0 else 0.0
# ---------------------------------------------------------------------------
# Set-level matching metrics
# ---------------------------------------------------------------------------
def _iou_1d(a, b):
s1, e1 = a
s2, e2 = b
inter = max(0.0, min(e1, e2) - max(s1, s2))
union = max(e1, e2) - min(s1, s2)
return inter / union if union > 0 else 0.0
def soft_f1(preds, gts, beta=1.0):
"""Continuous F-beta over set IoU (no threshold).
precision = mean over preds of max IoU(p, g) for any g in gts
recall = mean over gts of max IoU(g, p) for any p in preds
F_beta = (1+b^2) P R / (b^2 P + R)
beta=1.0 is the original F1 and reproduces prior behaviour bit-for-bit.
beta>1 up-weights recall: used by v12 on multi-segment GTs (K_gt>1) to
counter the model's measured under-prediction (predicts 1.50 segs vs 2.10
GT, 59.5% of multi-seg videos miss segments). Single-seg keeps beta=1 so
its precision deterrent against spurious extra spans is untouched.
"""
if not preds or not gts:
return 0.0
pres = [max(_iou_1d(p, g) for g in gts) for p in preds]
recs = [max(_iou_1d(g, p) for p in preds) for g in gts]
P = sum(pres) / len(pres)
R = sum(recs) / len(recs)
b2 = beta * beta
denom = b2 * P + R
return (1 + b2) * P * R / denom if denom > 0 else 0.0
def mean_f1_at_tiou(preds, gts, thresholds=(0.5, 0.75, 0.85, 0.95)):
"""Average F1 over multiple strict IoU thresholds (AF eval-style).
For each threshold τ, do greedy max-IoU 1-to-1 matching of preds to gts,
count matches with IoU > τ as TP, compute F1@τ. Average over thresholds.
"""
if not preds or not gts:
return 0.0
f1s = []
for tau in thresholds:
used = set()
tp = 0
# Greedy in order of decreasing best-IoU per GT for stability
gt_order = sorted(
range(len(gts)),
key=lambda i: max((_iou_1d(p, gts[i]) for p in preds), default=0.0),
reverse=True,
)
for i in gt_order:
g = gts[i]
best, best_j = 0.0, -1
for j, p in enumerate(preds):
if j in used:
continue
iou = _iou_1d(p, g)
if iou > best:
best, best_j = iou, j
if best_j >= 0 and best > tau:
used.add(best_j)
tp += 1
P = tp / len(preds)
R = tp / len(gts)
f1 = 2 * P * R / (P + R) if (P + R) > 0 else 0.0
f1s.append(f1)
return sum(f1s) / len(f1s)
def combined_iou_score(preds, gts, alpha: float = 0.5):
"""Final per-sample reward in [0, 1]: alpha * soft_F1 + (1 - alpha) * mean_F1@tIoU."""
sf = soft_f1(preds, gts)
mf = mean_f1_at_tiou(preds, gts)
return alpha * sf + (1.0 - alpha) * mf
# ---------------------------------------------------------------------------
# Set-decomposed advantages helper (ICLR method A)
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Hungarian-matched gIoU reward + anomaly alignment (forgery-aware GRPO)
# ---------------------------------------------------------------------------
def _giou_1d(p: Tuple[float, float], g: Tuple[float, float]) -> float:
"""1-D generalized IoU. Range [-1, 1]; equals IoU when intervals overlap."""
s_p, e_p = p
s_g, e_g = g
inter = max(0.0, min(e_p, e_g) - max(s_p, s_g))
union = max(e_p, e_g) - min(s_p, s_g)
if union <= 0:
return 0.0
iou = inter / union
enclosure = max(e_p, e_g) - min(s_p, s_g)
if enclosure <= 0:
return iou
return iou - (enclosure - union) / enclosure
def hungarian_iou_reward(
pred_segments: Sequence[Tuple[float, float]],
gt_segments: Sequence[Tuple[float, float]],
lambda_fp: float = 0.2,
lambda_fn: float = 0.4,
) -> float:
"""Permutation-invariant multi-span reward via Hungarian matching on gIoU.
Score = (mean matched gIoU)/max(K_p, K_g)*matched_count
- lambda_fp * unmatched_pred / (K_p + 1)
- lambda_fn * unmatched_gt / (K_g + 1)
Both empty -> 1.0 (perfect agreement on "nothing to find").
"""
K_p, K_g = len(pred_segments), len(gt_segments)
if K_p == 0 and K_g == 0:
return 1.0
if K_p == 0:
return -lambda_fn * K_g / (K_g + 1)
if K_g == 0:
return -lambda_fp * K_p / (K_p + 1)
cost = np.zeros((K_p, K_g))
for i in range(K_p):
for j in range(K_g):
cost[i, j] = -_giou_1d(pred_segments[i], gt_segments[j])
row_ind, col_ind = linear_sum_assignment(cost)
matched_giou = float(sum(-cost[i, j] for i, j in zip(row_ind, col_ind)))
n_matched = len(row_ind)
denom = max(K_p, K_g)
return (
matched_giou / denom
- lambda_fp * (K_p - n_matched) / (K_p + 1)
- lambda_fn * (K_g - n_matched) / (K_g + 1)
)
def anomaly_alignment_reward(
pred_segments: Sequence[Tuple[float, float]],
anomaly_scores: np.ndarray,
fps_to_groups: float = 1.0,
) -> float:
"""Mean anomaly score inside predicted segments minus mean outside.
anomaly_scores: (T,) array, normalised to [0, 1] preferred but not required.
fps_to_groups: groups per source second (1.0 at fps=2 + temporal_stride=2).
Returns 0.0 when degenerate (empty preds, or preds cover entire video).
"""
T = int(anomaly_scores.shape[0])
if T == 0 or len(pred_segments) == 0:
return 0.0
inside = np.zeros(T, dtype=bool)
for s, e in pred_segments:
if e <= s:
continue
s_idx = max(0, int(s * fps_to_groups))
e_idx = min(T, int(e * fps_to_groups + 0.999))
if s_idx < e_idx:
inside[s_idx:e_idx] = True
if not inside.any() or inside.all():
return 0.0
a_in = float(anomaly_scores[inside].mean())
a_out = float(anomaly_scores[~inside].mean())
return a_in - a_out
def combined_forgery_reward(
pred_text: str,
gt_segments: Sequence[Tuple[float, float]],
anomaly_scores: Optional[np.ndarray] = None,
alpha: float = 1.0,
beta: float = 0.5,
gamma_verifier_cot: float = 0.0,
delta_format: float = 0.0,
lambda_fp: float = 0.2,
lambda_fn: float = 0.4,
lambda_format: float = 1.0,
fps_to_groups: float = 1.0,
):
"""Top-level reward.
total = alpha * hungarian_iou(answer, gt) # GT-anchored (primary)
+ beta * alignment(answer, scores) # verifier-anchored shaping
+ gamma_verifier_cot * alignment(think_tags, scores) # CoT verifier-grounding
+ delta_format * forensics_format_reward # CoT structural reward
with a format-gate penalty for parse failures.
`delta_format` (M3): forgery-aware path bypasses script-level reward
functions, so the format / CoT-structure signal must be folded in here.
`gamma_verifier_cot` (M3): rewards alignment between the model's <timestep>
tags inside <think> and the verifier's per-frame heatmap — reasoning
grounded in independent forensic evidence.
Returns (scalar, dict-of-diagnostics).
"""
pred = parse_segments(pred_text)
if len(pred) == 0 and len(gt_segments) > 0:
return -lambda_format, {
"r_iou": -lambda_format,
"r_ano": 0.0,
"r_cot_ver": 0.0,
"r_fmt": 0.0,
"K_pred": 0,
"K_gt": len(gt_segments),
"parse_failed": True,
}
r_iou = hungarian_iou_reward(pred, list(gt_segments), lambda_fp=lambda_fp, lambda_fn=lambda_fn)
if anomaly_scores is not None and beta != 0.0:
r_ano = anomaly_alignment_reward(pred, anomaly_scores, fps_to_groups=fps_to_groups)
else:
r_ano = 0.0
if anomaly_scores is not None and gamma_verifier_cot != 0.0:
think_segs = parse_timestep_tags(pred_text)
r_cot_ver = anomaly_alignment_reward(
think_segs, anomaly_scores, fps_to_groups=fps_to_groups
)
else:
r_cot_ver = 0.0
if delta_format != 0.0:
r_fmt = forensics_format_reward([pred_text])[0]
else:
r_fmt = 0.0
total = (
alpha * r_iou
+ beta * r_ano
+ gamma_verifier_cot * r_cot_ver
+ delta_format * r_fmt
)
return total, {
"r_iou": r_iou,
"r_ano": r_ano,
"r_cot_ver": r_cot_ver,
"r_fmt": r_fmt,
"K_pred": len(pred),
"K_gt": len(gt_segments),
"parse_failed": False,
}
# ---------------------------------------------------------------------------
# Set-decomposed advantages helper (legacy; kept for ablation)
# ---------------------------------------------------------------------------
_SEG_TEXT_RE = re.compile(rf"({_NUM}){_RANGE_SEP}({_NUM})", re.IGNORECASE)
def build_per_segment_token_mask(
completion_text: str,
n_tokens: int,
tokenizer,
max_K: int = 8,
):
"""For a single decoded completion, return (max_K, n_tokens) bool mask.
mask[k, t] = True iff token t belongs to the k-th parsed segment inside
the model's <answer>...</answer> block. K_pred (actual parsed segments)
is returned alongside.
Used for decomposed-advantage GRPO: each segment's tokens get a credit
signal computed from THAT segment's IoU vs Hungarian-matched GT, instead
of all tokens sharing one scalar advantage.
The completion_text must be the EXACT round-trip of the tokenized completion
(i.e. tokenizer.decode(completion_ids, skip_special_tokens=False)); we then
re-tokenize with offsets to locate segment text spans and map back to token
positions. Returns (zeros, 0) if no <answer> block or no parseable segments.
"""
import numpy as _np
mask = _np.zeros((max_K, n_tokens), dtype=bool)
if not completion_text:
return mask, 0
m = _ANSWER_RE.search(completion_text)
if not m:
return mask, 0
inner = m.group(1)
inner_char_start = m.start(1)
seg_char_spans = []
for sm in _SEG_TEXT_RE.finditer(inner):
try:
s_val, e_val = float(sm.group(1)), float(sm.group(2))
except ValueError:
continue
if e_val <= s_val:
continue
seg_char_spans.append(
(inner_char_start + sm.start(), inner_char_start + sm.end())
)
if not seg_char_spans:
return mask, 0
try:
enc = tokenizer(completion_text, return_offsets_mapping=True, add_special_tokens=False)
offsets = enc["offset_mapping"]
except Exception:
return mask, 0
K_pred = 0
for (cs, ce) in seg_char_spans[:max_K]:
# tokens whose char-span overlaps with [cs, ce)
tok_idxs = [
j for j, (os, oe) in enumerate(offsets)
if oe > cs and os < ce and j < n_tokens
]
if not tok_idxs:
continue
for tj in tok_idxs:
mask[K_pred, tj] = True
K_pred += 1
return mask, K_pred
def hungarian_per_pred_iou(
pred_segments: Sequence[Tuple[float, float]],
gt_segments: Sequence[Tuple[float, float]],
boundary_weight: float = 0.0,
boundary_tau: float = 2.0,
) -> list:
"""Hungarian-matched per-PREDICTED-segment score.
Base: IoU of pred_i with its Hungarian-matched GT segment. Range [0, 1].
If `boundary_weight` > 0, adds a boundary-precision bonus tailored to the
forensics task (manipulation boundaries carry the strongest artifact signal:
blur, warp, abrupt transitions). For each matched pair:
boundary = exp(-|s_p - s_g| / tau) * exp(-|e_p - e_g| / tau) in [0, 1]
score = (IoU + boundary_weight * boundary) / (1 + boundary_weight)
`boundary_tau` (in seconds) controls tolerance: tau=2 → ±2s boundary error
drops boundary bonus by factor of e.
Unmatched preds (K_pred > K_gt) receive 0.
Returns: list of length K_pred.
"""
K_p, K_g = len(pred_segments), len(gt_segments)
if K_p == 0:
return []
if K_g == 0:
return [0.0] * K_p
cost = np.zeros((K_p, K_g))
for i in range(K_p):
for j in range(K_g):
cost[i, j] = -_iou_1d(pred_segments[i], gt_segments[j])
row_ind, col_ind = linear_sum_assignment(cost)
matched = {int(r): int(c) for r, c in zip(row_ind, col_ind)}
out = []
for i in range(K_p):
if i not in matched:
out.append(0.0)
continue
j = matched[i]
iou = float(-cost[i, j])
if boundary_weight > 0.0:
ps, pe = pred_segments[i]
gs, ge = gt_segments[j]
boundary = float(
math.exp(-abs(ps - gs) / boundary_tau)
* math.exp(-abs(pe - ge) / boundary_tau)
)
score = (iou + boundary_weight * boundary) / (1.0 + boundary_weight)
else:
score = iou
out.append(score)
return out
def per_segment_recall(completions, solution):
"""Per-GT-segment recall used by the decomposed-advantage GRPO trainer.
For each rollout i, returns a list of K floats in [0, 1]:
R[k, i] = max over preds_i of IoU(pred, gt_k)
where K = len(solution[i]) (assumed identical across i within a query group).
Returns: list[list[float]] of shape (B, K). Empty inner list when GT is empty.
Non-uniform K across the batch is allowed — caller must handle that
(in single-query GRPO it's always uniform).
"""
out = []
for content, sol in zip(completions, solution):
gt = [tuple(x) for x in sol]
pred = parse_segments(content)
if not gt:
out.append([])
continue
recalls = [max((_iou_1d(g, p) for p in pred), default=0.0) for g in gt]
out.append(recalls)
return out
# ---------------------------------------------------------------------------
# Reward functions exposed to GRPO trainer
# ---------------------------------------------------------------------------
# Module-level accumulators for per-generator logging across reward calls.
_GEN_REWARD_SUM = defaultdict(float)
_GEN_REWARD_CNT = defaultdict(int)
def _log_per_generator(rewards, generators):
if not generators:
return
for r, g in zip(rewards, generators):
_GEN_REWARD_SUM[g] += float(r)
_GEN_REWARD_CNT[g] += 1
# Print a compact summary every ~50 samples.
total = sum(_GEN_REWARD_CNT.values())
if total > 0 and total % 50 == 0:
parts = []
for g in sorted(_GEN_REWARD_CNT.keys()):
avg = _GEN_REWARD_SUM[g] / max(1, _GEN_REWARD_CNT[g])
parts.append(f"{g}={avg:.3f}(n={_GEN_REWARD_CNT[g]})")
print(f"[per-gen rolling avg @ {total}] " + " ".join(parts))
def single_span_iou_reward(completions, solution, durations=None, generator=None, **kwargs):
"""TempSamp-R1 / Charades-style single-span IoU reward.
Paired with FORENSICS_SPLIT_SINGLE_SPAN=true at dataset build time:
each example has a length-1 GT, and the reward is the 1D IoU between
the FIRST predicted span and the (sole) GT span. Extra predicted spans
are ignored — this matches the original TempSamp-R1 setup, which
assumes single-span output. Empty pred → 0.
"""
rewards = []
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
gens = generator if isinstance(generator, list) else [generator] * len(completions)
for content, sol, gen in zip(completions, solution, gens):
gt = [tuple(x) for x in sol]
pred = parse_segments(content)
if not pred or not gt:
iou = 0.0
else:
iou = _iou_1d(pred[0], gt[0])
rewards.append(iou)
print(f"gen={gen} gt={gt} pred={pred}")
print(f" single_span_iou={iou:.3f} [{current_time}]")
if os.getenv("DEBUG_MODE") == "true":
log_path = os.getenv("LOG_PATH")
if log_path:
with open(log_path, "a") as f:
f.write(f"Generator: {gen}\n")
f.write(f"Content: {content}\n")
f.write(f"pred: {pred}\n")
f.write(f"gt: {gt}\n")
f.write(f"single_span_iou={iou:.3f}\n")
f.write(f"------------- {current_time} -------------\n")
_log_per_generator(rewards, gens)
return rewards
def forensics_iou_reward(completions, solution, durations=None, generator=None, **kwargs):
"""Set-level IoU reward; `solution[i]` = list[(s, e)].
Training reward = soft_F1 only (continuous, smooth gradient compatible with
TempSamp-R1's asymmetric soft advantage transform). mean_F1@tIoU is kept
for diagnostic logging but not used as the optimization signal.
Env-var FORENSICS_IOU_REQUIRE_THINK (default false): when true, multiply
iou by 0.3 if the completion has no <think> block. Closes the
"drop CoT, output answer only" reward-hacking path that stage2_v7a hit
on step 6 (cheat-mode iou ~0.42 vs CoT-mode iou ~0.23, tipping total
reward toward cheat once k_prec/k_recall are added).
"""
require_think = os.getenv("FORENSICS_IOU_REQUIRE_THINK", "false").lower() in ("true", "1", "yes")
rewards = []
current_time = datetime.now().strftime("%d-%H-%M-%S-%f")
gens = generator if isinstance(generator, list) else [generator] * len(completions)
# v12: recall-weighted soft_F1 for multi-segment GTs only. Continuous F-beta
# reshaping of the EXISTING reward (stays bounded in [0,1]) rather than an
# additive term — this is what v11 got wrong: a constant additive penalty
# (~0.13) became ~7x the converged within-group reward_std (~0.019), so after
# GRPO's per-group normalisation it dominated the advantage and wrecked the
# boundary gradient. A bounded reshape can't blow up that way.
# Env-gated; FORENSICS_MULTISEG_RECALL_BETA=1.0 (default) reproduces v10_r2.
multiseg_beta = float(os.getenv("FORENSICS_MULTISEG_RECALL_BETA", "1.0"))
for content, sol, gen in zip(completions, solution, gens):
gt = [tuple(x) for x in sol]
pred = parse_segments(content)
beta = multiseg_beta if len(gt) > 1 else 1.0
sf = soft_f1(pred, gt, beta=beta)
mf = mean_f1_at_tiou(pred, gt)
reward = sf
if require_think and not _THINK_RE.search(content or ""):
reward = reward * 0.3
# Asymmetric over-prediction penalty. Measured failure mode: on
# single-segment videos the model emits spurious extra segments, which
# soft_F1's precision term penalises too weakly to suppress. This fires
# ONLY when K_pred > K_gt, so correct multi-segment predictions (the
# method's main win, +13.89 mIoU on multi-seg) are never penalised.
# Env-gated; FORENSICS_OVERPRED_PENALTY=0.0 (default) reproduces prior behaviour.
overpred_kappa = float(os.getenv("FORENSICS_OVERPRED_PENALTY", "0.0"))
if overpred_kappa > 0.0 and len(pred) > len(gt):
reward = reward - overpred_kappa * (len(pred) - len(gt)) / (len(pred) + 1)
rewards.append(reward)
print(f"gen={gen} gt={gt} pred={pred}")
print(f" soft_F1={sf:.3f} mean_F1@tIoU={mf:.3f} reward={reward:.3f} [{current_time}]")
if os.getenv("DEBUG_MODE") == "true":
log_path = os.getenv("LOG_PATH")
if log_path:
with open(log_path, "a") as f:
f.write(f"Generator: {gen}\n")
f.write(f"Content: {content}\n")
f.write(f"pred: {pred}\n")
f.write(f"gt: {gt}\n")
f.write(f"soft_F1={sf:.3f} mean_F1@tIoU={mf:.3f} reward={reward:.3f}\n")
f.write(f"------------- {current_time} -------------\n")
_log_per_generator(rewards, gens)
return rewards
def forensics_format_reward(completions, **kwargs):
"""Structured-output format reward.
Stage-1 (FORENSICS_COT=false): 1.0 iff <answer> strictly matches
"s1 to e1; ..." with no extra content (no CoT in answer block).
Stage-2 (FORENSICS_COT=true): graded score in [0, 1] that grades on
*functional* parseability, not strict regex match. This avoids dead-zoning
the format signal when the model adopts a CoT format that puts <timestep>
+ reasoning inside <answer> (v5 step ~27 onward exhibited this):
0.4 if parse_segments(text) returns at least one interval
— answer is functionally well-formed (numbers extractable),
whether or not <answer> contains additional CoT content.
+0.3 if <think> block is present and non-empty
+0.3 if at least one <timestep> tag appears INSIDE <think>
— this last branch is the gradient pushing against the
"merge think+answer into one block" shortcut.
"""
cot_mode = os.getenv("FORENSICS_COT", "true").lower() not in ("false", "0", "no")
rewards = []
for content in completions:
text = content or ""
if not cot_mode:
rewards.append(1.0 if _FULL_FORMAT_RE.search(text) else 0.0)
continue
if not parse_segments(text):
rewards.append(0.0)
continue
r = 0.4
m = _THINK_RE.search(text)
if m and m.group(1).strip():
r += 0.3
if _TIMESTEP_RE.search(m.group(1)):
r += 0.3
rewards.append(r)
print("format rewards:", rewards)
return rewards
def forensics_cot_consistency_reward(completions, **kwargs):
"""1D set-IoU between <timestep> tags in <think> and segments in <answer>.
Pressures CoT reasoning to be load-bearing:
- Intervals reasoned about must show up in the final answer.
- Final answer intervals must have been considered during reasoning.
Returns 0 when either side is empty (e.g. think had no <timestep> tags, or
answer failed to parse). Only emits a non-zero signal when the model has
started using the CoT structure — IoU and format rewards still drive
learning before that.
"""
weight = float(os.getenv("FORENSICS_COT_CONSIS_WEIGHT", "1.0"))
rewards = []
for content in completions:
text = content or ""
if "<think>" not in text.lower():
rewards.append(0.0)
continue
think_segs = parse_timestep_tags(text)
answer_segs = parse_segments(text)
if not think_segs or not answer_segs:
rewards.append(0.0)
continue
rewards.append(weight * _mass_iou_1d(think_segs, answer_segs))
if rewards:
arr = np.array(rewards, dtype=float)
print(f"cot_consis rewards (w={weight}): mean={arr.mean():.3f} nonzero={(arr>0).sum()}/{len(arr)} "
f">0.5={(arr>0.5*weight).sum()}")
return rewards
_CF_OBSERVED_RE = re.compile(r"\bobserved\s*:", re.IGNORECASE)
_CF_AUTHENTIC_RE = re.compile(r"\bif\s+authentic\s*:", re.IGNORECASE)
_CF_RULED_OUT_RE = re.compile(r"\bruled\s+out\s*:", re.IGNORECASE)
def forensics_cot_counterfactual_reward(completions, **kwargs):
"""Counterfactual-structure reward for stage-2 CoT.
Each <timestep>...</timestep> tag inside <think> must be followed (before the
next <timestep> or end of <think>) by all three anchors:
Observed:, If authentic:, Ruled out:
A timestep is "well-justified" when all three appear in its scope.
r = well-justified / max(num_timesteps, 1)
Returns 0 when there is no <think> or no <timestep> tag (model has not
started using the structure yet — format reward + iou still drive learning).
Pairs with:
- format reward (structural wrapper)
- cot_consis (<timestep> ↔ <answer> alignment)
- iou (final answer correctness)
Anchors are detectable, but cot_consis prevents emitting bogus timesteps
and iou prevents misalignment with truth — so gaming all three at once
requires actually picking the right intervals.
"""
weight = float(os.getenv("FORENSICS_COT_CF_WEIGHT", "1.0"))
rewards = []
for content in completions:
text = content or ""
m = _THINK_RE.search(text)
if not m:
rewards.append(0.0); continue
think = m.group(1)
ts_matches = list(_TIMESTEP_RE.finditer(think))
if not ts_matches:
rewards.append(0.0); continue
good = 0
for i, tsm in enumerate(ts_matches):
scope_start = tsm.end()
scope_end = ts_matches[i + 1].start() if i + 1 < len(ts_matches) else len(think)
scope = think[scope_start:scope_end]
if (_CF_OBSERVED_RE.search(scope)
and _CF_AUTHENTIC_RE.search(scope)
and _CF_RULED_OUT_RE.search(scope)):
good += 1
rewards.append(weight * good / len(ts_matches))
if rewards:
arr = np.array(rewards, dtype=float)
print(f"cot_cf rewards (w={weight}): mean={arr.mean():.3f} full={(arr>=0.999*weight).sum()}/{len(arr)} "
f"any={(arr>0).sum()}/{len(arr)}")
return rewards
def forensics_cot_gt_alignment_reward(completions, solution, **kwargs):
"""GT-anchored CoT reward: each <timestep> in <think> must align with a GT segment.
For every <timestep>start to end</timestep> emitted inside <think>, compute
max IoU against any GT segment. Reward = mean of these per-timestep IoUs.
Why this matters: the previous cot_cf reward checked only the 3-anchor
template ("Observed/If authentic/Ruled out"), which is structural and
incentivises padding the CoT with extra fabricated timesteps to harvest
more template points. cot_gt_align directly penalises fabricated timesteps
— every <timestep> the model invents that doesn't match a GT segment drags
the mean down. Combined with cot_consis (timestep ↔ answer set-IoU) this
forces the CoT to be **few and grounded in truth**, not many and decorative.
Returns 0 when:
- no <think> block
- no <timestep> tag inside <think>
Returns 1.0 when GT is empty AND there are no timesteps (vacuous agreement).
"""
rewards = []
for content, sol in zip(completions, solution):
text = content or ""
m = _THINK_RE.search(text)
if not m:
rewards.append(0.0); continue
ts_segs = parse_timestep_tags(text)
if not ts_segs:
rewards.append(0.0); continue
gts = [tuple(g) for g in sol]
if not gts:
rewards.append(0.0); continue
per_ts_iou = [max(_iou_1d(ts, g) for g in gts) for ts in ts_segs]
rewards.append(sum(per_ts_iou) / len(per_ts_iou))
if rewards:
arr = np.array(rewards, dtype=float)
print(f"cot_gt_align rewards: mean={arr.mean():.3f} >0.5={(arr>0.5).sum()}/{len(arr)} "
f">0.8={(arr>0.8).sum()}/{len(arr)}")
return rewards
def forensics_cot_gt_alignment_hungarian_reward(completions, solution, **kwargs):
"""Hungarian-matched GT-anchored CoT reward.
Drop-in replacement for cot_gt_align that removes its over-fragmentation
pathology. The original reward = mean over predicted timesteps of
max-IoU against any GT segment; this is non-decreasing in the number of
predicted timesteps (a duplicate of a correct timestep contributes the
same max-IoU, leaving the mean unchanged), so an RL policy is
rationally incentivised to enumerate. Empirically observed: v3 trained
with cot_gt_align emits 2.75 timesteps per video vs GT 1.21, with
pred>1 in 95% of cases and F1@0.5 dropping from 46% to 30%.
Hungarian fix: 1-to-1 match predicted timesteps to GT, normalise the
matched-IoU sum by max(K_p, K_g). Extra timesteps inflate the
denominator without contributing matched IoU, so fragmentation is
strictly costly. Symmetric: missing timesteps also reduce the score
via the same denominator.
Score (K_p = #<timestep>, K_g = #GT):
both empty → 1.0 (vacuous agreement)
K_g=0, K_p>0 → 0.0 (any timestep is fabricated)
K_p=0, K_g>0 → 0.0 (no CoT to credit)
otherwise → sum(matched_IoU) / max(K_p, K_g) in [0, 1]
Env-var FORENSICS_COT_ALIGN_WEIGHT (default 1.0): scales the final
reward. Raise to give the GT-anchored CoT signal more weight; this is
the ONLY CoT reward that uses ground truth (cot_consis is self-IoU,
gameable by copying <think> into <answer>).
"""
weight = float(os.getenv("FORENSICS_COT_ALIGN_WEIGHT", "1.0"))
rewards = []
for content, sol in zip(completions, solution):
text = content or ""
if not _THINK_RE.search(text):
rewards.append(0.0); continue
ts_segs = parse_timestep_tags(text)
gts = [tuple(g) for g in sol]
K_p, K_g = len(ts_segs), len(gts)
if K_p == 0 and K_g == 0:
rewards.append(weight * 1.0); continue
if K_p == 0 or K_g == 0:
rewards.append(0.0); continue
cost = np.zeros((K_p, K_g))
for i in range(K_p):
for j in range(K_g):
cost[i, j] = -_iou_1d(ts_segs[i], gts[j])
row_ind, col_ind = linear_sum_assignment(cost)
matched_iou = float(sum(-cost[i, j] for i, j in zip(row_ind, col_ind)))
rewards.append(weight * matched_iou / max(K_p, K_g))
if rewards:
arr = np.array(rewards, dtype=float)
print(f"cot_gt_align_h rewards: mean={arr.mean():.3f} >0.5={(arr>0.5).sum()}/{len(arr)} "
f">0.8={(arr>0.8).sum()}/{len(arr)}")
return rewards
def forensics_kpred_reward(completions, solution, **kwargs):
"""Reward for predicting the right NUMBER of segments.
r = exp(-|K_pred - K_gt| / tau), tau read from FORENSICS_KPRED_TAU (default 1.0).
Range [0, 1]. K matches → 1.0; off by 1 → 0.37; off by 2 → 0.14.
Addresses the observed bottleneck where the policy collapses to K=1 even
when GT has K=2-3 manipulated segments (K_match_ratio ≈ 0.5 in stage1).
"""
tau = float(os.getenv("FORENSICS_KPRED_TAU", "1.0"))
rewards = []
for content, sol in zip(completions, solution):
pred = parse_segments(content or "")
K_pred = len(pred)
K_gt = len(sol)
rewards.append(float(math.exp(-abs(K_pred - K_gt) / tau)))
print("kpred rewards:", rewards)
return rewards
def forensics_coverage_reward(completions, solution, **kwargs):
"""Coverage reward: fraction of GT segments covered by ANY pred with IoU > tau.
r = (# GT segments with max-pred IoU > tau) / K_gt
Addresses K-collapse: standard soft_F1 only weakly penalises missing a 2nd or
3rd GT segment (dilution across matched pairs); coverage directly rewards
recall of every GT segment regardless of K_pred padding.
tau read from FORENSICS_COVERAGE_TAU (default 0.3 — same as F1@0.3 threshold).
Both empty -> 1.0. GT empty + preds -> 0.0 (precision handled elsewhere).
"""
tau = float(os.getenv("FORENSICS_COVERAGE_TAU", "0.3"))
rewards = []
for content, sol in zip(completions, solution):
pred = parse_segments(content or "")
gt = list(sol)
if not gt and not pred:
rewards.append(1.0); continue
if not gt:
rewards.append(0.0); continue
if not pred:
rewards.append(0.0); continue
covered = 0
for g in gt:
best = max(_iou_1d(p, g) for p in pred)
if best > tau:
covered += 1
rewards.append(covered / len(gt))
print("coverage rewards:", rewards)
return rewards
def forensics_k_recall_reward(completions, solution, **kwargs):
"""Asymmetric K reward: penalize K_pred < K_gt only, do NOT penalize K_pred > K_gt.
r = min(K_pred / K_gt, 1.0) if K_gt > 0 else (1.0 if K_pred==0 else 0.0)
Replaces symmetric kpred which over-penalised over-prediction and pushed the
policy to the trivially-safe K=1 collapse (full run: 92.3% K=1 preds).
Precision against over-prediction is left to the F1-side rewards.
Env-var FORENSICS_K_REQUIRE_TIMESTEP (default false): when true, return 0
if the completion has no <timestep> tag inside <think>. Stops k_recall
from acting as a cheat-mode anchor (in v7a it gave +0.92 to "no think,
K=1 answer" rollouts, swamping the CoT-mode signal).
"""
require_ts = os.getenv("FORENSICS_K_REQUIRE_TIMESTEP", "false").lower() in ("true", "1", "yes")
rewards = []
for content, sol in zip(completions, solution):
text = content or ""
if require_ts and not parse_timestep_tags(text):
rewards.append(0.0); continue
pred = parse_segments(text)
K_pred, K_gt = len(pred), len(sol)
if K_gt == 0:
rewards.append(1.0 if K_pred == 0 else 0.0)
else:
rewards.append(min(K_pred / K_gt, 1.0))
print("k_recall rewards:", rewards)
return rewards
def forensics_k_precision_reward(completions, solution, **kwargs):
"""Asymmetric K precision: penalize K_pred > K_gt only.
r = min(K_gt / K_pred, 1.0) if K_pred > 0 else (1.0 if K_gt==0 else 0.0)
Complement of k_recall. Together they form a symmetric K signal that
survives GRPO group-norm (set-level reward, not per-segment), so over-
prediction actually costs the rollout. Targets the r1r2 failure mode where
asymmetric recall let K_pred run unbounded on single-GT generators
(vidu, ltx) and dragged strict F1 down.
Env-var FORENSICS_K_REQUIRE_TIMESTEP (default false): when true, return 0
if the completion has no <timestep> tag inside <think>. Pair with k_recall.
"""
require_ts = os.getenv("FORENSICS_K_REQUIRE_TIMESTEP", "false").lower() in ("true", "1", "yes")
rewards = []
for content, sol in zip(completions, solution):
text = content or ""
if require_ts and not parse_timestep_tags(text):
rewards.append(0.0); continue
pred = parse_segments(text)
K_pred, K_gt = len(pred), len(sol)
if K_pred == 0:
rewards.append(1.0 if K_gt == 0 else 0.0)
else:
rewards.append(min(K_gt / K_pred, 1.0))
print("k_precision rewards:", rewards)
return rewards
def _iou_1d(a, b):
s1, e1 = a; s2, e2 = b
inter = max(0.0, min(e1, e2) - max(s1, s2))
union = max(e1, e2) - min(s1, s2)
return inter / union if union > 0 else 0.0
def forensics_strict_boundary_reward(completions, solution, **kwargs):
"""Strict F1 at multiple IoU thresholds — explicitly rewards tight boundaries.
soft_F1 (iou_reward) is continuous and rewards loose "in the ballpark"
predictions just as much as tight ones — empirically pushes the policy
toward over-extended intervals (v1 mIoU 0.48 but F1@0.5 0.27 was the
canonical failure of soft-only reward). This reward complements it by
giving discrete credit only when matched IoU clears the threshold.
Score = mean over τ ∈ thresholds of greedy-matched F1@τ, range [0, 1].
Env-var FORENSICS_STRICT_TAUS (default "0.5,0.7"): comma-separated thresholds.
Increase weight implicitly by including a higher threshold (e.g. "0.5,0.7,0.85").
Directly attacks the 38.8% "K right, boundary drift" failure mode in
stage1_decomp_boundary.
"""
taus_str = os.getenv("FORENSICS_STRICT_TAUS", "0.5,0.7")
taus = tuple(float(t) for t in taus_str.split(",") if t.strip())
if not taus:
taus = (0.5, 0.7)
rewards = []
for content, sol in zip(completions, solution):
pred = parse_segments(content or "")
gt = [tuple(x) for x in sol]
rewards.append(mean_f1_at_tiou(pred, gt, thresholds=taus))
print(f"strict_boundary rewards (taus={taus}): {rewards}")
return rewards
def forensics_strict_edge_reward(completions, solution, **kwargs):
"""Sub-second boundary precision reward (replaces probe-based R1).
strict_boundary @ τ∈{0.3,0.5,0.7} saturates the moment IoU clears 0.7 —
no gradient remains to push toward F1@0.85/0.95. The v10 eval gap
(F1@0.7=48 vs F1@0.85=26 vs F1@0.95=10) sits exactly in that silent zone.
Per Hungarian-matched (pred, gt):
b = exp(-|Δs|/σ) * exp(-|Δe|/σ) in [0, 1]
Both boundaries must hit (product, not mean) — single-boundary degenerate
preds (e.g. (4.0, 4.1) matched to (4.0, 14.0)) get exp(0)·exp(-19.8) ≈ 0.
F1-style aggregation handles K_pred ≠ K_gt:
TP_soft = Σ b over matched pairs
P = TP_soft / K_pred, R = TP_soft / K_gt, F1 = 2PR/(P+R)
σ=0.5s by default: above 0.5s error the bonus decays e-fold; reaches ~0.9
only when both boundaries are within ~0.05s of GT. Targets the 0.85/0.95
regime that strict_boundary cannot see.
"""
sigma = float(os.getenv("FORENSICS_STRICT_EDGE_SIGMA", "0.5"))
rewards = []
for content, sol in zip(completions, solution):
preds = parse_segments(content or "")
gts = [tuple(x) for x in sol]
if not preds or not gts:
rewards.append(0.0)
continue
K_p, K_g = len(preds), len(gts)
cost = np.zeros((K_p, K_g))
for i in range(K_p):
for j in range(K_g):
cost[i, j] = -_iou_1d(preds[i], gts[j])
row_ind, col_ind = linear_sum_assignment(cost)
tp_soft = 0.0
for r, c in zip(row_ind, col_ind):
ps, pe = preds[r]
gs, ge = gts[c]
b = math.exp(-abs(ps - gs) / sigma) * math.exp(-abs(pe - ge) / sigma)
tp_soft += float(b)
P = tp_soft / K_p
R = tp_soft / K_g
f1 = 2 * P * R / (P + R) if (P + R) > 0 else 0.0
rewards.append(float(f1))
print(f"strict_edge rewards (sigma={sigma}): {rewards}")
return rewards
# ---------------------------------------------------------------------------
# Binary probing reward — unified R1 (window coherence) / R3 (point forgery)
# ---------------------------------------------------------------------------
R1_COHERENCE_QUESTION = (
"Watch the following short video clip. Is it internally coherent — that "
"is, does it show a continuous, consistent scene without any abrupt "
"change in appearance, motion, lighting, object identity, or background?"
)
R3_FORGERY_QUESTION = (
"Watch the following short video clip. Does it show AI-generated or "
"manipulated visual content (e.g., unrealistic motion, blurry boundaries, "
"texture or lighting inconsistencies, or unnatural object behavior)?"
)
def _build_r1_probes(t1, t2, delta_s, duration):
"""6 window probes per (t1, t2): yes-yes-no pattern around each boundary."""
probes = []
for t in (t1, t2):
pre = (max(0.0, t - delta_s), t,
R1_COHERENCE_QUESTION, "yes")
post = (t, min(duration, t + delta_s),
R1_COHERENCE_QUESTION, "yes")
cross = (max(0.0, t - delta_s / 2),
min(duration, t + delta_s / 2),
R1_COHERENCE_QUESTION, "no")
probes.extend([pre, post, cross])
return probes
def _build_r3_probes(t1, t2, point_window_s, duration):
"""4 point probes around (t1, t2): expected pattern N-Y-Y-N."""
half = point_window_s / 2
probes = [
(max(0.0, t1 - 1 - half), max(0.0, t1 - 1 + half),
R3_FORGERY_QUESTION, "no"),
(max(0.0, t1 + 1 - half), min(duration, t1 + 1 + half),
R3_FORGERY_QUESTION, "yes"),
(max(0.0, t2 - 1 - half), max(0.0, t2 - 1 + half),
R3_FORGERY_QUESTION, "yes"),
(max(0.0, t2 + 1 - half), min(duration, t2 + 1 + half),
R3_FORGERY_QUESTION, "no"),
]
return probes
def _build_probes(t1, t2, schedule, delta_s, point_window_s, duration):
if schedule == "r1":
return _build_r1_probes(t1, t2, delta_s, duration)
if schedule == "r3":
return _build_r3_probes(t1, t2, point_window_s, duration)
if schedule == "combined":
return (_build_r1_probes(t1, t2, delta_s, duration) +
_build_r3_probes(t1, t2, point_window_s, duration))
raise ValueError(f"Unknown FORENSICS_PROBE_SCHEDULE: {schedule}")
# Module-level step counter so the reward can be cheaply skipped on most
# training steps (gated by FORENSICS_PROBE_INTERVAL_STEPS). The counter
# advances by one per reward-function invocation (= one trainer step under
# bs=1 GRPO).
_PROBE_STEP_COUNTER = 0
def _unwrap(maybe_listed, key=None):
"""Helper: HF datasets wrap cached tensors and dicts as 1-element lists;
in reward_kwargs they're then repeated G times. This peels all list
wrappers — including the one Qwen's process_vision_info nests inside
video_kwargs[fps] (e.g. {"fps": [2.0]})."""
x = maybe_listed
while isinstance(x, list) and len(x) > 0:
x = x[0]
if key is not None and isinstance(x, dict):
v = x.get(key)
while isinstance(v, list) and len(v) > 0:
v = v[0]
return v
return x
def binary_probing_reward(
completions,
solution,
video_inputs=None,
video_kwargs=None,
durations=None,
use_preprocessed=None,
**kwargs,
):
"""Binary probing reward — implements R1 (window coherence) / R3 (point
forgery) / combined via a frozen Qwen2.5-VL reference model.
Per rollout:
1. Parse predicted intervals from <answer>...</answer>.
2. For each predicted (t1, t2), build a probe schedule:
- "r1": 6 window probes (yes,yes,no per boundary)
- "r3": 4 point probes (N-Y-Y-N around endpoints)
- "combined": both
3. Run all probes through the frozen prober → P(expected_token).
4. reward = mean over probes of P(expected).
Env vars (all optional):
FORENSICS_PROBE_MODEL required, path to frozen Qwen2.5-VL
FORENSICS_PROBE_SCHEDULE "r1" | "r3" | "combined" (default r1)
FORENSICS_PROBE_DELTA_S R1 window width, sec (default 2.0)
FORENSICS_PROBE_POINT_WINDOW_S R3 clip width, sec (default 1.0)
FORENSICS_PROBE_INTERVAL_STEPS compute every N steps (default 4; 1 = every step)
FORENSICS_PROBE_MAX_ROLLOUTS only probe first M rollouts of each
group (default 2; 0 = no limit)
FORENSICS_PROBE_NUM_GENERATIONS num_generations hint for the gate
(default 4)
FORENSICS_PROBE_MAX_SEGS max predicted segs probed per rollout
(default 3)
FORENSICS_PROBE_MAX_PROBES_PER_BATCH prober batch size (default 16)
FORENSICS_PROBE_K_HARD_GATE zero reward when K_pred > K_gt * gate
(default 2.0; 0 = no gate)
"""
global _PROBE_STEP_COUNTER
_PROBE_STEP_COUNTER += 1
n_completions = len(completions)
interval = int(os.environ.get("FORENSICS_PROBE_INTERVAL_STEPS", "4"))
if interval > 0 and (_PROBE_STEP_COUNTER % interval) != 0:
return [0.0] * n_completions
if video_inputs is None or video_kwargs is None:
# Trainer did not pass video tensors through reward_kwargs — bail out.
return [0.0] * n_completions
schedule = os.environ.get("FORENSICS_PROBE_SCHEDULE", "r1")
delta_s = float(os.environ.get("FORENSICS_PROBE_DELTA_S", "2.0"))
point_window_s = float(os.environ.get("FORENSICS_PROBE_POINT_WINDOW_S", "1.0"))
max_rollouts = int(os.environ.get("FORENSICS_PROBE_MAX_ROLLOUTS", "2"))
G = int(os.environ.get("FORENSICS_PROBE_NUM_GENERATIONS", "4"))
max_segs = int(os.environ.get("FORENSICS_PROBE_MAX_SEGS", "3"))
max_bs = int(os.environ.get("FORENSICS_PROBE_MAX_PROBES_PER_BATCH", "16"))
k_hard_gate = float(os.environ.get("FORENSICS_PROBE_K_HARD_GATE", "2.0"))
try:
from src.open_r1.binary_prober import get_prober, slice_video_by_time
prober = get_prober()
except Exception as e:
import traceback as _tb
print(f"[binary_probing] prober init failed: {e}\n{_tb.format_exc()}")
return [0.0] * n_completions
rewards = [0.0] * n_completions
# Each entry: (rollout_idx, clip, fps, question, expected_token)
flat_probes = []
for i, (content, sol) in enumerate(zip(completions, solution)):
# Cap probing to the first `max_rollouts` of every group of G.
if max_rollouts > 0 and (i % G) >= max_rollouts:
continue
pred_segs = parse_segments(content or "")
if not pred_segs:
continue
# K hard gate (anti-hacking: prevents inflating reward via excess segs).
if isinstance(sol, (list, tuple)) and sol and isinstance(sol[0], (list, tuple)):
K_gt = len(sol)
else:
K_gt = 1
if k_hard_gate > 0 and len(pred_segs) > K_gt * k_hard_gate:
continue
pred_segs = pred_segs[:max_segs]
try:
vi = _unwrap(video_inputs[i]) # tensor (T, C, H, W)
fps = _unwrap(video_kwargs[i], key="fps")
if fps is None or vi is None:
continue
fps = float(fps)
dur_raw = durations[i] if durations is not None else None
duration = float(_unwrap(dur_raw)) if dur_raw is not None else 1e9
except Exception as e:
print(f"[binary_probing] sample {i} unwrap failed: {e}")
continue
for (t1, t2) in pred_segs:
for (s_s, s_e, question, expected) in _build_probes(
t1, t2, schedule, delta_s, point_window_s, duration
):
clip = slice_video_by_time(vi, fps, s_s, s_e)
if clip is None:
continue
flat_probes.append((i, clip, fps, question, expected))
if not flat_probes:
return rewards
rollout_scores: list[list[float]] = [[] for _ in range(n_completions)]
for start in range(0, len(flat_probes), max_bs):
chunk = flat_probes[start:start + max_bs]
clips = [c for (_, c, _, _, _) in chunk]
fps_list = [f for (_, _, f, _, _) in chunk]
questions = [q for (_, _, _, q, _) in chunk]
try:
results = prober.probe_batch(clips, fps_list, questions)
except Exception as e:
import traceback as _tb
print(f"[binary_probing] prober.probe_batch failed: {e}\n{_tb.format_exc()}")
continue
for (idx, _, _, _, expected), (p_yes, p_no) in zip(chunk, results):
score = p_yes if expected == "yes" else p_no
rollout_scores[idx].append(float(score))
for i, scores in enumerate(rollout_scores):
if scores:
rewards[i] = float(np.mean(scores))
if os.environ.get("DEBUG_MODE", "false").lower() == "true":
nz = sum(1 for r in rewards if r > 0)
print(f"[binary_probing] step={_PROBE_STEP_COUNTER} schedule={schedule} "
f"non-zero={nz}/{n_completions} mean={np.mean(rewards):.3f} "
f"probes={len(flat_probes)}")
return rewards
REWARD_FUNCS_REGISTRY = {
"iou": forensics_iou_reward,
"single_span_iou": single_span_iou_reward,
"format": forensics_format_reward,
"cot_consis": forensics_cot_consistency_reward,
"kpred": forensics_kpred_reward,
"coverage": forensics_coverage_reward,
"k_recall": forensics_k_recall_reward,
"k_precision": forensics_k_precision_reward,
"cot_cf": forensics_cot_counterfactual_reward,
"cot_gt_align": forensics_cot_gt_alignment_reward,
"cot_gt_align_h": forensics_cot_gt_alignment_hungarian_reward,
"strict_boundary": forensics_strict_boundary_reward,
"strict_edge": forensics_strict_edge_reward,
"binary_probing": binary_probing_reward,
}
# ---------------------------------------------------------------------------
# Self-test
# ---------------------------------------------------------------------------
if __name__ == "__main__":
cases = [
# (completion, gt, expected_combined_approx)
("<answer>2.80 to 14.30</answer>", [(2.80, 14.30)], 1.0),
("<think>...</think><answer>10.50 to 16.40; 24.90 to 42.40</answer>",
[(10.50, 16.40), (24.90, 42.40)], 1.0),
("<answer>3.0 to 13.0</answer>", [(2.80, 14.30)], 0.85), # close
("<answer>50.0 to 60.0</answer>", [(2.80, 14.30)], 0.0), # no overlap
("<answer>2.80 to 14.30; 100.0 to 110.0</answer>", # extra FP
[(2.80, 14.30)], 0.5),
("<answer>2.8 - 14.3</answer>", [(2.80, 14.30)], 1.0), # dash separator
("<answer>2.8 ~ 14.3, 24.9 - 42.4</answer>",
[(2.80, 14.30), (24.90, 42.40)], 1.0), # mixed separators
("<answer>0 to 50</answer>", [(2.80, 14.30)], 0.3), # shotgun
("garbage", [(2.80, 14.30)], 0.0),
]
print("=== iou reward ===")
for comp, gt, expected in cases:
r = forensics_iou_reward([comp], [gt], generator=["wan"])[0]
f = forensics_format_reward([comp])[0]
sf = soft_f1(parse_segments(comp), gt)
mf = mean_f1_at_tiou(parse_segments(comp), gt)
print(f" combined={r:.3f} (~{expected}) soft_F1={sf:.3f} mean_F1@τ={mf:.3f} fmt={f} | {comp[:55]}")
print("\n=== cot-mode format + consistency reward ===")
os.environ["FORENSICS_COT"] = "true"
cot_cases = [
("<answer>10.0 to 20.0</answer>",
"no_think_no_timestep"),
("<think>I see something suspicious.</think><answer>10.0 to 20.0</answer>",
"think_but_no_timestep"),
("<think>Around <timestep>10.0 to 20.0</timestep> motion looks fake.</think>"
"<answer>10.0 to 20.0</answer>",
"perfect_match"),
("<think><timestep>5.0 to 15.0</timestep></think><answer>20.0 to 30.0</answer>",
"think_answer_mismatch"),
("<think><timestep>5.0 to 25.0</timestep></think><answer>10.0 to 20.0</answer>",
"think_superset_of_answer"),
("<think><timestep>0.0 to 10.0</timestep><timestep>20.0 to 30.0</timestep></think>"
"<answer>0.0 to 10.0; 20.0 to 30.0</answer>",
"two_segments_match"),
]
for comp, label in cot_cases:
fmt = forensics_format_reward([comp])[0]
consis = forensics_cot_consistency_reward([comp])[0]
print(f" fmt={fmt:.2f} consis={consis:.3f} | {label}")
os.environ.pop("FORENSICS_COT", None)