| """Smoke test the SPI augmenter on a real cached sample. |
| |
| Verifies: |
| - Augmented tensor has same shape as original |
| - New solution intervals are valid (s < e, within video duration) |
| - Forgery atom frames are intact (we sample one frame and confirm pixel match) |
| - Disabling env var leaves data unchanged |
| """ |
| import json |
| import os |
| import sys |
| import torch |
|
|
| sys.path.insert(0, "/mnt/local-fast/zhangt/forensics_grpo") |
| sys.path.insert(0, "/mnt/local-fast/zhangt/forensics_grpo/src") |
| from open_r1.spi_aug import maybe_apply_spi |
|
|
|
|
| def make_data(sample_dir): |
| feats = torch.load(os.path.join(sample_dir, "video_inputs.pt"), weights_only=False) |
| with open(os.path.join(sample_dir, "video_kwargs.json")) as f: |
| kw = json.load(f) |
| |
| return { |
| "video_inputs": [feats], |
| "video_kwargs": [kw], |
| "use_preprocessed": [True], |
| "solution": [[16.9, 22.5]], |
| } |
|
|
|
|
| def run_one(sample_dir, gt, n_trials=5): |
| feats = torch.load(os.path.join(sample_dir, "video_inputs.pt"), weights_only=False) |
| video = feats[0] |
| T = video.shape[0] |
| with open(os.path.join(sample_dir, "video_kwargs.json")) as f: |
| kw = json.load(f) |
| fps = kw["fps"][0] |
| duration = T / fps |
|
|
| print(f"\nSample: {sample_dir}") |
| print(f" T={T} frames, fps={fps:.3f}, duration={duration:.2f}s") |
| print(f" GT: {gt}") |
|
|
| |
| os.environ.pop("FORENSICS_SPI_AUG", None) |
| data = {"video_inputs": [feats], "video_kwargs": [kw], |
| "use_preprocessed": [True], "solution": gt} |
| out = maybe_apply_spi(data) |
| assert "_spi" not in out, "should not have augmented when env unset" |
| print(f" [test 1] disabled-by-default OK") |
|
|
| |
| os.environ["FORENSICS_SPI_AUG"] = "true" |
| os.environ["FORENSICS_SPI_PROB"] = "1.0" |
| aug_count = 0 |
| for trial in range(n_trials): |
| data = {"video_inputs": [list(feats)], "video_kwargs": [kw], |
| "use_preprocessed": [True], "solution": [list(s) for s in gt]} |
| out = maybe_apply_spi(data) |
| if out.get("_spi"): |
| aug_count += 1 |
| new_video = out["video_inputs"][0][0] |
| new_sol = out["solution"] |
| assert new_video.shape == video.shape, \ |
| f"shape mismatch: {new_video.shape} vs {video.shape}" |
| |
| for s, e in new_sol: |
| assert 0 <= s < e <= duration + 0.01, f"bad interval ({s},{e})" |
| |
| |
| print(f" [trial {trial}] new sol: {[(round(s,2), round(e,2)) for s,e in new_sol]}") |
| print(f" [test 2] aug applied {aug_count}/{n_trials} trials") |
| assert aug_count == n_trials, f"expected all {n_trials} to augment, got {aug_count}" |
|
|
| |
| |
| |
| data = {"video_inputs": [list(feats)], "video_kwargs": [kw], |
| "use_preprocessed": [True], "solution": [list(s) for s in gt]} |
| out = maybe_apply_spi(data) |
| new_video = out["video_inputs"][0][0] |
| new_sol = out["solution"] |
| new_s, new_e = new_sol[0] |
| new_fs = int(round(new_s * fps)) |
| new_fe = max(new_fs + 1, int(round(new_e * fps))) |
| aug_slice = new_video[new_fs:new_fe] |
| |
| L = aug_slice.shape[0] |
| if L == 0: |
| print(f" [test 3] skipped (interval too short)") |
| else: |
| match_found = False |
| for off in range(T - L + 1): |
| if torch.equal(aug_slice, video[off:off + L]): |
| match_found = True |
| print(f" [test 3] forgery frames intact, original offset={off} (frame), " |
| f"original time = {off/fps:.2f}s") |
| break |
| assert match_found, "forgery slice not found verbatim in original — frames corrupted" |
|
|
|
|
| if __name__ == "__main__": |
| CACHE = "/mnt/local-fast/zhangt/forensics_grpo_cache_uniform3584_fps2.0" |
| samples = [ |
| |
| (os.path.join(CACHE, "test", "vidu", |
| "99B6U+16.90=22.50=charades@test_add@99B6U@1412@vidu"), [[16.9, 22.5]]), |
| ] |
| |
| import glob |
| extra = glob.glob(os.path.join(CACHE, "train", "*", "*", "video_inputs.pt"))[:3] |
| for p in extra: |
| sd = os.path.dirname(p) |
| |
| feats = torch.load(p, weights_only=False) |
| with open(os.path.join(sd, "video_kwargs.json")) as f: |
| kw = json.load(f) |
| fps = kw["fps"][0]; T = feats[0].shape[0]; dur = T / fps |
| |
| s = dur * 0.4; e = dur * 0.6 |
| samples.append((sd, [[s, e]])) |
|
|
| for sd, gt in samples: |
| run_one(sd, gt, n_trials=3) |
| print("\nALL TESTS PASSED") |
|
|