forensics-grpo / code /src /open_r1 /hard_sample_miner.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
5 kB
"""
Intra-video hard negative sample miner for GRPO training.
For each ground truth temporal segment, constructs 3 hard negatives:
- One with IoU in [0.5, 0.7) — medium difficulty
- One with IoU in [0.7, 0.85) — hard
- One with IoU in [0.85, 0.95) — very hard (boundary-level)
Perturbation strategies: shift / expand / shrink / asymmetric (random choice).
max_attempts=50 per bucket; returns None on failure (no random fill).
"""
import random
class IntraVideoHardSampleMiner:
def __init__(self, max_attempts=50, seed=None):
self.max_attempts = max_attempts
self.rng = random.Random(seed)
# ------------------------------------------------------------------ #
# IoU helper
# ------------------------------------------------------------------ #
@staticmethod
def compute_iou(s1, e1, s2, e2):
inter = max(0.0, min(e1, e2) - max(s1, s2))
union = max(e1, e2) - min(s1, s2)
if union <= 0:
return 0.0
return inter / union
# ------------------------------------------------------------------ #
# Four perturbation strategies
# ------------------------------------------------------------------ #
def _perturb_shift(self, gt_s, gt_e, dur):
L = gt_e - gt_s
d = self.rng.uniform(-L * 0.6, L * 0.6)
ns, ne = max(0, gt_s + d), min(dur, gt_e + d)
return (ns, ne) if ns < ne else (None, None)
def _perturb_expand(self, gt_s, gt_e, dur):
L = gt_e - gt_s
left = self.rng.uniform(0, 0.8) * L
right = self.rng.uniform(0, 0.8) * L
return max(0, gt_s - left), min(dur, gt_e + right)
def _perturb_shrink(self, gt_s, gt_e, dur):
L = gt_e - gt_s
ls = self.rng.uniform(0, 0.45) * L
rs = self.rng.uniform(0, 0.45) * L
ns, ne = gt_s + ls, gt_e - rs
return (ns, ne) if ns < ne else (None, None)
def _perturb_asymmetric(self, gt_s, gt_e, dur):
L = gt_e - gt_s
ds = self.rng.uniform(-0.6, 0.6) * L
de = self.rng.uniform(-0.6, 0.6) * L
ns, ne = max(0, gt_s + ds), min(dur, gt_e + de)
return (ns, ne) if ns < ne else (None, None)
# ------------------------------------------------------------------ #
# Core generation logic
# ------------------------------------------------------------------ #
def _try_generate(self, gt_s, gt_e, dur, iou_lo, iou_hi):
fns = [self._perturb_shift, self._perturb_expand,
self._perturb_shrink, self._perturb_asymmetric]
for _ in range(self.max_attempts):
fn = self.rng.choice(fns)
ns, ne = fn(gt_s, gt_e, dur)
if ns is None:
continue
iou = self.compute_iou(gt_s, gt_e, ns, ne)
if iou_lo <= iou < iou_hi:
return {
"start": round(ns, 2),
"end": round(ne, 2),
"iou_with_gt": round(iou, 4),
"bucket": f"{iou_lo}-{iou_hi}",
}
return None
def mine(self, gt_start, gt_end, duration):
"""
Returns list of 3 elements (may be None):
[0] IoU in [0.5, 0.7) — medium
[1] IoU in [0.7, 0.85) — hard
[2] IoU in [0.85, 0.95) — very hard (boundary-level)
"""
return [
self._try_generate(gt_start, gt_end, duration, 0.5, 0.7),
self._try_generate(gt_start, gt_end, duration, 0.7, 0.85),
self._try_generate(gt_start, gt_end, duration, 0.85, 0.95),
]
# ====================================================================== #
# Verification
# ====================================================================== #
def verify_hard_sample_miner():
"""Verify mined IoU values truly fall in the declared bucket range."""
miner = IntraVideoHardSampleMiner(seed=42)
cases = [
(5.0, 15.0, 30.0),
(0.0, 10.0, 20.0),
(20.0, 30.0, 35.0),
(10.0, 25.0, 60.0),
(0.5, 1.5, 10.0),
]
all_ok = True
for gt_s, gt_e, dur in cases:
negs = miner.mine(gt_s, gt_e, dur)
print(f"\nGT: [{gt_s:.2f}, {gt_e:.2f}], duration={dur:.2f}")
for neg in negs:
if neg is None:
print(" -> generation failed (None)")
continue
actual = IntraVideoHardSampleMiner.compute_iou(gt_s, gt_e, neg["start"], neg["end"])
lo, hi = map(float, neg["bucket"].split("-"))
ok = lo <= actual < hi
if not ok:
all_ok = False
print(f" [{'PASS' if ok else 'FAIL'}] neg=[{neg['start']:.2f}, {neg['end']:.2f}] "
f"stored_iou={neg['iou_with_gt']:.4f} verified_iou={actual:.4f} "
f"bucket={neg['bucket']}")
print(f"\n{'=' * 50}")
print(f"Overall: {'ALL PASSED' if all_ok else 'SOME FAILED'}")
return all_ok
if __name__ == "__main__":
verify_hard_sample_miner()