| """Exploration probe: does the v10_r2 model EVER recover the missed (shorter) |
| segment on under-counted multi-seg videos if we sample at high temperature? |
| |
| Distinguishes policy/exploration bottleneck (model CAN find it, just doesn't |
| at temp=0) from a perception wall (never finds it across N samples). |
| |
| Reads /tmp/undercounted.json (sample_id, generator, gt, t0_pred, missed). |
| For each video: N high-temp samples; record every sample's parsed segments. |
| Output: <out_dir>/rank_<r>.jsonl one record per video with all N sample preds. |
| """ |
| import argparse, json, os, sys, time |
| import torch |
| from transformers import AutoProcessor, GenerationConfig, Qwen2_5_VLForConditionalGeneration |
|
|
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| from src.open_r1.reward import parse_segments |
| from src.open_r1.trainer.grpo_trainer_video_GT_soft import SYSTEM_PROMPT, get_question_template |
|
|
| CACHE = "/mnt/local-fast/zhangt/forensics_grpo_cache_uniform3584_fps2.0" |
|
|
|
|
| def get_args(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--model_path", required=True) |
| p.add_argument("--rank", type=int, default=0) |
| p.add_argument("--world_size", type=int, default=1) |
| p.add_argument("--device", type=int, default=0) |
| p.add_argument("--out_dir", default="probe_explore_v10r2") |
| p.add_argument("--n_samples", type=int, default=8) |
| p.add_argument("--temperature", type=float, default=1.0) |
| p.add_argument("--max_new_tokens", type=int, default=64) |
| return p.parse_args() |
|
|
|
|
| def main(): |
| args = get_args() |
| device = f"cuda:{args.device}" |
| os.makedirs(args.out_dir, exist_ok=True) |
| os.environ["FORENSICS_COT"] = "false" |
|
|
| videos = json.load(open("/tmp/undercounted.json")) |
| videos = [v for i, v in enumerate(videos) if i % args.world_size == args.rank] |
| print(f"[rank {args.rank}] {len(videos)} videos N={args.n_samples} temp={args.temperature}", flush=True) |
|
|
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| args.model_path, torch_dtype=torch.bfloat16, use_sliding_window=True, |
| attn_implementation="flash_attention_2", device_map=device) |
| model.eval() |
| model.config.use_cache = True |
| if hasattr(model, "generation_config"): |
| model.generation_config.use_cache = True |
| processor = AutoProcessor.from_pretrained(args.model_path) |
| question = get_question_template() |
| gen_cfg = GenerationConfig( |
| max_new_tokens=args.max_new_tokens, do_sample=True, |
| temperature=args.temperature, top_p=1.0, |
| pad_token_id=processor.tokenizer.pad_token_id, use_cache=True) |
|
|
| fout = open(os.path.join(args.out_dir, f"rank_{args.rank}.jsonl"), "w") |
| t0 = time.time(); done = 0 |
| for v in videos: |
| sd = os.path.join(CACHE, "test", v["generator"], v["sample_id"]) |
| vp = os.path.join(sd, "video_inputs.pt") |
| if not os.path.exists(vp): |
| continue |
| try: |
| video_inputs = torch.load(vp, weights_only=False) |
| video_kwargs = json.load(open(os.path.join(sd, "video_kwargs.json"))) |
| chat = [ |
| {"role": "system", "content": [{"type": "text", "text": SYSTEM_PROMPT}]}, |
| {"role": "user", "content": [ |
| {"type": "video", "video": "placeholder"}, |
| {"type": "text", "text": question}]}, |
| ] |
| text = processor.apply_chat_template(chat, tokenize=False, add_generation_prompt=True) |
| inputs = processor(text=[text], videos=[video_inputs[0]], fps=[video_kwargs["fps"][0]], |
| padding=True, return_tensors="pt", padding_side="left", |
| add_special_tokens=False) |
| inputs = {k: (val.to(device) if hasattr(val, "to") else val) for k, val in inputs.items()} |
| sample_preds = [] |
| for _ in range(args.n_samples): |
| with torch.no_grad(): |
| out_ids = model.generate(**inputs, generation_config=gen_cfg, use_cache=True) |
| gen_ids = out_ids[0][inputs["input_ids"].shape[1]:] |
| txt = processor.tokenizer.decode(gen_ids, skip_special_tokens=True) |
| sample_preds.append(parse_segments(txt)) |
| except Exception as e: |
| print(f" [skip] {v['sample_id']}: {type(e).__name__}: {e}", flush=True) |
| continue |
| rec = {"sample_id": v["sample_id"], "generator": v["generator"], |
| "gt": v["gt"], "t0_pred": v["t0_pred"], "missed": v["missed"], |
| "sample_preds": sample_preds} |
| fout.write(json.dumps(rec) + "\n"); fout.flush() |
| done += 1 |
| if done % 5 == 0: |
| r = done / max(1e-6, time.time() - t0) |
| print(f" rank={args.rank} done={done}/{len(videos)} rate={r:.2f}/s", flush=True) |
| fout.close() |
| print(f"[rank {args.rank}] DONE done={done} elapsed={time.time()-t0:.0f}s", flush=True) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|