from data_configs import DATASETS import argparse import numpy as np import json from tqdm import tqdm import os import re import pickle import torch from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor from qwen_vl_utils import process_vision_info import random VIDEO_INFO_CACHE = {} def get_args(): parser = argparse.ArgumentParser(description='Evaluation for training-free video temporal grounding (Single GPU Version)') parser.add_argument('--dataset', default='charades', type=str, help='Specify the dataset.') parser.add_argument('--split', default='default', type=str, help='Specify the split.') parser.add_argument("--model_base", type=str, default="/path/to/qwen-model") parser.add_argument("--batch_size", type=int, default=1, help="Batch size") parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory to save checkpoints") parser.add_argument("--resume", action="store_true", help="Resume from checkpoint") parser.add_argument("--device", type=str, default="cuda:0", help="GPU device to use") return parser.parse_args() def calc_iou(candidates, gt): start, end = candidates[:,0], candidates[:,1] s, e = gt[0], gt[1] inter = np.minimum(end, e) - np.maximum(start, s) union = np.maximum(end, e) - np.minimum(start, s) return inter.clip(min=0) / union def cached_process_vision_info(messages, return_video_kwargs=False): global VIDEO_INFO_CACHE video_path = None for msg in messages: for content in msg.get('content', []): if isinstance(content, dict) and 'video' in content: video_path = content['video'] break cache_key = f"{video_path}_{return_video_kwargs}" if cache_key in VIDEO_INFO_CACHE: return VIDEO_INFO_CACHE[cache_key] result = process_vision_info(messages, return_video_kwargs=return_video_kwargs) VIDEO_INFO_CACHE[cache_key] = result return result def inference(video_path, prompt, model, processor, max_new_tokens=2048, device="cuda:0"): messages = [ {"role": "system", "content": "You are a video analysis expert."}, {"role": "user", "content": [ {"type": "video", "video": video_path, "total_pixels": 3584 * 28 * 28, "min_pixels": 200704, }, {"type": "text", "text": prompt}, ] }, ] text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) image_inputs, video_inputs, video_kwargs = cached_process_vision_info(messages, return_video_kwargs=True) fps_inputs = video_kwargs['fps'] inputs = processor(text=[text], images=image_inputs, videos=video_inputs, fps=fps_inputs, padding=True, return_tensors="pt") inputs = inputs.to(device) with torch.no_grad(): output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens) generated_ids = [output_ids[i][len(inputs.input_ids[i]):] for i in range(len(output_ids))] output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) return output_text[0] def parse_timestamp_output(output_string): matches = re.findall(r"(\d+\.?\d*) (to|and) (\d+\.?\d*)", output_string) if not matches: answer_match = re.search(r"(.*?)", output_string) if answer_match: answer_content = answer_match.group(1).strip() answer_matches = re.findall(r"(\d+\.?\d*) (to|and) (\d+\.?\d*)", answer_content) if answer_matches: last_match = answer_matches[-1] return float(last_match[0]), float(last_match[2]) return None, None last_match = matches[-1] start_time_str = last_match[0] end_time_str = last_match[2] try: start_time = float(start_time_str) end_time = float(end_time_str) return start_time, end_time except ValueError: return None, None # GROUND_TEMPLATE = """To accurately pinpoint the event "[EVENT]" in the video, determine the precise time period of the event. # Output your thought process within the tags, including analysis with either specific timestamps (xx.xx) or time ranges (xx.xx to xx.xx) in tags. # Then, provide the start and end times (in seconds, precise to two decimal places) in the format "start time to end time" within the tags. For example: 12.54 to 17.83.""" GROUND_TEMPLATE = """To accurately pinpoint the event "[EVENT]" in the video, determine the precise time period of the event. Provide the start and end times (in seconds, precise to two decimal places) in the format "start time to end time" within the tags. For example: "12.54 to 17.83" .""" def create_work_items(data): work_items = [] for vid, ann in data.items(): for i in range(len(ann['sentences'])): work_items.append({ 'vid': vid, 'ann': ann, 'sentence_idx': i }) # 随机打乱列表 random.shuffle(work_items) return work_items def setup_model(model_base, device): print(f"Setting up model on device {device}") model = Qwen2_5_VLForConditionalGeneration.from_pretrained( model_base, torch_dtype=torch.bfloat16, use_sliding_window=True, attn_implementation="flash_attention_2", device_map=device ) processor = AutoProcessor.from_pretrained(model_base) return model, processor def get_checkpoint_path(checkpoint_dir): os.makedirs(checkpoint_dir, exist_ok=True) return os.path.join(checkpoint_dir, "checkpoint.pkl") def load_checkpoint(checkpoint_path): if os.path.exists(checkpoint_path): try: with open(checkpoint_path, 'rb') as f: return pickle.load(f) except Exception as e: print(f"Error loading checkpoint: {e}") return {'processed_items': set(), 'ious': [], 'recall': np.array([0, 0, 0])} def save_checkpoint(checkpoint_path, state): with open(checkpoint_path, 'wb') as f: pickle.dump(state, f) def process_work_items(work_items, video_dir_path, model_base, device, checkpoint_dir, resume=False): ious = [] thresh = np.array([0.3, 0.5, 0.7]) recall = np.array([0, 0, 0]) # 加载检查点(如果需要恢复) checkpoint_path = get_checkpoint_path(checkpoint_dir) processed_items = set() if resume and os.path.exists(checkpoint_path): checkpoint = load_checkpoint(checkpoint_path) processed_items = checkpoint['processed_items'] ious = checkpoint['ious'] recall = checkpoint['recall'] print(f"Resuming from checkpoint with {len(processed_items)} processed items") model, processor = setup_model(model_base, device) item_ids = [f"{item['vid']}_{item['sentence_idx']}" for item in work_items] remaining_items = [(i, item) for i, (item, item_id) in enumerate(zip(work_items, item_ids)) if not resume or item_id not in processed_items] if not remaining_items: print("All items already processed") return ious, recall print(f"Processing {len(remaining_items)} out of {len(work_items)} items") pbar = tqdm(remaining_items) for idx, (_, item) in enumerate(pbar): vid = item['vid'] ann = item['ann'] sentence_idx = item['sentence_idx'] item_id = f"{vid}_{sentence_idx}" sentence = ann['sentences'][sentence_idx].strip().lower() if sentence.endswith("."): sentence = sentence[:-1] prompt = GROUND_TEMPLATE.replace('[EVENT]', sentence) # 确定视频路径 duration = ann['duration'] if 'duration' in ann else ann['video_duration'] video_path = None for ext in ['mp4', 'mkv', 'webm']: path = os.path.join(video_dir_path, f"{vid}.{ext}") if os.path.isfile(path): video_path = path break # 处理视频 if video_path: try: ans = inference(video_path, prompt, model, processor, device=device) # print('prompt', prompt) # print('ans', ans) sp, ep = parse_timestamp_output(ans) print(f"Parsed times: {sp}, {ep}") print(f"Ground truth: {ann['timestamps'][sentence_idx]}") print('-' * 50) if (sp is not None) and (ep is not None): s, e = ann['timestamps'][sentence_idx] iou_ = (min(e, ep) - max(s, sp)) / (max(e, ep) - min(s, sp)) ious.append(max(iou_, 0)) recall += (thresh <= iou_) else: ious.append(0) processed_items.add(item_id) if (idx + 1) % 5 == 0 or idx == len(remaining_items) - 1: state = { 'processed_items': processed_items, 'ious': ious, 'recall': recall } save_checkpoint(checkpoint_path, state) miou = sum(ious) / len(ious) if ious else 0 recall_str = str(recall / len(ious) if ious else [0, 0, 0]) pbar.set_postfix({"mIoU": miou, 'recall': recall_str}) except Exception as e: print(f"Error processing {vid}_{sentence_idx}: {e}") print('=== final result ===') # if ious: print('mIoU:', sum(ious) / len(ious)) for th, r in zip(thresh, recall): print(f'R@{th}:', r / len(ious)) return ious, recall def evaluate(data, args): dataset = DATASETS[args.dataset] video_dir_path = dataset['video_path'] work_items = create_work_items(data) ious, recall = process_work_items( work_items, video_dir_path, args.model_base, args.device, args.checkpoint_dir, args.resume ) return ious, recall if __name__=='__main__': args = get_args() assert args.dataset in DATASETS dataset = DATASETS[args.dataset] assert args.split in dataset['splits'] print('evaluate', args.dataset, args.split) # load data with open(dataset['splits'][args.split]['annotation_file']) as f: data = json.load(f) evaluate(data, args)