File size: 2,509 Bytes
70d8fcf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import bisect
from dataset.msa_info_utils import (
    load_msa_info,
)
from dataset.custom_types import MsaInfo
import glob
import pdb
import pandas as pd


def cal_acc(ann_info: MsaInfo | str, est_info: MsaInfo | str, post_digit: int = 3):
    if type(ann_info) is str:
        assert os.path.exists(ann_info), f"{ann_info} not exists"
        ann_info = load_msa_info(ann_info)

    if type(ann_info) is str:
        assert os.path.exists(est_info), f"{est_info} not exists"
        est_info = load_msa_info(est_info)

    ann_info_time = [
        int(round(time_, post_digit) * (10**post_digit)) for time_, label in ann_info
    ]
    est_info_time = [
        int(round(time_, post_digit) * (10**post_digit)) for time_, label in est_info
    ]

    common_start_time = max(ann_info_time[0], est_info_time[0])
    common_end_time = min(ann_info_time[-1], est_info_time[-1])

    time_points = set()
    time_points.add(common_start_time)
    time_points.add(common_end_time)

    for time_ in ann_info_time:
        if time_ >= common_start_time and time_ <= common_end_time:
            time_points.add(time_)
    for time_ in est_info_time:
        if time_ >= common_start_time and time_ <= common_end_time:
            time_points.add(time_)

    time_points = sorted(list(time_points))
    total_duration = 0
    total_score = 0

    for idx in range(len(time_points) - 1):
        duration = time_points[idx + 1] - time_points[idx]
        ann_label = ann_info[bisect.bisect_right(ann_info_time, time_points[idx]) - 1][
            1
        ]
        est_label = est_info[bisect.bisect_right(est_info_time, time_points[idx]) - 1][
            1
        ]
        total_duration += duration
        if ann_label == est_label:
            total_score += duration
    return total_score / total_duration


if __name__ == "__main__":
    ext_paths = glob.glob("")
    results = []
    for ext_path in ext_paths:
        try:
            ann_path = os.path.join(
                "",
                os.path.basename(ext_path).split(".")[0] + ".txt",
            )
            results.append(
                {
                    "data_id": os.path.basename(ext_path).split(".")[0],
                    "acc": cal_acc(
                        ann_info=ann_path,
                        est_info=ext_path,
                    ),
                }
            )
        except Exception as e:
            print(e)
            continue
    df = pd.DataFrame(results)
    print(df["acc"].mean())