File size: 4,910 Bytes
33569f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""Exploration probe: does the v10_r2 model EVER recover the missed (shorter)
segment on under-counted multi-seg videos if we sample at high temperature?

Distinguishes policy/exploration bottleneck (model CAN find it, just doesn't
at temp=0) from a perception wall (never finds it across N samples).

Reads /tmp/undercounted.json (sample_id, generator, gt, t0_pred, missed).
For each video: N high-temp samples; record every sample's parsed segments.
Output: <out_dir>/rank_<r>.jsonl  one record per video with all N sample preds.
"""
import argparse, json, os, sys, time
import torch
from transformers import AutoProcessor, GenerationConfig, Qwen2_5_VLForConditionalGeneration

sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
from src.open_r1.reward import parse_segments
from src.open_r1.trainer.grpo_trainer_video_GT_soft import SYSTEM_PROMPT, get_question_template

CACHE = "/mnt/local-fast/zhangt/forensics_grpo_cache_uniform3584_fps2.0"


def get_args():
    p = argparse.ArgumentParser()
    p.add_argument("--model_path", required=True)
    p.add_argument("--rank", type=int, default=0)
    p.add_argument("--world_size", type=int, default=1)
    p.add_argument("--device", type=int, default=0)
    p.add_argument("--out_dir", default="probe_explore_v10r2")
    p.add_argument("--n_samples", type=int, default=8)
    p.add_argument("--temperature", type=float, default=1.0)
    p.add_argument("--max_new_tokens", type=int, default=64)
    return p.parse_args()


def main():
    args = get_args()
    device = f"cuda:{args.device}"
    os.makedirs(args.out_dir, exist_ok=True)
    os.environ["FORENSICS_COT"] = "false"

    videos = json.load(open("/tmp/undercounted.json"))
    videos = [v for i, v in enumerate(videos) if i % args.world_size == args.rank]
    print(f"[rank {args.rank}] {len(videos)} videos  N={args.n_samples} temp={args.temperature}", flush=True)

    model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
        args.model_path, torch_dtype=torch.bfloat16, use_sliding_window=True,
        attn_implementation="flash_attention_2", device_map=device)
    model.eval()
    model.config.use_cache = True
    if hasattr(model, "generation_config"):
        model.generation_config.use_cache = True
    processor = AutoProcessor.from_pretrained(args.model_path)
    question = get_question_template()
    gen_cfg = GenerationConfig(
        max_new_tokens=args.max_new_tokens, do_sample=True,
        temperature=args.temperature, top_p=1.0,
        pad_token_id=processor.tokenizer.pad_token_id, use_cache=True)

    fout = open(os.path.join(args.out_dir, f"rank_{args.rank}.jsonl"), "w")
    t0 = time.time(); done = 0
    for v in videos:
        sd = os.path.join(CACHE, "test", v["generator"], v["sample_id"])
        vp = os.path.join(sd, "video_inputs.pt")
        if not os.path.exists(vp):
            continue
        try:
            video_inputs = torch.load(vp, weights_only=False)
            video_kwargs = json.load(open(os.path.join(sd, "video_kwargs.json")))
            chat = [
                {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]},
                {"role": "user", "content": [
                    {"type": "video", "video": "placeholder"},
                    {"type": "text", "text": question}]},
            ]
            text = processor.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
            inputs = processor(text=[text], videos=[video_inputs[0]], fps=[video_kwargs["fps"][0]],
                               padding=True, return_tensors="pt", padding_side="left",
                               add_special_tokens=False)
            inputs = {k: (val.to(device) if hasattr(val, "to") else val) for k, val in inputs.items()}
            sample_preds = []
            for _ in range(args.n_samples):
                with torch.no_grad():
                    out_ids = model.generate(**inputs, generation_config=gen_cfg, use_cache=True)
                gen_ids = out_ids[0][inputs["input_ids"].shape[1]:]
                txt = processor.tokenizer.decode(gen_ids, skip_special_tokens=True)
                sample_preds.append(parse_segments(txt))
        except Exception as e:
            print(f"  [skip] {v['sample_id']}: {type(e).__name__}: {e}", flush=True)
            continue
        rec = {"sample_id": v["sample_id"], "generator": v["generator"],
               "gt": v["gt"], "t0_pred": v["t0_pred"], "missed": v["missed"],
               "sample_preds": sample_preds}
        fout.write(json.dumps(rec) + "\n"); fout.flush()
        done += 1
        if done % 5 == 0:
            r = done / max(1e-6, time.time() - t0)
            print(f"  rank={args.rank} done={done}/{len(videos)} rate={r:.2f}/s", flush=True)
    fout.close()
    print(f"[rank {args.rank}] DONE done={done} elapsed={time.time()-t0:.0f}s", flush=True)


if __name__ == "__main__":
    main()