|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, Tuple, Iterable, List, TYPE_CHECKING |
|
|
|
|
|
import numpy as np |
|
|
from pyannote_audio_utils.core import Annotation |
|
|
from scipy.optimize import linear_sum_assignment |
|
|
|
|
|
if TYPE_CHECKING: |
|
|
from pyannote_audio_utils.core.utils.types import Label |
|
|
|
|
|
MATCH_CORRECT = 'correct' |
|
|
MATCH_CONFUSION = 'confusion' |
|
|
MATCH_MISSED_DETECTION = 'missed detection' |
|
|
MATCH_FALSE_ALARM = 'false alarm' |
|
|
MATCH_TOTAL = 'total' |
|
|
|
|
|
|
|
|
class LabelMatcher: |
|
|
""" |
|
|
ID matcher base class mixin. |
|
|
|
|
|
All ID matcher classes must inherit from this class and implement |
|
|
.match() -- ie return True if two IDs match and False |
|
|
otherwise. |
|
|
""" |
|
|
|
|
|
def match(self, rlabel: 'Label', hlabel: 'Label') -> bool: |
|
|
""" |
|
|
Parameters |
|
|
---------- |
|
|
rlabel : |
|
|
Reference label |
|
|
hlabel : |
|
|
Hypothesis label |
|
|
|
|
|
Returns |
|
|
------- |
|
|
match : bool |
|
|
True if labels match, False otherwise. |
|
|
|
|
|
""" |
|
|
|
|
|
return rlabel == hlabel |
|
|
|
|
|
def __call__(self, rlabels: Iterable['Label'], hlabels: Iterable['Label']) \ |
|
|
-> Tuple[Dict[str, int], |
|
|
Dict[str, List['Label']]]: |
|
|
""" |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
rlabels, hlabels : iterable |
|
|
Reference and hypothesis labels |
|
|
|
|
|
Returns |
|
|
------- |
|
|
counts : dict |
|
|
details : dict |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
counts = { |
|
|
MATCH_CORRECT: 0, |
|
|
MATCH_CONFUSION: 0, |
|
|
MATCH_MISSED_DETECTION: 0, |
|
|
MATCH_FALSE_ALARM: 0, |
|
|
MATCH_TOTAL: 0 |
|
|
} |
|
|
|
|
|
details = { |
|
|
MATCH_CORRECT: [], |
|
|
MATCH_CONFUSION: [], |
|
|
MATCH_MISSED_DETECTION: [], |
|
|
MATCH_FALSE_ALARM: [] |
|
|
} |
|
|
|
|
|
|
|
|
rlabels = list(rlabels) |
|
|
hlabels = list(hlabels) |
|
|
|
|
|
NR = len(rlabels) |
|
|
NH = len(hlabels) |
|
|
N = max(NR, NH) |
|
|
|
|
|
|
|
|
if N == 0: |
|
|
return counts, details |
|
|
|
|
|
|
|
|
|
|
|
match = np.zeros((N, N), dtype=bool) |
|
|
for r, rlabel in enumerate(rlabels): |
|
|
for h, hlabel in enumerate(hlabels): |
|
|
match[r, h] = self.match(rlabel, hlabel) |
|
|
|
|
|
|
|
|
|
|
|
for r, h in zip(*linear_sum_assignment(~match)): |
|
|
|
|
|
|
|
|
|
|
|
if r >= NR: |
|
|
counts[MATCH_FALSE_ALARM] += 1 |
|
|
details[MATCH_FALSE_ALARM].append(hlabels[h]) |
|
|
|
|
|
|
|
|
|
|
|
elif h >= NH: |
|
|
counts[MATCH_MISSED_DETECTION] += 1 |
|
|
details[MATCH_MISSED_DETECTION].append(rlabels[r]) |
|
|
|
|
|
|
|
|
|
|
|
elif match[r, h]: |
|
|
counts[MATCH_CORRECT] += 1 |
|
|
details[MATCH_CORRECT].append((rlabels[r], hlabels[h])) |
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
counts[MATCH_CONFUSION] += 1 |
|
|
details[MATCH_CONFUSION].append((rlabels[r], hlabels[h])) |
|
|
|
|
|
counts[MATCH_TOTAL] += NR |
|
|
|
|
|
|
|
|
return counts, details |
|
|
|
|
|
|
|
|
class HungarianMapper: |
|
|
|
|
|
def __call__(self, A: Annotation, B: Annotation) -> Dict['Label', 'Label']: |
|
|
mapping = {} |
|
|
|
|
|
cooccurrence = A * B |
|
|
a_labels, b_labels = A.labels(), B.labels() |
|
|
|
|
|
for a, b in zip(*linear_sum_assignment(-cooccurrence)): |
|
|
if cooccurrence[a, b] > 0: |
|
|
mapping[a_labels[a]] = b_labels[b] |
|
|
|
|
|
return mapping |
|
|
|
|
|
|
|
|
class GreedyMapper: |
|
|
|
|
|
def __call__(self, A: Annotation, B: Annotation) -> Dict['Label', 'Label']: |
|
|
mapping = {} |
|
|
|
|
|
cooccurrence = A * B |
|
|
Na, Nb = cooccurrence.shape |
|
|
a_labels, b_labels = A.labels(), B.labels() |
|
|
|
|
|
for i in range(min(Na, Nb)): |
|
|
a, b = np.unravel_index(np.argmax(cooccurrence), (Na, Nb)) |
|
|
|
|
|
if cooccurrence[a, b] > 0: |
|
|
mapping[a_labels[a]] = b_labels[b] |
|
|
cooccurrence[a, :] = 0. |
|
|
cooccurrence[:, b] = 0. |
|
|
continue |
|
|
|
|
|
break |
|
|
|
|
|
return mapping |
|
|
|