|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
|
|
from pyannote_audio_utils.core import Annotation, Timeline |
|
|
|
|
|
from .base import BaseMetric |
|
|
from .base import Precision, PRECISION_RETRIEVED, PRECISION_RELEVANT_RETRIEVED |
|
|
from .base import Recall, RECALL_RELEVANT, RECALL_RELEVANT_RETRIEVED |
|
|
from .matcher import LabelMatcher, \ |
|
|
MATCH_TOTAL, MATCH_CORRECT, MATCH_CONFUSION, \ |
|
|
MATCH_MISSED_DETECTION, MATCH_FALSE_ALARM |
|
|
from .types import MetricComponents, Details |
|
|
from .utils import UEMSupportMixin |
|
|
|
|
|
|
|
|
IER_TOTAL = MATCH_TOTAL |
|
|
IER_CORRECT = MATCH_CORRECT |
|
|
IER_CONFUSION = MATCH_CONFUSION |
|
|
IER_FALSE_ALARM = MATCH_FALSE_ALARM |
|
|
IER_MISS = MATCH_MISSED_DETECTION |
|
|
IER_NAME = 'identification error rate' |
|
|
|
|
|
|
|
|
class IdentificationErrorRate(UEMSupportMixin, BaseMetric): |
|
|
"""Identification error rate |
|
|
|
|
|
``ier = (wc x confusion + wf x false_alarm + wm x miss) / total`` |
|
|
|
|
|
where |
|
|
- `confusion` is the total confusion duration in seconds |
|
|
- `false_alarm` is the total hypothesis duration where there are |
|
|
- `miss` is |
|
|
- `total` is the total duration of all tracks |
|
|
- wc, wf and wm are optional weights (default to 1) |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
collar : float, optional |
|
|
Duration (in seconds) of collars removed from evaluation around |
|
|
boundaries of reference segments. |
|
|
skip_overlap : bool, optional |
|
|
Set to True to not evaluate overlap regions. |
|
|
Defaults to False (i.e. keep overlap regions). |
|
|
confusion, miss, false_alarm: float, optional |
|
|
Optional weights for confusion, miss and false alarm respectively. |
|
|
Default to 1. (no weight) |
|
|
""" |
|
|
|
|
|
@classmethod |
|
|
def metric_name(cls) -> str: |
|
|
return IER_NAME |
|
|
|
|
|
@classmethod |
|
|
def metric_components(cls) -> MetricComponents: |
|
|
return [ |
|
|
IER_TOTAL, |
|
|
IER_CORRECT, |
|
|
IER_FALSE_ALARM, IER_MISS, |
|
|
IER_CONFUSION] |
|
|
|
|
|
def __init__(self, |
|
|
confusion: float = 1., |
|
|
miss: float = 1., |
|
|
false_alarm: float = 1., |
|
|
collar: float = 0., |
|
|
skip_overlap: bool = False, |
|
|
**kwargs): |
|
|
|
|
|
super().__init__(**kwargs) |
|
|
self.matcher_ = LabelMatcher() |
|
|
self.confusion = confusion |
|
|
self.miss = miss |
|
|
self.false_alarm = false_alarm |
|
|
self.collar = collar |
|
|
self.skip_overlap = skip_overlap |
|
|
|
|
|
def compute_components(self, |
|
|
reference: Annotation, |
|
|
hypothesis: Annotation, |
|
|
uem: Optional[Timeline] = None, |
|
|
collar: Optional[float] = None, |
|
|
skip_overlap: Optional[float] = None, |
|
|
**kwargs) -> Details: |
|
|
""" |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
collar : float, optional |
|
|
Override self.collar |
|
|
skip_overlap : bool, optional |
|
|
Override self.skip_overlap |
|
|
|
|
|
See also |
|
|
-------- |
|
|
:class:`pyannote_audio_utils.metric.diarization.DiarizationErrorRate` uses these |
|
|
two options in its `compute_components` method. |
|
|
|
|
|
""" |
|
|
|
|
|
detail = self.init_components() |
|
|
|
|
|
if collar is None: |
|
|
collar = self.collar |
|
|
if skip_overlap is None: |
|
|
skip_overlap = self.skip_overlap |
|
|
|
|
|
R, H, common_timeline = self.uemify( |
|
|
reference, hypothesis, uem=uem, |
|
|
collar=collar, skip_overlap=skip_overlap, |
|
|
returns_timeline=True) |
|
|
|
|
|
|
|
|
for segment in common_timeline: |
|
|
|
|
|
duration = segment.duration |
|
|
|
|
|
|
|
|
r = R.get_labels(segment, unique=False) |
|
|
|
|
|
|
|
|
h = H.get_labels(segment, unique=False) |
|
|
|
|
|
counts, _ = self.matcher_(r, h) |
|
|
|
|
|
detail[IER_TOTAL] += duration * counts[IER_TOTAL] |
|
|
detail[IER_CORRECT] += duration * counts[IER_CORRECT] |
|
|
detail[IER_CONFUSION] += duration * counts[IER_CONFUSION] |
|
|
detail[IER_MISS] += duration * counts[IER_MISS] |
|
|
detail[IER_FALSE_ALARM] += duration * counts[IER_FALSE_ALARM] |
|
|
|
|
|
return detail |
|
|
|
|
|
def compute_metric(self, detail: Details) -> float: |
|
|
|
|
|
numerator = 1. * ( |
|
|
self.confusion * detail[IER_CONFUSION] + |
|
|
self.false_alarm * detail[IER_FALSE_ALARM] + |
|
|
self.miss * detail[IER_MISS] |
|
|
) |
|
|
denominator = 1. * detail[IER_TOTAL] |
|
|
if denominator == 0.: |
|
|
if numerator == 0: |
|
|
return 0. |
|
|
else: |
|
|
return 1. |
|
|
else: |
|
|
return numerator / denominator |
|
|
|
|
|
|
|
|
class IdentificationPrecision(UEMSupportMixin, Precision): |
|
|
"""Identification Precision |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
collar : float, optional |
|
|
Duration (in seconds) of collars removed from evaluation around |
|
|
boundaries of reference segments. |
|
|
skip_overlap : bool, optional |
|
|
Set to True to not evaluate overlap regions. |
|
|
Defaults to False (i.e. keep overlap regions). |
|
|
""" |
|
|
|
|
|
def __init__(self, collar: float = 0., skip_overlap: bool = False, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.collar = collar |
|
|
self.skip_overlap = skip_overlap |
|
|
self.matcher_ = LabelMatcher() |
|
|
|
|
|
def compute_components(self, |
|
|
reference: Annotation, |
|
|
hypothesis: Annotation, |
|
|
uem: Optional[Timeline] = None, |
|
|
**kwargs) -> Details: |
|
|
detail = self.init_components() |
|
|
|
|
|
R, H, common_timeline = self.uemify( |
|
|
reference, hypothesis, uem=uem, |
|
|
collar=self.collar, skip_overlap=self.skip_overlap, |
|
|
returns_timeline=True) |
|
|
|
|
|
|
|
|
for segment in common_timeline: |
|
|
|
|
|
duration = segment.duration |
|
|
|
|
|
|
|
|
r = R.get_labels(segment, unique=False) |
|
|
|
|
|
|
|
|
h = H.get_labels(segment, unique=False) |
|
|
|
|
|
counts, _ = self.matcher_(r, h) |
|
|
|
|
|
detail[PRECISION_RETRIEVED] += duration * len(h) |
|
|
detail[PRECISION_RELEVANT_RETRIEVED] += \ |
|
|
duration * counts[IER_CORRECT] |
|
|
|
|
|
return detail |
|
|
|
|
|
|
|
|
class IdentificationRecall(UEMSupportMixin, Recall): |
|
|
"""Identification Recall |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
collar : float, optional |
|
|
Duration (in seconds) of collars removed from evaluation around |
|
|
boundaries of reference segments. |
|
|
skip_overlap : bool, optional |
|
|
Set to True to not evaluate overlap regions. |
|
|
Defaults to False (i.e. keep overlap regions). |
|
|
""" |
|
|
|
|
|
def __init__(self, collar: float = 0., skip_overlap: bool = False, **kwargs): |
|
|
super().__init__(**kwargs) |
|
|
self.collar = collar |
|
|
self.skip_overlap = skip_overlap |
|
|
self.matcher_ = LabelMatcher() |
|
|
|
|
|
def compute_components(self, |
|
|
reference: Annotation, |
|
|
hypothesis: Annotation, |
|
|
uem: Optional[Timeline] = None, |
|
|
**kwargs) -> Details: |
|
|
detail = self.init_components() |
|
|
|
|
|
R, H, common_timeline = self.uemify( |
|
|
reference, hypothesis, uem=uem, |
|
|
collar=self.collar, skip_overlap=self.skip_overlap, |
|
|
returns_timeline=True) |
|
|
|
|
|
|
|
|
for segment in common_timeline: |
|
|
|
|
|
duration = segment.duration |
|
|
|
|
|
|
|
|
r = R.get_labels(segment, unique=False) |
|
|
|
|
|
|
|
|
h = H.get_labels(segment, unique=False) |
|
|
|
|
|
counts, _ = self.matcher_(r, h) |
|
|
|
|
|
detail[RECALL_RELEVANT] += duration * counts[IER_TOTAL] |
|
|
detail[RECALL_RELEVANT_RETRIEVED] += duration * counts[IER_CORRECT] |
|
|
|
|
|
return detail |
|
|
|