"""Smoke test the SPI augmenter on a real cached sample. Verifies: - Augmented tensor has same shape as original - New solution intervals are valid (s < e, within video duration) - Forgery atom frames are intact (we sample one frame and confirm pixel match) - Disabling env var leaves data unchanged """ import json import os import sys import torch sys.path.insert(0, "/mnt/local-fast/zhangt/forensics_grpo") sys.path.insert(0, "/mnt/local-fast/zhangt/forensics_grpo/src") from open_r1.spi_aug import maybe_apply_spi def make_data(sample_dir): feats = torch.load(os.path.join(sample_dir, "video_inputs.pt"), weights_only=False) with open(os.path.join(sample_dir, "video_kwargs.json")) as f: kw = json.load(f) # Replicate the structure data_loader's __getitem__ produces. return { "video_inputs": [feats], "video_kwargs": [kw], "use_preprocessed": [True], "solution": [[16.9, 22.5]], # placeholder GT — replaced per sample } def run_one(sample_dir, gt, n_trials=5): feats = torch.load(os.path.join(sample_dir, "video_inputs.pt"), weights_only=False) video = feats[0] T = video.shape[0] with open(os.path.join(sample_dir, "video_kwargs.json")) as f: kw = json.load(f) fps = kw["fps"][0] duration = T / fps print(f"\nSample: {sample_dir}") print(f" T={T} frames, fps={fps:.3f}, duration={duration:.2f}s") print(f" GT: {gt}") # Test 1: disabled by default os.environ.pop("FORENSICS_SPI_AUG", None) data = {"video_inputs": [feats], "video_kwargs": [kw], "use_preprocessed": [True], "solution": gt} out = maybe_apply_spi(data) assert "_spi" not in out, "should not have augmented when env unset" print(f" [test 1] disabled-by-default OK") # Test 2: enabled, force prob=1 os.environ["FORENSICS_SPI_AUG"] = "true" os.environ["FORENSICS_SPI_PROB"] = "1.0" aug_count = 0 for trial in range(n_trials): data = {"video_inputs": [list(feats)], "video_kwargs": [kw], "use_preprocessed": [True], "solution": [list(s) for s in gt]} out = maybe_apply_spi(data) if out.get("_spi"): aug_count += 1 new_video = out["video_inputs"][0][0] new_sol = out["solution"] assert new_video.shape == video.shape, \ f"shape mismatch: {new_video.shape} vs {video.shape}" # Validate intervals for s, e in new_sol: assert 0 <= s < e <= duration + 0.01, f"bad interval ({s},{e})" # Trivial check: video changed (unless identity perm was randomly accepted) # Compute frame index swap signature print(f" [trial {trial}] new sol: {[(round(s,2), round(e,2)) for s,e in new_sol]}") print(f" [test 2] aug applied {aug_count}/{n_trials} trials") assert aug_count == n_trials, f"expected all {n_trials} to augment, got {aug_count}" # Test 3: forgery FRAMES preserved — pick the first new interval, slice the # frames, verify they match SOME contiguous slice in original (the forgery # atom). data = {"video_inputs": [list(feats)], "video_kwargs": [kw], "use_preprocessed": [True], "solution": [list(s) for s in gt]} out = maybe_apply_spi(data) new_video = out["video_inputs"][0][0] new_sol = out["solution"] new_s, new_e = new_sol[0] new_fs = int(round(new_s * fps)) new_fe = max(new_fs + 1, int(round(new_e * fps))) aug_slice = new_video[new_fs:new_fe] # Find a matching slice of the same length in the original. L = aug_slice.shape[0] if L == 0: print(f" [test 3] skipped (interval too short)") else: match_found = False for off in range(T - L + 1): if torch.equal(aug_slice, video[off:off + L]): match_found = True print(f" [test 3] forgery frames intact, original offset={off} (frame), " f"original time = {off/fps:.2f}s") break assert match_found, "forgery slice not found verbatim in original — frames corrupted" if __name__ == "__main__": CACHE = "/mnt/local-fast/zhangt/forensics_grpo_cache_uniform3584_fps2.0" samples = [ # (sample_dir, gt) — picked from earlier inspection (os.path.join(CACHE, "test", "vidu", "99B6U+16.90=22.50=charades@test_add@99B6U@1412@vidu"), [[16.9, 22.5]]), ] # Also pick a couple of training samples programmatically. import glob extra = glob.glob(os.path.join(CACHE, "train", "*", "*", "video_inputs.pt"))[:3] for p in extra: sd = os.path.dirname(p) # We don't have GT here without lookup; fabricate one mid-video for the test. feats = torch.load(p, weights_only=False) with open(os.path.join(sd, "video_kwargs.json")) as f: kw = json.load(f) fps = kw["fps"][0]; T = feats[0].shape[0]; dur = T / fps # Mid-video GT covering ~20% duration s = dur * 0.4; e = dur * 0.6 samples.append((sd, [[s, e]])) for sd, gt in samples: run_one(sd, gt, n_trials=3) print("\nALL TESTS PASSED")