File size: 11,633 Bytes
33569f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
"""Evaluate a trained forensics-GRPO model on the AF test split.

Adapted from evaluate.py (single-span Charades-style grounding) for:
- multi-segment localisation (list of (s, e) tuples per video)
- the forensics CoT prompt template (FORENSICS_COT toggle preserved)
- cached video_inputs.pt to avoid re-decoding
- multi-GPU sharding (one process per device)
- multiple matching metrics: soft_F1, mean_F1@{0.5,0.75,0.85,0.95}, hungarian_IoU

Output:
  <out_dir>/rank_<r>.jsonl     one record per evaluated test video on this rank
  <out_dir>/summary.json       aggregate metrics (overall + per-generator)
"""
import argparse
import json
import os
import random
import sys
import time

import torch
from transformers import (
    AutoProcessor,
    GenerationConfig,
    Qwen2_5_VLForConditionalGeneration,
)

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.open_r1.data_loader import TEST_GENERATORS, build_examples
from src.open_r1.reward import (
    hungarian_iou_reward,
    mean_f1_at_tiou,
    parse_segments,
    soft_f1,
)
from src.open_r1.trainer.grpo_trainer_video_GT_soft import (
    SYSTEM_PROMPT,
    get_question_template,
)
from src.open_r1.verifier import (
    ForensicsVerifier,
    format_verifier_scores,
    sample_id_from_video_path,
)

ANNOT = "/mnt/local-fast/zhangt/annot/annot"
VROOT = "/mnt/local-fast/zhangt/video"
CACHE = "/mnt/local-fast/zhangt/forensics_grpo_cache_uniform3584_fps2.0"


def get_args():
    p = argparse.ArgumentParser()
    p.add_argument("--model_path", required=True)
    p.add_argument("--rank", type=int, default=0)
    p.add_argument("--world_size", type=int, default=1)
    p.add_argument("--device", type=int, default=0,
                   help="cuda device index (set CUDA_VISIBLE_DEVICES to pin physical GPU)")
    p.add_argument("--out_dir", default="eval_outputs/stage2_verifier_grounded")
    p.add_argument("--cot", choices=["true", "false"], default="true",
                   help="Use CoT prompt template ('true') or no-CoT ('false').")
    p.add_argument("--cot_variant", choices=["descriptive", "counterfactual", "counterfactual_parsimonious"], default="descriptive",
                   help="CoT prompt variant; must match the variant used at training time.")
    p.add_argument("--verifier_context", choices=["true", "false"], default="false",
                   help="If true, inject external verifier per-second scores into the prompt.")
    p.add_argument("--verifier_ckpt", default="/mnt/local-fast/zhangt/forensics_verifier_clip_l14/verifier_temporal_best.pt")
    p.add_argument("--verifier_cache", default="/mnt/local-fast/zhangt/forensics_verifier_clip_l14")
    p.add_argument("--max_new_tokens", type=int, default=640)
    p.add_argument("--temperature", type=float, default=0.0,
                   help="Greedy if 0 else sample with this temp.")
    p.add_argument("--limit", type=int, default=0, help="Cap number of videos per rank (0=all)")
    return p.parse_args()


def load_cached(sample_dir):
    feats = torch.load(os.path.join(sample_dir, "video_inputs.pt"), weights_only=False)
    with open(os.path.join(sample_dir, "video_kwargs.json"), "r") as f:
        kwargs = json.load(f)
    return feats, kwargs


def main():
    args = get_args()
    device = f"cuda:{args.device}"
    os.makedirs(args.out_dir, exist_ok=True)

    # Toggle CoT/no-CoT prompt template via env var (the template fn reads it).
    os.environ["FORENSICS_COT"] = args.cot
    os.environ["FORENSICS_COT_VARIANT"] = args.cot_variant

    print(f"[rank {args.rank}/{args.world_size}] device={device}  model={args.model_path}", flush=True)
    print(f"  cot={args.cot}  cot_variant={args.cot_variant}  max_new_tokens={args.max_new_tokens}  temp={args.temperature}", flush=True)

    t0 = time.time()
    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        args.model_path,
        torch_dtype=torch.bfloat16,
        use_sliding_window=True,
        attn_implementation="flash_attention_2",
        device_map=device,
    )
    model.eval()
    processor = AutoProcessor.from_pretrained(args.model_path)
    print(f"  loaded model+processor in {time.time()-t0:.1f}s", flush=True)

    # Build all test examples, then shard.
    examples = build_examples(
        annot_dir=ANNOT, video_root=VROOT, generators=TEST_GENERATORS,
        split_prefix="test", preprocessed_data_path=CACHE, require_video_exists=True,
    )
    examples = [ex for i, ex in enumerate(examples) if i % args.world_size == args.rank]
    if args.limit > 0:
        examples = examples[: args.limit]
    print(f"  rank {args.rank} processes {len(examples)} test videos", flush=True)

    question = get_question_template()  # CoT or no-CoT depending on env var
    # v14 generator-conditional eval mode (env var FORENSICS_GENCOND_MODE):
    #   none    — plain prompt (default; deployment-equivalent, the headline ship number)
    #   matched — random 1[FORENSICS_GENCOND_PROB] prepend correct gen name, else generic
    #             (mirrors the training distribution; this is the "what the model
    #              was optimized for" reading)
    #   correct — always prepend the true generator name (oracle upper bound)
    #   wrong   — always prepend a deterministic *other* generator name (token control)
    gencond_mode = os.getenv("FORENSICS_GENCOND_MODE", "none").lower()
    if gencond_mode not in ("none", "matched", "correct", "wrong"):
        raise ValueError(f"FORENSICS_GENCOND_MODE must be none|matched|correct|wrong, got {gencond_mode!r}")
    # matched-mode prob MUST equal the FORENSICS_GENCOND_PROB used at training
    # time (the 3-way driver script enforces this by exporting both from the
    # same value).
    gencond_prob = float(os.getenv("FORENSICS_GENCOND_PROB", "0.5"))
    # Deterministic per-rank RNGs so reruns of the same mode produce identical
    # prompts (wrong: which other gen; matched: which samples are conditioned).
    wrong_rng = random.Random(0xC0FFEE + args.rank)
    matched_rng = random.Random(0xBEEF00 + args.rank)
    print(f"  FORENSICS_GENCOND_MODE = {gencond_mode}"
          + (f"   PROB = {gencond_prob}" if gencond_mode == "matched" else ""), flush=True)
    gen_cfg = GenerationConfig(
        max_new_tokens=args.max_new_tokens,
        do_sample=args.temperature > 0,
        temperature=max(args.temperature, 1e-6),
        pad_token_id=processor.tokenizer.pad_token_id,
        use_cache=True,
    )
    # Force use_cache=True in 3 places (HF merges model defaults into GenerationConfig and
    # silently overrides use_cache=False when gradient_checkpointing was on at train time).
    model.config.use_cache = True
    if hasattr(model, "generation_config"):
        model.generation_config.use_cache = True

    out_path = os.path.join(args.out_dir, f"rank_{args.rank}.jsonl")
    fout = open(out_path, "w")
    t_start = time.time()
    done = failed = 0
    for ex in examples:
        sample_id = os.path.splitext(os.path.basename(ex["video_path"]))[0]
        sample_dir = os.path.join(CACHE, "test", ex["generator"], sample_id)
        if not os.path.exists(os.path.join(sample_dir, "video_inputs.pt")):
            failed += 1
            continue

        try:
            video_inputs, video_kwargs = load_cached(sample_dir)
            # Per-sample prompt: optionally prepend a generator-name sentence to
            # exercise the v14 generator-conditional pathway. The prepend string
            # is byte-identical to the trainer's injection (single source of truth
            # is grpo_trainer_video_GT_soft.py:make_conversation_video).
            if gencond_mode == "correct":
                q_text = f"The forged segments in this video were generated by {ex['generator']}. " + question
            elif gencond_mode == "wrong":
                others = [g for g in TEST_GENERATORS if g != ex["generator"]]
                wrong_gen = wrong_rng.choice(others)
                q_text = f"The forged segments in this video were generated by {wrong_gen}. " + question
            elif gencond_mode == "matched":
                if matched_rng.random() < gencond_prob:
                    q_text = f"The forged segments in this video were generated by {ex['generator']}. " + question
                else:
                    q_text = question
            else:
                q_text = question
            # Build chat-template prompt BYTE-IDENTICAL to the trainer's
            # make_conversation_video (grpo_trainer_video_GT_soft.py:664-677):
            #   - system content is a plain string (NOT a list of blocks)
            #   - video block carries the same max_pixels/min_pixels/fps/max_frames
            # The video tensor is supplied separately via processor(videos=[...]);
            # the "video" key value is unused by apply_chat_template, but kwargs
            # are preserved verbatim to remove any divergence risk.
            chat = [
                {"role": "system", "content": SYSTEM_PROMPT},
                {"role": "user", "content": [
                        {"type": "video",
                         "video": ex["video_path"],
                         "max_pixels": 3584 * 28 * 28,
                         "min_pixels": 200704,
                         "fps": 2.0,
                         "max_frames": 64,
                        },
                        {"type": "text", "text": q_text},
                    ],
                },
            ]
            text = processor.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)

            inputs = processor(
                text=[text],
                videos=[video_inputs[0]],
                fps=[video_kwargs["fps"][0]],
                padding=True,
                return_tensors="pt",
                padding_side="left",
                add_special_tokens=False,
            )
            inputs = {k: v.to(device) if hasattr(v, "to") else v for k, v in inputs.items()}

            with torch.no_grad():
                out_ids = model.generate(**inputs, generation_config=gen_cfg, use_cache=True)
            gen_ids = out_ids[0][inputs["input_ids"].shape[1]:]
            output_text = processor.tokenizer.decode(gen_ids, skip_special_tokens=True)
        except Exception as e:
            failed += 1
            print(f"  [skip] {sample_id}: {type(e).__name__}: {e}", flush=True)
            continue

        # Score.
        pred = parse_segments(output_text)
        gt = [tuple(s) for s in ex["solution"]]
        sf = soft_f1(pred, gt)
        mf = mean_f1_at_tiou(pred, gt)
        hg = hungarian_iou_reward(pred, gt)

        rec = {
            "sample_id": sample_id,
            "generator": ex["generator"],
            "gt": gt,
            "pred": pred,
            "output_text": output_text,
            "soft_F1": sf,
            "mean_F1_tIoU": mf,
            "hungarian_iou": hg,
            "n_pred": len(pred),
            "n_gt": len(gt),
            "parse_failed": len(pred) == 0,
        }
        fout.write(json.dumps(rec) + "\n")
        fout.flush()
        done += 1

        if done % 20 == 0:
            elapsed = time.time() - t_start
            rate = done / max(1e-6, elapsed)
            remaining = (len(examples) - done - failed) / max(1e-6, rate)
            print(
                f"  rank={args.rank} done={done} fail={failed} "
                f"rate={rate:.2f}/s eta={remaining/60:.1f}min",
                flush=True,
            )

    fout.close()
    print(f"[rank {args.rank}] DONE  done={done} failed={failed} elapsed={time.time()-t_start:.0f}s", flush=True)


if __name__ == "__main__":
    main()