File size: 5,001 Bytes
33569f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 | """
Intra-video hard negative sample miner for GRPO training.
For each ground truth temporal segment, constructs 3 hard negatives:
- One with IoU in [0.5, 0.7) — medium difficulty
- One with IoU in [0.7, 0.85) — hard
- One with IoU in [0.85, 0.95) — very hard (boundary-level)
Perturbation strategies: shift / expand / shrink / asymmetric (random choice).
max_attempts=50 per bucket; returns None on failure (no random fill).
"""
import random
class IntraVideoHardSampleMiner:
def __init__(self, max_attempts=50, seed=None):
self.max_attempts = max_attempts
self.rng = random.Random(seed)
# ------------------------------------------------------------------ #
# IoU helper
# ------------------------------------------------------------------ #
@staticmethod
def compute_iou(s1, e1, s2, e2):
inter = max(0.0, min(e1, e2) - max(s1, s2))
union = max(e1, e2) - min(s1, s2)
if union <= 0:
return 0.0
return inter / union
# ------------------------------------------------------------------ #
# Four perturbation strategies
# ------------------------------------------------------------------ #
def _perturb_shift(self, gt_s, gt_e, dur):
L = gt_e - gt_s
d = self.rng.uniform(-L * 0.6, L * 0.6)
ns, ne = max(0, gt_s + d), min(dur, gt_e + d)
return (ns, ne) if ns < ne else (None, None)
def _perturb_expand(self, gt_s, gt_e, dur):
L = gt_e - gt_s
left = self.rng.uniform(0, 0.8) * L
right = self.rng.uniform(0, 0.8) * L
return max(0, gt_s - left), min(dur, gt_e + right)
def _perturb_shrink(self, gt_s, gt_e, dur):
L = gt_e - gt_s
ls = self.rng.uniform(0, 0.45) * L
rs = self.rng.uniform(0, 0.45) * L
ns, ne = gt_s + ls, gt_e - rs
return (ns, ne) if ns < ne else (None, None)
def _perturb_asymmetric(self, gt_s, gt_e, dur):
L = gt_e - gt_s
ds = self.rng.uniform(-0.6, 0.6) * L
de = self.rng.uniform(-0.6, 0.6) * L
ns, ne = max(0, gt_s + ds), min(dur, gt_e + de)
return (ns, ne) if ns < ne else (None, None)
# ------------------------------------------------------------------ #
# Core generation logic
# ------------------------------------------------------------------ #
def _try_generate(self, gt_s, gt_e, dur, iou_lo, iou_hi):
fns = [self._perturb_shift, self._perturb_expand,
self._perturb_shrink, self._perturb_asymmetric]
for _ in range(self.max_attempts):
fn = self.rng.choice(fns)
ns, ne = fn(gt_s, gt_e, dur)
if ns is None:
continue
iou = self.compute_iou(gt_s, gt_e, ns, ne)
if iou_lo <= iou < iou_hi:
return {
"start": round(ns, 2),
"end": round(ne, 2),
"iou_with_gt": round(iou, 4),
"bucket": f"{iou_lo}-{iou_hi}",
}
return None
def mine(self, gt_start, gt_end, duration):
"""
Returns list of 3 elements (may be None):
[0] IoU in [0.5, 0.7) — medium
[1] IoU in [0.7, 0.85) — hard
[2] IoU in [0.85, 0.95) — very hard (boundary-level)
"""
return [
self._try_generate(gt_start, gt_end, duration, 0.5, 0.7),
self._try_generate(gt_start, gt_end, duration, 0.7, 0.85),
self._try_generate(gt_start, gt_end, duration, 0.85, 0.95),
]
# ====================================================================== #
# Verification
# ====================================================================== #
def verify_hard_sample_miner():
"""Verify mined IoU values truly fall in the declared bucket range."""
miner = IntraVideoHardSampleMiner(seed=42)
cases = [
(5.0, 15.0, 30.0),
(0.0, 10.0, 20.0),
(20.0, 30.0, 35.0),
(10.0, 25.0, 60.0),
(0.5, 1.5, 10.0),
]
all_ok = True
for gt_s, gt_e, dur in cases:
negs = miner.mine(gt_s, gt_e, dur)
print(f"\nGT: [{gt_s:.2f}, {gt_e:.2f}], duration={dur:.2f}")
for neg in negs:
if neg is None:
print(" -> generation failed (None)")
continue
actual = IntraVideoHardSampleMiner.compute_iou(gt_s, gt_e, neg["start"], neg["end"])
lo, hi = map(float, neg["bucket"].split("-"))
ok = lo <= actual < hi
if not ok:
all_ok = False
print(f" [{'PASS' if ok else 'FAIL'}] neg=[{neg['start']:.2f}, {neg['end']:.2f}] "
f"stored_iou={neg['iou_with_gt']:.4f} verified_iou={actual:.4f} "
f"bucket={neg['bucket']}")
print(f"\n{'=' * 50}")
print(f"Overall: {'ALL PASSED' if all_ok else 'SOME FAILED'}")
return all_ok
if __name__ == "__main__":
verify_hard_sample_miner()
|