File size: 5,186 Bytes
33569f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
"""Smoke test the SPI augmenter on a real cached sample.

Verifies:
  - Augmented tensor has same shape as original
  - New solution intervals are valid (s < e, within video duration)
  - Forgery atom frames are intact (we sample one frame and confirm pixel match)
  - Disabling env var leaves data unchanged
"""
import json
import os
import sys
import torch

sys.path.insert(0, "/mnt/local-fast/zhangt/forensics_grpo")
sys.path.insert(0, "/mnt/local-fast/zhangt/forensics_grpo/src")
from open_r1.spi_aug import maybe_apply_spi


def make_data(sample_dir):
    feats = torch.load(os.path.join(sample_dir, "video_inputs.pt"), weights_only=False)
    with open(os.path.join(sample_dir, "video_kwargs.json")) as f:
        kw = json.load(f)
    # Replicate the structure data_loader's __getitem__ produces.
    return {
        "video_inputs": [feats],
        "video_kwargs": [kw],
        "use_preprocessed": [True],
        "solution": [[16.9, 22.5]],  # placeholder GT — replaced per sample
    }


def run_one(sample_dir, gt, n_trials=5):
    feats = torch.load(os.path.join(sample_dir, "video_inputs.pt"), weights_only=False)
    video = feats[0]
    T = video.shape[0]
    with open(os.path.join(sample_dir, "video_kwargs.json")) as f:
        kw = json.load(f)
    fps = kw["fps"][0]
    duration = T / fps

    print(f"\nSample: {sample_dir}")
    print(f"  T={T} frames, fps={fps:.3f}, duration={duration:.2f}s")
    print(f"  GT: {gt}")

    # Test 1: disabled by default
    os.environ.pop("FORENSICS_SPI_AUG", None)
    data = {"video_inputs": [feats], "video_kwargs": [kw],
            "use_preprocessed": [True], "solution": gt}
    out = maybe_apply_spi(data)
    assert "_spi" not in out, "should not have augmented when env unset"
    print(f"  [test 1] disabled-by-default OK")

    # Test 2: enabled, force prob=1
    os.environ["FORENSICS_SPI_AUG"] = "true"
    os.environ["FORENSICS_SPI_PROB"] = "1.0"
    aug_count = 0
    for trial in range(n_trials):
        data = {"video_inputs": [list(feats)], "video_kwargs": [kw],
                "use_preprocessed": [True], "solution": [list(s) for s in gt]}
        out = maybe_apply_spi(data)
        if out.get("_spi"):
            aug_count += 1
            new_video = out["video_inputs"][0][0]
            new_sol = out["solution"]
            assert new_video.shape == video.shape, \
                f"shape mismatch: {new_video.shape} vs {video.shape}"
            # Validate intervals
            for s, e in new_sol:
                assert 0 <= s < e <= duration + 0.01, f"bad interval ({s},{e})"
            # Trivial check: video changed (unless identity perm was randomly accepted)
            # Compute frame index swap signature
            print(f"  [trial {trial}] new sol: {[(round(s,2), round(e,2)) for s,e in new_sol]}")
    print(f"  [test 2] aug applied {aug_count}/{n_trials} trials")
    assert aug_count == n_trials, f"expected all {n_trials} to augment, got {aug_count}"

    # Test 3: forgery FRAMES preserved — pick the first new interval, slice the
    # frames, verify they match SOME contiguous slice in original (the forgery
    # atom).
    data = {"video_inputs": [list(feats)], "video_kwargs": [kw],
            "use_preprocessed": [True], "solution": [list(s) for s in gt]}
    out = maybe_apply_spi(data)
    new_video = out["video_inputs"][0][0]
    new_sol = out["solution"]
    new_s, new_e = new_sol[0]
    new_fs = int(round(new_s * fps))
    new_fe = max(new_fs + 1, int(round(new_e * fps)))
    aug_slice = new_video[new_fs:new_fe]
    # Find a matching slice of the same length in the original.
    L = aug_slice.shape[0]
    if L == 0:
        print(f"  [test 3] skipped (interval too short)")
    else:
        match_found = False
        for off in range(T - L + 1):
            if torch.equal(aug_slice, video[off:off + L]):
                match_found = True
                print(f"  [test 3] forgery frames intact, original offset={off} (frame), "
                      f"original time = {off/fps:.2f}s")
                break
        assert match_found, "forgery slice not found verbatim in original — frames corrupted"


if __name__ == "__main__":
    CACHE = "/mnt/local-fast/zhangt/forensics_grpo_cache_uniform3584_fps2.0"
    samples = [
        # (sample_dir, gt) — picked from earlier inspection
        (os.path.join(CACHE, "test", "vidu",
         "99B6U+16.90=22.50=charades@test_add@99B6U@1412@vidu"), [[16.9, 22.5]]),
    ]
    # Also pick a couple of training samples programmatically.
    import glob
    extra = glob.glob(os.path.join(CACHE, "train", "*", "*", "video_inputs.pt"))[:3]
    for p in extra:
        sd = os.path.dirname(p)
        # We don't have GT here without lookup; fabricate one mid-video for the test.
        feats = torch.load(p, weights_only=False)
        with open(os.path.join(sd, "video_kwargs.json")) as f:
            kw = json.load(f)
        fps = kw["fps"][0]; T = feats[0].shape[0]; dur = T / fps
        # Mid-video GT covering ~20% duration
        s = dur * 0.4; e = dur * 0.6
        samples.append((sd, [[s, e]]))

    for sd, gt in samples:
        run_one(sd, gt, n_trials=3)
    print("\nALL TESTS PASSED")