| from data_configs import DATASETS |
| import argparse |
| import numpy as np |
| import json |
| from tqdm import tqdm |
| import os |
| import re |
| import pickle |
| import torch |
| from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor |
| from qwen_vl_utils import process_vision_info |
| import random |
|
|
| VIDEO_INFO_CACHE = {} |
|
|
| def get_args(): |
| parser = argparse.ArgumentParser(description='Evaluation for training-free video temporal grounding (Single GPU Version)') |
| parser.add_argument('--dataset', default='charades', type=str, help='Specify the dataset.') |
| parser.add_argument('--split', default='default', type=str, help='Specify the split.') |
| parser.add_argument("--model_base", type=str, default="/path/to/qwen-model") |
| parser.add_argument("--batch_size", type=int, default=1, help="Batch size") |
| parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory to save checkpoints") |
| parser.add_argument("--resume", action="store_true", help="Resume from checkpoint") |
| parser.add_argument("--device", type=str, default="cuda:0", help="GPU device to use") |
| return parser.parse_args() |
|
|
| def calc_iou(candidates, gt): |
| start, end = candidates[:,0], candidates[:,1] |
| s, e = gt[0], gt[1] |
| inter = np.minimum(end, e) - np.maximum(start, s) |
| union = np.maximum(end, e) - np.minimum(start, s) |
| return inter.clip(min=0) / union |
|
|
| def cached_process_vision_info(messages, return_video_kwargs=False): |
| global VIDEO_INFO_CACHE |
| |
| video_path = None |
| for msg in messages: |
| for content in msg.get('content', []): |
| if isinstance(content, dict) and 'video' in content: |
| video_path = content['video'] |
| break |
| |
| cache_key = f"{video_path}_{return_video_kwargs}" |
| if cache_key in VIDEO_INFO_CACHE: |
| return VIDEO_INFO_CACHE[cache_key] |
| |
| result = process_vision_info(messages, return_video_kwargs=return_video_kwargs) |
| VIDEO_INFO_CACHE[cache_key] = result |
| |
| return result |
|
|
| def inference(video_path, prompt, model, processor, max_new_tokens=2048, device="cuda:0"): |
| messages = [ |
| {"role": "system", "content": "You are a video analysis expert."}, |
| {"role": "user", "content": [ |
| {"type": "video", |
| "video": video_path, |
| "total_pixels": 3584 * 28 * 28, |
| "min_pixels": 200704, |
| }, |
| {"type": "text", "text": prompt}, |
| ] |
| }, |
| ] |
| text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) |
| |
| image_inputs, video_inputs, video_kwargs = cached_process_vision_info(messages, return_video_kwargs=True) |
| fps_inputs = video_kwargs['fps'] |
| |
| inputs = processor(text=[text], images=image_inputs, videos=video_inputs, fps=fps_inputs, padding=True, return_tensors="pt") |
| inputs = inputs.to(device) |
|
|
| with torch.no_grad(): |
| output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens) |
| |
| generated_ids = [output_ids[i][len(inputs.input_ids[i]):] for i in range(len(output_ids))] |
| output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) |
| return output_text[0] |
|
|
| def parse_timestamp_output(output_string): |
| matches = re.findall(r"(\d+\.?\d*) (to|and) (\d+\.?\d*)", output_string) |
| if not matches: |
| answer_match = re.search(r"<answer>(.*?)</answer>", output_string) |
| if answer_match: |
| answer_content = answer_match.group(1).strip() |
| answer_matches = re.findall(r"(\d+\.?\d*) (to|and) (\d+\.?\d*)", answer_content) |
| if answer_matches: |
| last_match = answer_matches[-1] |
| return float(last_match[0]), float(last_match[2]) |
| return None, None |
|
|
| last_match = matches[-1] |
| start_time_str = last_match[0] |
| end_time_str = last_match[2] |
| |
| try: |
| start_time = float(start_time_str) |
| end_time = float(end_time_str) |
| return start_time, end_time |
| except ValueError: |
| return None, None |
|
|
| |
|
|
| |
|
|
| |
|
|
|
|
| GROUND_TEMPLATE = """To accurately pinpoint the event "[EVENT]" in the video, determine the precise time period of the event. |
| |
| Provide the start and end times (in seconds, precise to two decimal places) in the format "start time to end time" within the <answer> </answer> tags. For example: <answer> "12.54 to 17.83" </answer>.""" |
|
|
|
|
| def create_work_items(data): |
| work_items = [] |
| for vid, ann in data.items(): |
| for i in range(len(ann['sentences'])): |
| work_items.append({ |
| 'vid': vid, |
| 'ann': ann, |
| 'sentence_idx': i |
| }) |
| |
| random.shuffle(work_items) |
| return work_items |
|
|
| def setup_model(model_base, device): |
| print(f"Setting up model on device {device}") |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( |
| model_base, |
| torch_dtype=torch.bfloat16, |
| use_sliding_window=True, |
| attn_implementation="flash_attention_2", |
| device_map=device |
| ) |
| processor = AutoProcessor.from_pretrained(model_base) |
| return model, processor |
|
|
| def get_checkpoint_path(checkpoint_dir): |
| os.makedirs(checkpoint_dir, exist_ok=True) |
| return os.path.join(checkpoint_dir, "checkpoint.pkl") |
|
|
| def load_checkpoint(checkpoint_path): |
| if os.path.exists(checkpoint_path): |
| try: |
| with open(checkpoint_path, 'rb') as f: |
| return pickle.load(f) |
| except Exception as e: |
| print(f"Error loading checkpoint: {e}") |
| return {'processed_items': set(), 'ious': [], 'recall': np.array([0, 0, 0])} |
|
|
| def save_checkpoint(checkpoint_path, state): |
| with open(checkpoint_path, 'wb') as f: |
| pickle.dump(state, f) |
|
|
| def process_work_items(work_items, video_dir_path, model_base, device, checkpoint_dir, resume=False): |
| ious = [] |
| thresh = np.array([0.3, 0.5, 0.7]) |
| recall = np.array([0, 0, 0]) |
| |
| |
| checkpoint_path = get_checkpoint_path(checkpoint_dir) |
| processed_items = set() |
| |
| if resume and os.path.exists(checkpoint_path): |
| checkpoint = load_checkpoint(checkpoint_path) |
| processed_items = checkpoint['processed_items'] |
| ious = checkpoint['ious'] |
| recall = checkpoint['recall'] |
| print(f"Resuming from checkpoint with {len(processed_items)} processed items") |
|
|
| model, processor = setup_model(model_base, device) |
| |
| item_ids = [f"{item['vid']}_{item['sentence_idx']}" for item in work_items] |
| remaining_items = [(i, item) for i, (item, item_id) in enumerate(zip(work_items, item_ids)) |
| if not resume or item_id not in processed_items] |
| |
| if not remaining_items: |
| print("All items already processed") |
| return ious, recall |
| |
| print(f"Processing {len(remaining_items)} out of {len(work_items)} items") |
| |
| pbar = tqdm(remaining_items) |
| for idx, (_, item) in enumerate(pbar): |
| vid = item['vid'] |
| ann = item['ann'] |
| sentence_idx = item['sentence_idx'] |
| item_id = f"{vid}_{sentence_idx}" |
| |
| sentence = ann['sentences'][sentence_idx].strip().lower() |
| if sentence.endswith("."): |
| sentence = sentence[:-1] |
| prompt = GROUND_TEMPLATE.replace('[EVENT]', sentence) |
| |
| |
| duration = ann['duration'] if 'duration' in ann else ann['video_duration'] |
| video_path = None |
| for ext in ['mp4', 'mkv', 'webm']: |
| path = os.path.join(video_dir_path, f"{vid}.{ext}") |
| if os.path.isfile(path): |
| video_path = path |
| break |
| |
| |
| if video_path: |
| try: |
| ans = inference(video_path, prompt, model, processor, device=device) |
| |
| |
| sp, ep = parse_timestamp_output(ans) |
| print(f"Parsed times: {sp}, {ep}") |
| print(f"Ground truth: {ann['timestamps'][sentence_idx]}") |
| print('-' * 50) |
| |
| if (sp is not None) and (ep is not None): |
| s, e = ann['timestamps'][sentence_idx] |
| iou_ = (min(e, ep) - max(s, sp)) / (max(e, ep) - min(s, sp)) |
| ious.append(max(iou_, 0)) |
| recall += (thresh <= iou_) |
| else: |
| ious.append(0) |
| |
| processed_items.add(item_id) |
| |
| if (idx + 1) % 5 == 0 or idx == len(remaining_items) - 1: |
| state = { |
| 'processed_items': processed_items, |
| 'ious': ious, |
| 'recall': recall |
| } |
| save_checkpoint(checkpoint_path, state) |
| |
| miou = sum(ious) / len(ious) if ious else 0 |
| recall_str = str(recall / len(ious) if ious else [0, 0, 0]) |
| pbar.set_postfix({"mIoU": miou, 'recall': recall_str}) |
| |
| except Exception as e: |
| print(f"Error processing {vid}_{sentence_idx}: {e}") |
| |
| print('=== final result ===') |
| |
| print('mIoU:', sum(ious) / len(ious)) |
| for th, r in zip(thresh, recall): |
| print(f'R@{th}:', r / len(ious)) |
| |
| return ious, recall |
|
|
| def evaluate(data, args): |
| dataset = DATASETS[args.dataset] |
| video_dir_path = dataset['video_path'] |
| |
| work_items = create_work_items(data) |
| |
| ious, recall = process_work_items( |
| work_items, |
| video_dir_path, |
| args.model_base, |
| args.device, |
| args.checkpoint_dir, |
| args.resume |
| ) |
| |
| return ious, recall |
|
|
| if __name__=='__main__': |
| args = get_args() |
| assert args.dataset in DATASETS |
| dataset = DATASETS[args.dataset] |
| assert args.split in dataset['splits'] |
| |
| print('evaluate', args.dataset, args.split) |
| |
| |
| with open(dataset['splits'][args.split]['annotation_file']) as f: |
| data = json.load(f) |
|
|
| evaluate(data, args) |