forensics-grpo / code /test_spi_aug.py
sdzt's picture
Add source code
33569f9 verified
Raw
History Blame Contribute Delete
5.19 kB
"""Smoke test the SPI augmenter on a real cached sample.
Verifies:
- Augmented tensor has same shape as original
- New solution intervals are valid (s < e, within video duration)
- Forgery atom frames are intact (we sample one frame and confirm pixel match)
- Disabling env var leaves data unchanged
"""
import json
import os
import sys
import torch
sys.path.insert(0, "/mnt/local-fast/zhangt/forensics_grpo")
sys.path.insert(0, "/mnt/local-fast/zhangt/forensics_grpo/src")
from open_r1.spi_aug import maybe_apply_spi
def make_data(sample_dir):
feats = torch.load(os.path.join(sample_dir, "video_inputs.pt"), weights_only=False)
with open(os.path.join(sample_dir, "video_kwargs.json")) as f:
kw = json.load(f)
# Replicate the structure data_loader's __getitem__ produces.
return {
"video_inputs": [feats],
"video_kwargs": [kw],
"use_preprocessed": [True],
"solution": [[16.9, 22.5]], # placeholder GT — replaced per sample
}
def run_one(sample_dir, gt, n_trials=5):
feats = torch.load(os.path.join(sample_dir, "video_inputs.pt"), weights_only=False)
video = feats[0]
T = video.shape[0]
with open(os.path.join(sample_dir, "video_kwargs.json")) as f:
kw = json.load(f)
fps = kw["fps"][0]
duration = T / fps
print(f"\nSample: {sample_dir}")
print(f" T={T} frames, fps={fps:.3f}, duration={duration:.2f}s")
print(f" GT: {gt}")
# Test 1: disabled by default
os.environ.pop("FORENSICS_SPI_AUG", None)
data = {"video_inputs": [feats], "video_kwargs": [kw],
"use_preprocessed": [True], "solution": gt}
out = maybe_apply_spi(data)
assert "_spi" not in out, "should not have augmented when env unset"
print(f" [test 1] disabled-by-default OK")
# Test 2: enabled, force prob=1
os.environ["FORENSICS_SPI_AUG"] = "true"
os.environ["FORENSICS_SPI_PROB"] = "1.0"
aug_count = 0
for trial in range(n_trials):
data = {"video_inputs": [list(feats)], "video_kwargs": [kw],
"use_preprocessed": [True], "solution": [list(s) for s in gt]}
out = maybe_apply_spi(data)
if out.get("_spi"):
aug_count += 1
new_video = out["video_inputs"][0][0]
new_sol = out["solution"]
assert new_video.shape == video.shape, \
f"shape mismatch: {new_video.shape} vs {video.shape}"
# Validate intervals
for s, e in new_sol:
assert 0 <= s < e <= duration + 0.01, f"bad interval ({s},{e})"
# Trivial check: video changed (unless identity perm was randomly accepted)
# Compute frame index swap signature
print(f" [trial {trial}] new sol: {[(round(s,2), round(e,2)) for s,e in new_sol]}")
print(f" [test 2] aug applied {aug_count}/{n_trials} trials")
assert aug_count == n_trials, f"expected all {n_trials} to augment, got {aug_count}"
# Test 3: forgery FRAMES preserved — pick the first new interval, slice the
# frames, verify they match SOME contiguous slice in original (the forgery
# atom).
data = {"video_inputs": [list(feats)], "video_kwargs": [kw],
"use_preprocessed": [True], "solution": [list(s) for s in gt]}
out = maybe_apply_spi(data)
new_video = out["video_inputs"][0][0]
new_sol = out["solution"]
new_s, new_e = new_sol[0]
new_fs = int(round(new_s * fps))
new_fe = max(new_fs + 1, int(round(new_e * fps)))
aug_slice = new_video[new_fs:new_fe]
# Find a matching slice of the same length in the original.
L = aug_slice.shape[0]
if L == 0:
print(f" [test 3] skipped (interval too short)")
else:
match_found = False
for off in range(T - L + 1):
if torch.equal(aug_slice, video[off:off + L]):
match_found = True
print(f" [test 3] forgery frames intact, original offset={off} (frame), "
f"original time = {off/fps:.2f}s")
break
assert match_found, "forgery slice not found verbatim in original — frames corrupted"
if __name__ == "__main__":
CACHE = "/mnt/local-fast/zhangt/forensics_grpo_cache_uniform3584_fps2.0"
samples = [
# (sample_dir, gt) — picked from earlier inspection
(os.path.join(CACHE, "test", "vidu",
"99B6U+16.90=22.50=charades@test_add@99B6U@1412@vidu"), [[16.9, 22.5]]),
]
# Also pick a couple of training samples programmatically.
import glob
extra = glob.glob(os.path.join(CACHE, "train", "*", "*", "video_inputs.pt"))[:3]
for p in extra:
sd = os.path.dirname(p)
# We don't have GT here without lookup; fabricate one mid-video for the test.
feats = torch.load(p, weights_only=False)
with open(os.path.join(sd, "video_kwargs.json")) as f:
kw = json.load(f)
fps = kw["fps"][0]; T = feats[0].shape[0]; dur = T / fps
# Mid-video GT covering ~20% duration
s = dur * 0.4; e = dur * 0.6
samples.append((sd, [[s, e]]))
for sd, gt in samples:
run_one(sd, gt, n_trials=3)
print("\nALL TESTS PASSED")