import argparse import json import os import re import time import torch import torch.nn.functional as F from tqdm import tqdm from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration from src.vllm_inference.data import build_dataloader from src.vllm_inference.vllm_infer import vllmWrapper def get_args(): parser = argparse.ArgumentParser( description="Evaluation for training-free video temporal grounding (Single GPU Version)" ) parser.add_argument( "--datatype", default="tg", type=str, help="Specify the dataset.", choices=["tg", "mcq"], ) parser.add_argument( "--model_base", type=str, default="../pretrained_models/Qwen2.5-VL-7B-Instruct" ) parser.add_argument("--batch_size", type=int, default=1, help="Batch size") parser.add_argument( "--output_dir", type=str, default="checkpoints", help="Directory to save checkpoints", ) parser.add_argument( "--device", type=str, default="cuda:0", help="GPU device to use" ) parser.add_argument( "--pipeline_parallel_size", type=int, default=1, help="GPU nodes" ) parser.add_argument("--split", type=str, default="train", help="dataset type") parser.add_argument("--max_new_tokens", type=int, default=128) parser.add_argument("--curr_idx", type=int, default=0, help="数据分片") parser.add_argument("--total_idx", type=int, default=1, help="数据分片") parser.add_argument( "--total_pixels", type=int, default=3584 * 28 * 28, help="total_pixels" ) parser.add_argument( "--datasets", nargs="+", type=str, help="dataset names", choices=[ "charades", "activitynet", "videomme", "mvbench", "tvgbench_filter", "tvgbench", "egoschema", "tempcompass", ], ) parser.add_argument( "--use_r1_thinking_prompt", action="store_true", help="On R1 SHOUD BE TRUE!" ) parser.add_argument( "--use_vllm_inference", action="store_true" ) parser.add_argument("--prompt_type", type=str, default="r1", help="Prompt type") parser.add_argument( "--use_nothink", action="store_true", help="Use no think prompt" ) parser.add_argument( "--use_prepared_video", action="store_true", help="Use video cache in ./video_cache", ) return parser.parse_args() def build_model(args): processor = AutoProcessor.from_pretrained(args.model_base, use_fast=True) if args.datatype in ["tg"]: processor.tokenizer.padding_side = "left" if (args.datatype == "tg" or (args.datatype == "mcq" and args.split != "train")) and args.use_vllm_inference: # vllm inference model = vllmWrapper(args) else: # transformers inference model = Qwen2_5_VLForConditionalGeneration.from_pretrained( args.model_base, torch_dtype="auto", device_map=args.device, attn_implementation="flash_attention_2", ) model.eval() return model, processor @torch.no_grad() def inference(model, inputs): for key in inputs.keys(): if not isinstance(inputs[key], torch.Tensor): continue inputs[key] = inputs[key].to(model.device) logits = model(**inputs).logits bsz, seq_len, _ = logits.shape if "attention_mask" in inputs: pred_token_indices = torch.sum(inputs["attention_mask"], dim=-1) - 1 else: pred_token_indices = torch.full((bsz,), seq_len - 1, device=logits.device) pred_token_logits = logits[ torch.arange(bsz, device=logits.device), pred_token_indices, : ] return pred_token_logits def extract_answer(output_string, datatype): if datatype == "tg": matches = re.findall(r"(\d+\.?\d*) (to|and) (\d+\.?\d*)", output_string) if not matches: answer_match = re.search(r"(.*?)", output_string) if answer_match: answer_content = answer_match.group(1).strip() answer_matches = re.findall( r"(\d+\.?\d*) (to|and) (\d+\.?\d*)", answer_content ) if answer_matches: last_match = answer_matches[-1] return [float(last_match[0]), float(last_match[2])] return [None, None] last_match = matches[-1] start_time_str = last_match[0] end_time_str = last_match[2] try: start_time = float(start_time_str) end_time = float(end_time_str) return [start_time, end_time] except ValueError: return [None, None] if datatype == "mcq": matches = re.findall(r"\(([A-Z])\)", output_string) if matches: return ord(matches[-1]) - ord("A") return None @torch.no_grad() def calc_prob(logits, options_token_ids): bsz = logits.shape[0] probs = [] for i in range(bsz): logit = logits[i, options_token_ids] probs.append(F.softmax(logit, dim=1)) return probs @torch.no_grad() def main(args): os.makedirs(args.output_dir, exist_ok=True) output_file = os.path.join( args.output_dir, f"{args.datatype}_{args.curr_idx}_{args.total_idx}.jsonl" ) already_finished = set([]) f = open(output_file, "a+") try: with open(output_file, "r") as g: for line in g: old_data = json.loads(line) already_finished.add(old_data["qid"]) except Exception as e: print(e) model, processor = build_model(args) dataloader_args = { "batch_size": args.batch_size, "already_finished": already_finished, "curr_idx": args.curr_idx, "total_idx": args.total_idx, "split": args.split, "num_workers": min(8, args.batch_size), "dataset_names": args.datasets, "use_prepared_video": args.use_prepared_video, "total_pixels": args.total_pixels, "use_r1_thinking_prompt": args.use_r1_thinking_prompt, "prompt_type": args.prompt_type, "use_nothink": args.use_nothink, } dataloader = build_dataloader(processor, args.datatype, **dataloader_args) program_start_time = time.perf_counter() for batch_itm in tqdm(dataloader): if args.datatype == "tg": output_texts = model.generate( batch_itm["inputs"], max_new_tokens=args.max_new_tokens, ) targets = batch_itm["timestamps"] for i in range(len(targets)): pred = extract_answer(output_texts[i], args.datatype) f.write( json.dumps( { "qid": batch_itm["qid"][i], "pred": pred, "target": list(targets[i]), "duration": ( None if "duration" not in batch_itm else batch_itm["duration"][i] ), "output_text": output_texts[i], } ) + "\n" ) f.flush() elif args.datatype == "mcq" and args.split != "train": output_texts = model.generate( batch_itm["inputs"], max_new_tokens=args.max_new_tokens, answer_prompt=dataloader.dataset.answer_prompt, ) targets = batch_itm["answer"] for i in range(len(targets)): pred = extract_answer(output_texts[i], args.datatype) f.write( json.dumps( { "qid": batch_itm["qid"][i], "pred": None, "target": targets[i], "duration": ( None if "duration" not in batch_itm else batch_itm["duration"][i] ), "output_text": output_texts[i], } ) + "\n" ) f.flush() else: logits = inference(model, batch_itm["inputs"]) options_token_ids = [ [processor.tokenizer.vocab[word] for word in word_list] for word_list in batch_itm["options"] ] probs = calc_prob(logits, options_token_ids) for i in range(len(logits)): f.write( json.dumps( { "qid": batch_itm["qid"][i], "pred": probs[i].argmax().item(), "target": batch_itm["answer"][i], "duration": ( None if "duration" not in batch_itm else batch_itm["duration"][i] ), "probs": probs[i].cpu().tolist(), } ) + "\n" ) f.flush() # --- END TOTAL TIME & CALCULATIONS --- program_end_time = time.perf_counter() total_program_duration = program_end_time - program_start_time print("\n--- Timing Summary ---") print(f"Total program execution time: {total_program_duration:.2f} seconds") output_filename = f"{args.output_dir}/timing_summary_vllm.txt" with open(output_filename, "w", encoding="utf-8") as f: f.write("\n--- Timing Summary ---\n") f.write(f"Total program execution time: {total_program_duration:.2f} seconds\n") f.write("Another line of summary using write.\n") if __name__ == "__main__": from src.vllm_inference.utils import monkey_patch monkey_patch() args = get_args() if "mvbench" in args.datasets \ or "videomme" in args.datasets \ or "tempcompass" in args.datasets: args.datatype = "mcq" elif ( "tvgbench" in args.datasets \ or "tvgbench_filter" in args.datasets \ or "charades" in args.datasets \ or "activitynet" in args.datasets \ ): args.datatype = "tg" else: raise ValueError("Unsupported dataset type. Please check your datasets.") main(args)