csuhan's picture
Upload folder using huggingface_hub
b0c0df0 verified
import datetime
import json
import os
import re
import sys
from collections import defaultdict
from pathlib import Path
from typing import Dict, List, Optional, Union
import cv2
import numpy as np
import yaml
from loguru import logger as eval_logger
from lmms_eval.tasks._task_utils.file_utils import generate_submission_file
# with open(Path(__file__).parent / "_default_template_yaml", "r") as f:
# raw_data = f.readlines()
# safe_data = []
# for i, line in enumerate(raw_data):
# # remove function definition since yaml load cannot handle it
# if "!function" not in line:
# safe_data.append(line)
# config = yaml.safe_load("".join(safe_data))
hf_home = os.getenv("HF_HOME", "~/.cache/huggingface/")
# cache_dir = os.path.join(hf_home, cache_dir)
# base_cache_dir = config["dataset_kwargs"]["cache_dir"]
base_cache_dir = os.path.expanduser(hf_home)
with open(Path(__file__).parent / "egoplan.yaml", "r") as f:
raw_data = f.readlines()
safe_data = []
for i, line in enumerate(raw_data):
# remove function definition since yaml load cannot handle it
if "!function" not in line:
safe_data.append(line)
cache_name = yaml.safe_load("".join(safe_data))["dataset_kwargs"]["cache_dir"]
def parse_subtitle_time(time_str):
h, m, s_ms = time_str.split(":")
s, ms = s_ms.split(",")
return int(h) * 3600 + int(m) * 60 + int(s) + int(ms) / 1000
def load_subtitles(subtitle_path):
subtitles = {}
with open(subtitle_path, "r", encoding="utf-8") as file:
content = file.read().split("\n\n")
for section in content:
if section.strip():
lines = section.split("\n")
if len(lines) >= 3:
time_range = lines[1].split(" --> ")
start_time = parse_subtitle_time(time_range[0])
end_time = parse_subtitle_time(time_range[1])
text = " ".join(line for line in lines[2:])
subtitles[(start_time, end_time)] = text
return subtitles
def convert_time_to_frame(time_in_seconds, fps):
return int(time_in_seconds * fps)
def extract_subtitles(video_path, subtitle_path):
video = cv2.VideoCapture(video_path)
fps = video.get(cv2.CAP_PROP_FPS)
total_frame = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
subtitles = load_subtitles(subtitle_path)
subtitle_frames = []
for (start_time, end_time), text in subtitles.items():
start_frame = convert_time_to_frame(start_time, fps)
end_frame = convert_time_to_frame(end_time, fps)
subtitle_frames.append((start_frame, end_frame, text))
return subtitle_frames, total_frame
def parse_subtitle_time(time_str):
h, m, s_ms = time_str.split(":")
s, ms = s_ms.split(",")
return int(h) * 3600 + int(m) * 60 + int(s) + int(ms) / 1000
def load_subtitles(subtitle_path):
subtitles = {}
with open(subtitle_path, "r", encoding="utf-8") as file:
content = file.read().split("\n\n")
for section in content:
if section.strip():
lines = section.split("\n")
if len(lines) >= 3:
time_range = lines[1].split(" --> ")
start_time = parse_subtitle_time(time_range[0])
end_time = parse_subtitle_time(time_range[1])
text = " ".join(line for line in lines[2:])
subtitles[(start_time, end_time)] = text
return subtitles
def convert_time_to_frame(time_in_seconds, fps):
return int(time_in_seconds * fps)
def extract_subtitles(video_path, subtitle_path):
video = cv2.VideoCapture(video_path)
fps = video.get(cv2.CAP_PROP_FPS)
total_frame = int(video.get(cv2.CAP_PROP_FRAME_COUNT))
subtitles = load_subtitles(subtitle_path)
subtitle_frames = []
for (start_time, end_time), text in subtitles.items():
start_frame = convert_time_to_frame(start_time, fps)
end_frame = convert_time_to_frame(end_time, fps)
subtitle_frames.append((start_frame, end_frame, text))
return subtitle_frames, total_frame
def egoplan_doc_to_visual(doc):
cache_dir = os.path.join(base_cache_dir, cache_name)
video_path = str(doc["sample_id"]) + ".mp4"
video_path = os.path.join(cache_dir, video_path)
if os.path.exists(video_path):
video_path = video_path
elif os.path.exists(video_path.replace("mp4", "MP4")):
video_path = video_path.replace("mp4", "MP4")
elif os.path.exists(video_path.replace("mp4", "mkv")):
video_path = video_path.replace("mp4", "mkv")
else:
sys.exit(f"video path:{video_path} does not exist, please check")
return [video_path]
def egoplan_doc_to_text(doc, lmms_eval_specific_kwargs=None):
task_goal = doc["task_goal"]
if "goal" in task_goal:
task_goal = task_goal.split("to", 1)[1].strip()
words = task_goal.split()
if words[0].endswith("ing"):
question_pattern = (
"I am tasked with {}. "
"The task's progress is demonstrated in the provided video. "
"My current field of view is shown in the provided image. "
"What should be my next action? "
"Please output the most reasonable action you think, expressed in a short phrase."
)
else:
question_pattern = (
"My current task is to {}. "
"The task's progress is demonstrated in the provided video. "
"My current field of view is shown in the provided image. "
"What should be my next action? "
"Please output the most reasonable action you think, expressed in a short phrase."
)
question = question_pattern.format(task_goal)
candidates = []
for choice_idx in ["A", "B", "C", "D"]:
question += "\n" + f"{choice_idx}. " + (doc[f"choice_{choice_idx.lower()}"])
post_prompt = "\nAnswer with the option's letter from the given choices"
return f"{question}{post_prompt}"
def extract_characters_regex(s):
s = s.strip()
answer_prefixes = [
"The best answer is",
"The correct answer is",
"The answer is",
"The answer",
"The best option is" "The correct option is",
"Best answer:" "Best option:",
]
for answer_prefix in answer_prefixes:
s = s.replace(answer_prefix, "")
if len(s.split()) > 10 and not re.search("[ABCD]", s):
return ""
matches = re.search(r"[ABCD]", s)
if matches is None:
return ""
return matches[0]
def egoplan_process_results(doc, results):
pred = results[0]
pred_ans = extract_characters_regex(pred)
# gt_ans = doc["answer"].lower().strip().replace(".", "")
doc["pred_answer"] = pred_ans
data_dict = doc.copy()
return {f"egoplan_mcq_accuracy": data_dict}
def egoplan_aggregate_results(results):
correct_num = 0
for result in results:
if result["pred_answer"] == result["golden_choice_idx"]:
correct_num += 1
question_num = len(results)
accuracy = correct_num / question_num
return accuracy