ASLP-lab's picture
init
70d8fcf
import os
from dataset.custom_types import MsaInfo
from dataset.label2id import LABEL_TO_ID
from pprint import pprint
def load_msa_info(msa_info_path):
msa_info: MsaInfo = []
with open(msa_info_path) as f:
for line in f:
line = line.strip()
if not line:
continue
time_, label = line.split()
time_ = float(time_)
label = str(label)
assert label in LABEL_TO_ID or label == "end", f"{label} not in LABEL_TO_ID"
msa_info.append((time_, label))
assert msa_info[-1][1] == "end", f"last {msa_info[-1][1]} != end"
return msa_info
def msa_info_to_segments(msa_info):
# skip the last "end"
segments = []
for i in range(len(msa_info) - 1):
start = msa_info[i][0]
end = msa_info[i + 1][0]
label = msa_info[i][1]
segments.append((start, end, label))
return segments
def compute_iou_for_label(segments_a, segments_b, label):
# segments_a, segments_b: [(start, end, label)]
# only process the current label
intervals_a = [(s, e) for s, e, l in segments_a if l == label]
intervals_b = [(s, e) for s, e, l in segments_b if l == label]
# sum up all intersections between a and b
intersection = 0.0
for sa, ea in intervals_a:
for sb, eb in intervals_b:
left = max(sa, sb)
right = min(ea, eb)
if left < right:
intersection += right - left
# union = total length of both sets - overlapping intersection
length_a = sum([e - s for s, e in intervals_a])
length_b = sum([e - s for s, e in intervals_b])
union = length_a + length_b - intersection
if union == 0:
return 0.0
return intersection / union, intersection, union
def compute_mean_iou(segments_a, segments_b, labels):
ious = []
for label in labels:
iou, intsec_dur, uni_dur = compute_iou_for_label(segments_a, segments_b, label)
ious.append(
{"label": label, "iou": iou, "intsec_dur": intsec_dur, "uni_dur": uni_dur}
)
return ious
def cal_iou(ann_info, est_info):
if type(ann_info) is str:
assert os.path.exists(ann_info), f"{ann_info} not exists"
ann_info = load_msa_info(ann_info)
if type(est_info) is str:
assert os.path.exists(est_info), f"{est_info} not exists"
est_info = load_msa_info(est_info)
segments_ann = msa_info_to_segments(ann_info)
segments_est = msa_info_to_segments(est_info)
occurred_labels = list(
set([l for s, e, l in segments_ann]) | set(l for s, e, l in segments_est)
)
mean_iou = compute_mean_iou(segments_ann, segments_est, occurred_labels)
return mean_iou
if __name__ == "__main__":
ann_info = ""
est_info = ""
pprint(cal_iou(ann_info, est_info))