File size: 7,154 Bytes
b0c0df0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
import datetime
import json
import os
import re
import sys
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Union
import cv2
import numpy as np
import yaml
from loguru import logger as eval_logger
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
# with open(Path(__file__).parent / "_default_template_yaml", "r") as f:
# raw_data = f.readlines()
# safe_data = []
# for i, line in enumerate(raw_data):
# # remove function definition since yaml load cannot handle it
# if "!function" not in line:
# safe_data.append(line)
# config = yaml.safe_load("".join(safe_data))
hf_home = os.getenv("HF_HOME", "~/.cache/huggingface/")
# cache_dir = os.path.join(hf_home, cache_dir)
# base_cache_dir = config["dataset_kwargs"]["cache_dir"]
base_cache_dir = os.path.expanduser(hf_home)
with open(Path(__file__).parent / "egoplan.yaml", "r") as f:
raw_data = f.readlines()
safe_data = []
for i, line in enumerate(raw_data):
# remove function definition since yaml load cannot handle it
if "!function" not in line:
safe_data.append(line)
cache_name = yaml.safe_load("".join(safe_data))["dataset_kwargs"]["cache_dir"]
def parse_subtitle_time(time_str):
h, m, s_ms = time_str.split(":")
s, ms = s_ms.split(",")
return int(h) * 3600 + int(m) * 60 + int(s) + int(ms) / 1000
def load_subtitles(subtitle_path):
subtitles = {}
with open(subtitle_path, "r", encoding="utf-8") as file:
content = file.read().split("\n\n")
for section in content:
if section.strip():
lines = section.split("\n")
if len(lines) >= 3:
time_range = lines[1].split(" --> ")
start_time = parse_subtitle_time(time_range[0])
end_time = parse_subtitle_time(time_range[1])
text = " ".join(line for line in lines[2:])
subtitles[(start_time, end_time)] = text
return subtitles
def convert_time_to_frame(time_in_seconds, fps):
return int(time_in_seconds * fps)
def extract_subtitles(video_path, subtitle_path):
video = cv2.VideoCapture(video_path)
fps = video.get(cv2.CAP_PROP_FPS)
total_frame = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
subtitles = load_subtitles(subtitle_path)
subtitle_frames = []
for (start_time, end_time), text in subtitles.items():
start_frame = convert_time_to_frame(start_time, fps)
end_frame = convert_time_to_frame(end_time, fps)
subtitle_frames.append((start_frame, end_frame, text))
return subtitle_frames, total_frame
def parse_subtitle_time(time_str):
h, m, s_ms = time_str.split(":")
s, ms = s_ms.split(",")
return int(h) * 3600 + int(m) * 60 + int(s) + int(ms) / 1000
def load_subtitles(subtitle_path):
subtitles = {}
with open(subtitle_path, "r", encoding="utf-8") as file:
content = file.read().split("\n\n")
for section in content:
if section.strip():
lines = section.split("\n")
if len(lines) >= 3:
time_range = lines[1].split(" --> ")
start_time = parse_subtitle_time(time_range[0])
end_time = parse_subtitle_time(time_range[1])
text = " ".join(line for line in lines[2:])
subtitles[(start_time, end_time)] = text
return subtitles
def convert_time_to_frame(time_in_seconds, fps):
return int(time_in_seconds * fps)
def extract_subtitles(video_path, subtitle_path):
video = cv2.VideoCapture(video_path)
fps = video.get(cv2.CAP_PROP_FPS)
total_frame = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
subtitles = load_subtitles(subtitle_path)
subtitle_frames = []
for (start_time, end_time), text in subtitles.items():
start_frame = convert_time_to_frame(start_time, fps)
end_frame = convert_time_to_frame(end_time, fps)
subtitle_frames.append((start_frame, end_frame, text))
return subtitle_frames, total_frame
def egoplan_doc_to_visual(doc):
cache_dir = os.path.join(base_cache_dir, cache_name)
video_path = str(doc["sample_id"]) + ".mp4"
video_path = os.path.join(cache_dir, video_path)
if os.path.exists(video_path):
video_path = video_path
elif os.path.exists(video_path.replace("mp4", "MP4")):
video_path = video_path.replace("mp4", "MP4")
elif os.path.exists(video_path.replace("mp4", "mkv")):
video_path = video_path.replace("mp4", "mkv")
else:
sys.exit(f"video path:{video_path} does not exist, please check")
return [video_path]
def egoplan_doc_to_text(doc, lmms_eval_specific_kwargs=None):
task_goal = doc["task_goal"]
if "goal" in task_goal:
task_goal = task_goal.split("to", 1)[1].strip()
words = task_goal.split()
if words[0].endswith("ing"):
question_pattern = (
"I am tasked with {}. "
"The task's progress is demonstrated in the provided video. "
"My current field of view is shown in the provided image. "
"What should be my next action? "
"Please output the most reasonable action you think, expressed in a short phrase."
)
else:
question_pattern = (
"My current task is to {}. "
"The task's progress is demonstrated in the provided video. "
"My current field of view is shown in the provided image. "
"What should be my next action? "
"Please output the most reasonable action you think, expressed in a short phrase."
)
question = question_pattern.format(task_goal)
candidates = []
for choice_idx in ["A", "B", "C", "D"]:
question += "\n" + f"{choice_idx}. " + (doc[f"choice_{choice_idx.lower()}"])
post_prompt = "\nAnswer with the option's letter from the given choices"
return f"{question}{post_prompt}"
def extract_characters_regex(s):
s = s.strip()
answer_prefixes = [
"The best answer is",
"The correct answer is",
"The answer is",
"The answer",
"The best option is" "The correct option is",
"Best answer:" "Best option:",
]
for answer_prefix in answer_prefixes:
s = s.replace(answer_prefix, "")
if len(s.split()) > 10 and not re.search("[ABCD]", s):
return ""
matches = re.search(r"[ABCD]", s)
if matches is None:
return ""
return matches[0]
def egoplan_process_results(doc, results):
pred = results[0]
pred_ans = extract_characters_regex(pred)
# gt_ans = doc["answer"].lower().strip().replace(".", "")
doc["pred_answer"] = pred_ans
data_dict = doc.copy()
return {f"egoplan_mcq_accuracy": data_dict}
def egoplan_aggregate_results(results):
correct_num = 0
for result in results:
if result["pred_answer"] == result["golden_choice_idx"]:
correct_num += 1
question_num = len(results)
accuracy = correct_num / question_num
return accuracy
|