File size: 10,734 Bytes
33569f9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 | from data_configs import DATASETS
import argparse
import numpy as np
import json
from tqdm import tqdm
import os
import re
import pickle
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
from qwen_vl_utils import process_vision_info
import random
VIDEO_INFO_CACHE = {}
def get_args():
parser = argparse.ArgumentParser(description='Evaluation for training-free video temporal grounding (Single GPU Version)')
parser.add_argument('--dataset', default='charades', type=str, help='Specify the dataset.')
parser.add_argument('--split', default='default', type=str, help='Specify the split.')
parser.add_argument("--model_base", type=str, default="/path/to/qwen-model")
parser.add_argument("--batch_size", type=int, default=1, help="Batch size")
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Directory to save checkpoints")
parser.add_argument("--resume", action="store_true", help="Resume from checkpoint")
parser.add_argument("--device", type=str, default="cuda:0", help="GPU device to use")
return parser.parse_args()
def calc_iou(candidates, gt):
start, end = candidates[:,0], candidates[:,1]
s, e = gt[0], gt[1]
inter = np.minimum(end, e) - np.maximum(start, s)
union = np.maximum(end, e) - np.minimum(start, s)
return inter.clip(min=0) / union
def cached_process_vision_info(messages, return_video_kwargs=False):
global VIDEO_INFO_CACHE
video_path = None
for msg in messages:
for content in msg.get('content', []):
if isinstance(content, dict) and 'video' in content:
video_path = content['video']
break
cache_key = f"{video_path}_{return_video_kwargs}"
if cache_key in VIDEO_INFO_CACHE:
return VIDEO_INFO_CACHE[cache_key]
result = process_vision_info(messages, return_video_kwargs=return_video_kwargs)
VIDEO_INFO_CACHE[cache_key] = result
return result
def inference(video_path, prompt, model, processor, max_new_tokens=2048, device="cuda:0"):
messages = [
{"role": "system", "content": "You are a video analysis expert."},
{"role": "user", "content": [
{"type": "video",
"video": video_path,
"total_pixels": 3584 * 28 * 28,
"min_pixels": 200704,
},
{"type": "text", "text": prompt},
]
},
]
text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
image_inputs, video_inputs, video_kwargs = cached_process_vision_info(messages, return_video_kwargs=True)
fps_inputs = video_kwargs['fps']
inputs = processor(text=[text], images=image_inputs, videos=video_inputs, fps=fps_inputs, padding=True, return_tensors="pt")
inputs = inputs.to(device)
with torch.no_grad():
output_ids = model.generate(**inputs, max_new_tokens=max_new_tokens)
generated_ids = [output_ids[i][len(inputs.input_ids[i]):] for i in range(len(output_ids))]
output_text = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
return output_text[0]
def parse_timestamp_output(output_string):
matches = re.findall(r"(\d+\.?\d*) (to|and) (\d+\.?\d*)", output_string)
if not matches:
answer_match = re.search(r"<answer>(.*?)</answer>", output_string)
if answer_match:
answer_content = answer_match.group(1).strip()
answer_matches = re.findall(r"(\d+\.?\d*) (to|and) (\d+\.?\d*)", answer_content)
if answer_matches:
last_match = answer_matches[-1]
return float(last_match[0]), float(last_match[2])
return None, None
last_match = matches[-1]
start_time_str = last_match[0]
end_time_str = last_match[2]
try:
start_time = float(start_time_str)
end_time = float(end_time_str)
return start_time, end_time
except ValueError:
return None, None
# GROUND_TEMPLATE = """To accurately pinpoint the event "[EVENT]" in the video, determine the precise time period of the event.
# Output your thought process within the <think> </think> tags, including analysis with either specific timestamps (xx.xx) or time ranges (xx.xx to xx.xx) in <timestep> </timestep> tags.
# Then, provide the start and end times (in seconds, precise to two decimal places) in the format "start time to end time" within the <answer> </answer> tags. For example: 12.54 to 17.83."""
GROUND_TEMPLATE = """To accurately pinpoint the event "[EVENT]" in the video, determine the precise time period of the event.
Provide the start and end times (in seconds, precise to two decimal places) in the format "start time to end time" within the <answer> </answer> tags. For example: <answer> "12.54 to 17.83" </answer>."""
def create_work_items(data):
work_items = []
for vid, ann in data.items():
for i in range(len(ann['sentences'])):
work_items.append({
'vid': vid,
'ann': ann,
'sentence_idx': i
})
# 随机打乱列表
random.shuffle(work_items)
return work_items
def setup_model(model_base, device):
print(f"Setting up model on device {device}")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_base,
torch_dtype=torch.bfloat16,
use_sliding_window=True,
attn_implementation="flash_attention_2",
device_map=device
)
processor = AutoProcessor.from_pretrained(model_base)
return model, processor
def get_checkpoint_path(checkpoint_dir):
os.makedirs(checkpoint_dir, exist_ok=True)
return os.path.join(checkpoint_dir, "checkpoint.pkl")
def load_checkpoint(checkpoint_path):
if os.path.exists(checkpoint_path):
try:
with open(checkpoint_path, 'rb') as f:
return pickle.load(f)
except Exception as e:
print(f"Error loading checkpoint: {e}")
return {'processed_items': set(), 'ious': [], 'recall': np.array([0, 0, 0])}
def save_checkpoint(checkpoint_path, state):
with open(checkpoint_path, 'wb') as f:
pickle.dump(state, f)
def process_work_items(work_items, video_dir_path, model_base, device, checkpoint_dir, resume=False):
ious = []
thresh = np.array([0.3, 0.5, 0.7])
recall = np.array([0, 0, 0])
# 加载检查点(如果需要恢复)
checkpoint_path = get_checkpoint_path(checkpoint_dir)
processed_items = set()
if resume and os.path.exists(checkpoint_path):
checkpoint = load_checkpoint(checkpoint_path)
processed_items = checkpoint['processed_items']
ious = checkpoint['ious']
recall = checkpoint['recall']
print(f"Resuming from checkpoint with {len(processed_items)} processed items")
model, processor = setup_model(model_base, device)
item_ids = [f"{item['vid']}_{item['sentence_idx']}" for item in work_items]
remaining_items = [(i, item) for i, (item, item_id) in enumerate(zip(work_items, item_ids))
if not resume or item_id not in processed_items]
if not remaining_items:
print("All items already processed")
return ious, recall
print(f"Processing {len(remaining_items)} out of {len(work_items)} items")
pbar = tqdm(remaining_items)
for idx, (_, item) in enumerate(pbar):
vid = item['vid']
ann = item['ann']
sentence_idx = item['sentence_idx']
item_id = f"{vid}_{sentence_idx}"
sentence = ann['sentences'][sentence_idx].strip().lower()
if sentence.endswith("."):
sentence = sentence[:-1]
prompt = GROUND_TEMPLATE.replace('[EVENT]', sentence)
# 确定视频路径
duration = ann['duration'] if 'duration' in ann else ann['video_duration']
video_path = None
for ext in ['mp4', 'mkv', 'webm']:
path = os.path.join(video_dir_path, f"{vid}.{ext}")
if os.path.isfile(path):
video_path = path
break
# 处理视频
if video_path:
try:
ans = inference(video_path, prompt, model, processor, device=device)
# print('prompt', prompt)
# print('ans', ans)
sp, ep = parse_timestamp_output(ans)
print(f"Parsed times: {sp}, {ep}")
print(f"Ground truth: {ann['timestamps'][sentence_idx]}")
print('-' * 50)
if (sp is not None) and (ep is not None):
s, e = ann['timestamps'][sentence_idx]
iou_ = (min(e, ep) - max(s, sp)) / (max(e, ep) - min(s, sp))
ious.append(max(iou_, 0))
recall += (thresh <= iou_)
else:
ious.append(0)
processed_items.add(item_id)
if (idx + 1) % 5 == 0 or idx == len(remaining_items) - 1:
state = {
'processed_items': processed_items,
'ious': ious,
'recall': recall
}
save_checkpoint(checkpoint_path, state)
miou = sum(ious) / len(ious) if ious else 0
recall_str = str(recall / len(ious) if ious else [0, 0, 0])
pbar.set_postfix({"mIoU": miou, 'recall': recall_str})
except Exception as e:
print(f"Error processing {vid}_{sentence_idx}: {e}")
print('=== final result ===')
# if ious:
print('mIoU:', sum(ious) / len(ious))
for th, r in zip(thresh, recall):
print(f'R@{th}:', r / len(ious))
return ious, recall
def evaluate(data, args):
dataset = DATASETS[args.dataset]
video_dir_path = dataset['video_path']
work_items = create_work_items(data)
ious, recall = process_work_items(
work_items,
video_dir_path,
args.model_base,
args.device,
args.checkpoint_dir,
args.resume
)
return ious, recall
if __name__=='__main__':
args = get_args()
assert args.dataset in DATASETS
dataset = DATASETS[args.dataset]
assert args.split in dataset['splits']
print('evaluate', args.dataset, args.split)
# load data
with open(dataset['splits'][args.split]['annotation_file']) as f:
data = json.load(f)
evaluate(data, args) |