#!/usr/bin/env python # encoding: utf-8 # The MIT License (MIT) # Copyright (c) 2012-2019 CNRS # Permission is hereby granted, free of charge, to any person obtaining a copy # of this software and associated documentation files (the "Software"), to deal # in the Software without restriction, including without limitation the rights # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell # copies of the Software, and to permit persons to whom the Software is # furnished to do so, subject to the following conditions: # The above copyright notice and this permission notice shall be included in # all copies or substantial portions of the Software. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. # AUTHORS # Hervé BREDIN - http://herve.niderb.fr from typing import Optional from pyannote_audio_utils.core import Annotation, Timeline from .base import BaseMetric from .base import Precision, PRECISION_RETRIEVED, PRECISION_RELEVANT_RETRIEVED from .base import Recall, RECALL_RELEVANT, RECALL_RELEVANT_RETRIEVED from .matcher import LabelMatcher, \ MATCH_TOTAL, MATCH_CORRECT, MATCH_CONFUSION, \ MATCH_MISSED_DETECTION, MATCH_FALSE_ALARM from .types import MetricComponents, Details from .utils import UEMSupportMixin # TODO: can't we put these as class attributes? IER_TOTAL = MATCH_TOTAL IER_CORRECT = MATCH_CORRECT IER_CONFUSION = MATCH_CONFUSION IER_FALSE_ALARM = MATCH_FALSE_ALARM IER_MISS = MATCH_MISSED_DETECTION IER_NAME = 'identification error rate' class IdentificationErrorRate(UEMSupportMixin, BaseMetric): """Identification error rate ``ier = (wc x confusion + wf x false_alarm + wm x miss) / total`` where - `confusion` is the total confusion duration in seconds - `false_alarm` is the total hypothesis duration where there are - `miss` is - `total` is the total duration of all tracks - wc, wf and wm are optional weights (default to 1) Parameters ---------- collar : float, optional Duration (in seconds) of collars removed from evaluation around boundaries of reference segments. skip_overlap : bool, optional Set to True to not evaluate overlap regions. Defaults to False (i.e. keep overlap regions). confusion, miss, false_alarm: float, optional Optional weights for confusion, miss and false alarm respectively. Default to 1. (no weight) """ @classmethod def metric_name(cls) -> str: return IER_NAME @classmethod def metric_components(cls) -> MetricComponents: return [ IER_TOTAL, IER_CORRECT, IER_FALSE_ALARM, IER_MISS, IER_CONFUSION] def __init__(self, confusion: float = 1., miss: float = 1., false_alarm: float = 1., collar: float = 0., skip_overlap: bool = False, **kwargs): super().__init__(**kwargs) self.matcher_ = LabelMatcher() self.confusion = confusion self.miss = miss self.false_alarm = false_alarm self.collar = collar self.skip_overlap = skip_overlap def compute_components(self, reference: Annotation, hypothesis: Annotation, uem: Optional[Timeline] = None, collar: Optional[float] = None, skip_overlap: Optional[float] = None, **kwargs) -> Details: """ Parameters ---------- collar : float, optional Override self.collar skip_overlap : bool, optional Override self.skip_overlap See also -------- :class:`pyannote_audio_utils.metric.diarization.DiarizationErrorRate` uses these two options in its `compute_components` method. """ detail = self.init_components() if collar is None: collar = self.collar if skip_overlap is None: skip_overlap = self.skip_overlap R, H, common_timeline = self.uemify( reference, hypothesis, uem=uem, collar=collar, skip_overlap=skip_overlap, returns_timeline=True) # loop on all segments for segment in common_timeline: # segment duration duration = segment.duration # list of IDs in reference segment r = R.get_labels(segment, unique=False) # list of IDs in hypothesis segment h = H.get_labels(segment, unique=False) counts, _ = self.matcher_(r, h) detail[IER_TOTAL] += duration * counts[IER_TOTAL] detail[IER_CORRECT] += duration * counts[IER_CORRECT] detail[IER_CONFUSION] += duration * counts[IER_CONFUSION] detail[IER_MISS] += duration * counts[IER_MISS] detail[IER_FALSE_ALARM] += duration * counts[IER_FALSE_ALARM] return detail def compute_metric(self, detail: Details) -> float: numerator = 1. * ( self.confusion * detail[IER_CONFUSION] + self.false_alarm * detail[IER_FALSE_ALARM] + self.miss * detail[IER_MISS] ) denominator = 1. * detail[IER_TOTAL] if denominator == 0.: if numerator == 0: return 0. else: return 1. else: return numerator / denominator class IdentificationPrecision(UEMSupportMixin, Precision): """Identification Precision Parameters ---------- collar : float, optional Duration (in seconds) of collars removed from evaluation around boundaries of reference segments. skip_overlap : bool, optional Set to True to not evaluate overlap regions. Defaults to False (i.e. keep overlap regions). """ def __init__(self, collar: float = 0., skip_overlap: bool = False, **kwargs): super().__init__(**kwargs) self.collar = collar self.skip_overlap = skip_overlap self.matcher_ = LabelMatcher() def compute_components(self, reference: Annotation, hypothesis: Annotation, uem: Optional[Timeline] = None, **kwargs) -> Details: detail = self.init_components() R, H, common_timeline = self.uemify( reference, hypothesis, uem=uem, collar=self.collar, skip_overlap=self.skip_overlap, returns_timeline=True) # loop on all segments for segment in common_timeline: # segment duration duration = segment.duration # list of IDs in reference segment r = R.get_labels(segment, unique=False) # list of IDs in hypothesis segment h = H.get_labels(segment, unique=False) counts, _ = self.matcher_(r, h) detail[PRECISION_RETRIEVED] += duration * len(h) detail[PRECISION_RELEVANT_RETRIEVED] += \ duration * counts[IER_CORRECT] return detail class IdentificationRecall(UEMSupportMixin, Recall): """Identification Recall Parameters ---------- collar : float, optional Duration (in seconds) of collars removed from evaluation around boundaries of reference segments. skip_overlap : bool, optional Set to True to not evaluate overlap regions. Defaults to False (i.e. keep overlap regions). """ def __init__(self, collar: float = 0., skip_overlap: bool = False, **kwargs): super().__init__(**kwargs) self.collar = collar self.skip_overlap = skip_overlap self.matcher_ = LabelMatcher() def compute_components(self, reference: Annotation, hypothesis: Annotation, uem: Optional[Timeline] = None, **kwargs) -> Details: detail = self.init_components() R, H, common_timeline = self.uemify( reference, hypothesis, uem=uem, collar=self.collar, skip_overlap=self.skip_overlap, returns_timeline=True) # loop on all segments for segment in common_timeline: # segment duration duration = segment.duration # list of IDs in reference segment r = R.get_labels(segment, unique=False) # list of IDs in hypothesis segment h = H.get_labels(segment, unique=False) counts, _ = self.matcher_(r, h) detail[RECALL_RELEVANT] += duration * counts[IER_TOTAL] detail[RECALL_RELEVANT_RETRIEVED] += duration * counts[IER_CORRECT] return detail