import os import bisect from dataset.msa_info_utils import ( load_msa_info, ) from dataset.custom_types import MsaInfo import glob import pdb import pandas as pd def cal_acc(ann_info: MsaInfo | str, est_info: MsaInfo | str, post_digit: int = 3): if type(ann_info) is str: assert os.path.exists(ann_info), f"{ann_info} not exists" ann_info = load_msa_info(ann_info) if type(ann_info) is str: assert os.path.exists(est_info), f"{est_info} not exists" est_info = load_msa_info(est_info) ann_info_time = [ int(round(time_, post_digit) * (10**post_digit)) for time_, label in ann_info ] est_info_time = [ int(round(time_, post_digit) * (10**post_digit)) for time_, label in est_info ] common_start_time = max(ann_info_time[0], est_info_time[0]) common_end_time = min(ann_info_time[-1], est_info_time[-1]) time_points = set() time_points.add(common_start_time) time_points.add(common_end_time) for time_ in ann_info_time: if time_ >= common_start_time and time_ <= common_end_time: time_points.add(time_) for time_ in est_info_time: if time_ >= common_start_time and time_ <= common_end_time: time_points.add(time_) time_points = sorted(list(time_points)) total_duration = 0 total_score = 0 for idx in range(len(time_points) - 1): duration = time_points[idx + 1] - time_points[idx] ann_label = ann_info[bisect.bisect_right(ann_info_time, time_points[idx]) - 1][ 1 ] est_label = est_info[bisect.bisect_right(est_info_time, time_points[idx]) - 1][ 1 ] total_duration += duration if ann_label == est_label: total_score += duration return total_score / total_duration if __name__ == "__main__": ext_paths = glob.glob("") results = [] for ext_path in ext_paths: try: ann_path = os.path.join( "", os.path.basename(ext_path).split(".")[0] + ".txt", ) results.append( { "data_id": os.path.basename(ext_path).split(".")[0], "acc": cal_acc( ann_info=ann_path, est_info=ext_path, ), } ) except Exception as e: print(e) continue df = pd.DataFrame(results) print(df["acc"].mean())