Spaces:
Running
Running
| import os | |
| from dataset.custom_types import MsaInfo | |
| from dataset.label2id import LABEL_TO_ID | |
| from pprint import pprint | |
| def load_msa_info(msa_info_path): | |
| msa_info: MsaInfo = [] | |
| with open(msa_info_path) as f: | |
| for line in f: | |
| line = line.strip() | |
| if not line: | |
| continue | |
| time_, label = line.split() | |
| time_ = float(time_) | |
| label = str(label) | |
| assert label in LABEL_TO_ID or label == "end", f"{label} not in LABEL_TO_ID" | |
| msa_info.append((time_, label)) | |
| assert msa_info[-1][1] == "end", f"last {msa_info[-1][1]} != end" | |
| return msa_info | |
| def msa_info_to_segments(msa_info): | |
| # skip the last "end" | |
| segments = [] | |
| for i in range(len(msa_info) - 1): | |
| start = msa_info[i][0] | |
| end = msa_info[i + 1][0] | |
| label = msa_info[i][1] | |
| segments.append((start, end, label)) | |
| return segments | |
| def compute_iou_for_label(segments_a, segments_b, label): | |
| # segments_a, segments_b: [(start, end, label)] | |
| # only process the current label | |
| intervals_a = [(s, e) for s, e, l in segments_a if l == label] | |
| intervals_b = [(s, e) for s, e, l in segments_b if l == label] | |
| # sum up all intersections between a and b | |
| intersection = 0.0 | |
| for sa, ea in intervals_a: | |
| for sb, eb in intervals_b: | |
| left = max(sa, sb) | |
| right = min(ea, eb) | |
| if left < right: | |
| intersection += right - left | |
| # union = total length of both sets - overlapping intersection | |
| length_a = sum([e - s for s, e in intervals_a]) | |
| length_b = sum([e - s for s, e in intervals_b]) | |
| union = length_a + length_b - intersection | |
| if union == 0: | |
| return 0.0 | |
| return intersection / union, intersection, union | |
| def compute_mean_iou(segments_a, segments_b, labels): | |
| ious = [] | |
| for label in labels: | |
| iou, intsec_dur, uni_dur = compute_iou_for_label(segments_a, segments_b, label) | |
| ious.append( | |
| {"label": label, "iou": iou, "intsec_dur": intsec_dur, "uni_dur": uni_dur} | |
| ) | |
| return ious | |
| def cal_iou(ann_info, est_info): | |
| if type(ann_info) is str: | |
| assert os.path.exists(ann_info), f"{ann_info} not exists" | |
| ann_info = load_msa_info(ann_info) | |
| if type(est_info) is str: | |
| assert os.path.exists(est_info), f"{est_info} not exists" | |
| est_info = load_msa_info(est_info) | |
| segments_ann = msa_info_to_segments(ann_info) | |
| segments_est = msa_info_to_segments(est_info) | |
| occurred_labels = list( | |
| set([l for s, e, l in segments_ann]) | set(l for s, e, l in segments_est) | |
| ) | |
| mean_iou = compute_mean_iou(segments_ann, segments_est, occurred_labels) | |
| return mean_iou | |
| if __name__ == "__main__": | |
| ann_info = "" | |
| est_info = "" | |
| pprint(cal_iou(ann_info, est_info)) | |