|
|
import datetime |
|
|
import json |
|
|
import os |
|
|
import re |
|
|
import sys |
|
|
from collections import defaultdict |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Optional, Union |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import yaml |
|
|
from loguru import logger as eval_logger |
|
|
|
|
|
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hf_home = os.getenv("HF_HOME", "~/.cache/huggingface/") |
|
|
|
|
|
|
|
|
base_cache_dir = os.path.expanduser(hf_home) |
|
|
with open(Path(__file__).parent / "egoplan.yaml", "r") as f: |
|
|
raw_data = f.readlines() |
|
|
safe_data = [] |
|
|
for i, line in enumerate(raw_data): |
|
|
|
|
|
if "!function" not in line: |
|
|
safe_data.append(line) |
|
|
cache_name = yaml.safe_load("".join(safe_data))["dataset_kwargs"]["cache_dir"] |
|
|
|
|
|
|
|
|
def parse_subtitle_time(time_str): |
|
|
h, m, s_ms = time_str.split(":") |
|
|
s, ms = s_ms.split(",") |
|
|
return int(h) * 3600 + int(m) * 60 + int(s) + int(ms) / 1000 |
|
|
|
|
|
|
|
|
def load_subtitles(subtitle_path): |
|
|
subtitles = {} |
|
|
with open(subtitle_path, "r", encoding="utf-8") as file: |
|
|
content = file.read().split("\n\n") |
|
|
for section in content: |
|
|
if section.strip(): |
|
|
lines = section.split("\n") |
|
|
if len(lines) >= 3: |
|
|
time_range = lines[1].split(" --> ") |
|
|
start_time = parse_subtitle_time(time_range[0]) |
|
|
end_time = parse_subtitle_time(time_range[1]) |
|
|
text = " ".join(line for line in lines[2:]) |
|
|
subtitles[(start_time, end_time)] = text |
|
|
return subtitles |
|
|
|
|
|
|
|
|
def convert_time_to_frame(time_in_seconds, fps): |
|
|
return int(time_in_seconds * fps) |
|
|
|
|
|
|
|
|
def extract_subtitles(video_path, subtitle_path): |
|
|
video = cv2.VideoCapture(video_path) |
|
|
fps = video.get(cv2.CAP_PROP_FPS) |
|
|
total_frame = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
subtitles = load_subtitles(subtitle_path) |
|
|
|
|
|
subtitle_frames = [] |
|
|
for (start_time, end_time), text in subtitles.items(): |
|
|
start_frame = convert_time_to_frame(start_time, fps) |
|
|
end_frame = convert_time_to_frame(end_time, fps) |
|
|
subtitle_frames.append((start_frame, end_frame, text)) |
|
|
|
|
|
return subtitle_frames, total_frame |
|
|
|
|
|
|
|
|
def parse_subtitle_time(time_str): |
|
|
h, m, s_ms = time_str.split(":") |
|
|
s, ms = s_ms.split(",") |
|
|
return int(h) * 3600 + int(m) * 60 + int(s) + int(ms) / 1000 |
|
|
|
|
|
|
|
|
def load_subtitles(subtitle_path): |
|
|
subtitles = {} |
|
|
with open(subtitle_path, "r", encoding="utf-8") as file: |
|
|
content = file.read().split("\n\n") |
|
|
for section in content: |
|
|
if section.strip(): |
|
|
lines = section.split("\n") |
|
|
if len(lines) >= 3: |
|
|
time_range = lines[1].split(" --> ") |
|
|
start_time = parse_subtitle_time(time_range[0]) |
|
|
end_time = parse_subtitle_time(time_range[1]) |
|
|
text = " ".join(line for line in lines[2:]) |
|
|
subtitles[(start_time, end_time)] = text |
|
|
return subtitles |
|
|
|
|
|
|
|
|
def convert_time_to_frame(time_in_seconds, fps): |
|
|
return int(time_in_seconds * fps) |
|
|
|
|
|
|
|
|
def extract_subtitles(video_path, subtitle_path): |
|
|
video = cv2.VideoCapture(video_path) |
|
|
fps = video.get(cv2.CAP_PROP_FPS) |
|
|
total_frame = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
|
subtitles = load_subtitles(subtitle_path) |
|
|
|
|
|
subtitle_frames = [] |
|
|
for (start_time, end_time), text in subtitles.items(): |
|
|
start_frame = convert_time_to_frame(start_time, fps) |
|
|
end_frame = convert_time_to_frame(end_time, fps) |
|
|
subtitle_frames.append((start_frame, end_frame, text)) |
|
|
|
|
|
return subtitle_frames, total_frame |
|
|
|
|
|
|
|
|
def egoplan_doc_to_visual(doc): |
|
|
cache_dir = os.path.join(base_cache_dir, cache_name) |
|
|
video_path = str(doc["sample_id"]) + ".mp4" |
|
|
video_path = os.path.join(cache_dir, video_path) |
|
|
if os.path.exists(video_path): |
|
|
video_path = video_path |
|
|
elif os.path.exists(video_path.replace("mp4", "MP4")): |
|
|
video_path = video_path.replace("mp4", "MP4") |
|
|
elif os.path.exists(video_path.replace("mp4", "mkv")): |
|
|
video_path = video_path.replace("mp4", "mkv") |
|
|
else: |
|
|
sys.exit(f"video path:{video_path} does not exist, please check") |
|
|
return [video_path] |
|
|
|
|
|
|
|
|
def egoplan_doc_to_text(doc, lmms_eval_specific_kwargs=None): |
|
|
task_goal = doc["task_goal"] |
|
|
if "goal" in task_goal: |
|
|
task_goal = task_goal.split("to", 1)[1].strip() |
|
|
words = task_goal.split() |
|
|
if words[0].endswith("ing"): |
|
|
question_pattern = ( |
|
|
"I am tasked with {}. " |
|
|
"The task's progress is demonstrated in the provided video. " |
|
|
"My current field of view is shown in the provided image. " |
|
|
"What should be my next action? " |
|
|
"Please output the most reasonable action you think, expressed in a short phrase." |
|
|
) |
|
|
else: |
|
|
question_pattern = ( |
|
|
"My current task is to {}. " |
|
|
"The task's progress is demonstrated in the provided video. " |
|
|
"My current field of view is shown in the provided image. " |
|
|
"What should be my next action? " |
|
|
"Please output the most reasonable action you think, expressed in a short phrase." |
|
|
) |
|
|
question = question_pattern.format(task_goal) |
|
|
|
|
|
candidates = [] |
|
|
for choice_idx in ["A", "B", "C", "D"]: |
|
|
question += "\n" + f"{choice_idx}. " + (doc[f"choice_{choice_idx.lower()}"]) |
|
|
post_prompt = "\nAnswer with the option's letter from the given choices" |
|
|
|
|
|
return f"{question}{post_prompt}" |
|
|
|
|
|
|
|
|
def extract_characters_regex(s): |
|
|
s = s.strip() |
|
|
answer_prefixes = [ |
|
|
"The best answer is", |
|
|
"The correct answer is", |
|
|
"The answer is", |
|
|
"The answer", |
|
|
"The best option is" "The correct option is", |
|
|
"Best answer:" "Best option:", |
|
|
] |
|
|
for answer_prefix in answer_prefixes: |
|
|
s = s.replace(answer_prefix, "") |
|
|
|
|
|
if len(s.split()) > 10 and not re.search("[ABCD]", s): |
|
|
return "" |
|
|
|
|
|
matches = re.search(r"[ABCD]", s) |
|
|
if matches is None: |
|
|
return "" |
|
|
return matches[0] |
|
|
|
|
|
|
|
|
def egoplan_process_results(doc, results): |
|
|
pred = results[0] |
|
|
pred_ans = extract_characters_regex(pred) |
|
|
|
|
|
doc["pred_answer"] = pred_ans |
|
|
data_dict = doc.copy() |
|
|
return {f"egoplan_mcq_accuracy": data_dict} |
|
|
|
|
|
|
|
|
def egoplan_aggregate_results(results): |
|
|
correct_num = 0 |
|
|
for result in results: |
|
|
if result["pred_answer"] == result["golden_choice_idx"]: |
|
|
correct_num += 1 |
|
|
question_num = len(results) |
|
|
accuracy = correct_num / question_num |
|
|
return accuracy |
|
|
|