|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from collections import OrderedDict |
|
|
import itertools |
|
|
import logging |
|
|
import os |
|
|
import sys |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Optional |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
from omegaconf import II, MISSING |
|
|
from sklearn import metrics as sklearn_metrics |
|
|
|
|
|
from fairseq.data import AddTargetDataset, Dictionary, FileAudioDataset |
|
|
from fairseq.data.multi_corpus_dataset import MultiCorpusDataset |
|
|
from fairseq.data.text_compressor import TextCompressionLevel, TextCompressor |
|
|
from fairseq.dataclass import FairseqDataclass |
|
|
from fairseq.tasks.audio_pretraining import AudioPretrainingConfig, AudioPretrainingTask |
|
|
from fairseq.tasks.audio_finetuning import label_len_fn, LabelEncoder |
|
|
|
|
|
from .. import utils |
|
|
from ..logging import metrics |
|
|
from . import FairseqTask, register_task |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@dataclass |
|
|
class AudioClassificationConfig(AudioPretrainingConfig): |
|
|
target_dictionary: Optional[str] = field( |
|
|
default=None, metadata={"help": "override default dictionary location"} |
|
|
) |
|
|
|
|
|
|
|
|
@register_task("audio_classification", dataclass=AudioClassificationConfig) |
|
|
class AudioClassificationTask(AudioPretrainingTask): |
|
|
"""Task for audio classification tasks.""" |
|
|
|
|
|
cfg: AudioClassificationConfig |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
cfg: AudioClassificationConfig, |
|
|
): |
|
|
super().__init__(cfg) |
|
|
self.state.add_factory("target_dictionary", self.load_target_dictionary) |
|
|
logging.info(f"=== Number of labels = {len(self.target_dictionary)}") |
|
|
|
|
|
def load_target_dictionary(self): |
|
|
if self.cfg.labels: |
|
|
target_dictionary = self.cfg.data |
|
|
if self.cfg.target_dictionary: |
|
|
target_dictionary = self.cfg.target_dictionary |
|
|
dict_path = os.path.join(target_dictionary, f"dict.{self.cfg.labels}.txt") |
|
|
logger.info("Using dict_path : {}".format(dict_path)) |
|
|
return Dictionary.load(dict_path, add_special_symbols=False) |
|
|
return None |
|
|
|
|
|
def load_dataset( |
|
|
self, split: str, task_cfg: AudioClassificationConfig = None, **kwargs |
|
|
): |
|
|
super().load_dataset(split, task_cfg, **kwargs) |
|
|
task_cfg = task_cfg or self.cfg |
|
|
assert task_cfg.labels is not None |
|
|
text_compression_level = getattr( |
|
|
TextCompressionLevel, str(self.cfg.text_compression_level) |
|
|
) |
|
|
data_path = self.cfg.data |
|
|
if task_cfg.multi_corpus_keys is None: |
|
|
label_path = os.path.join(data_path, f"{split}.{task_cfg.labels}") |
|
|
skipped_indices = getattr(self.datasets[split], "skipped_indices", set()) |
|
|
text_compressor = TextCompressor(level=text_compression_level) |
|
|
with open(label_path, "r") as f: |
|
|
labels = [ |
|
|
text_compressor.compress(l) |
|
|
for i, l in enumerate(f) |
|
|
if i not in skipped_indices |
|
|
] |
|
|
|
|
|
assert len(labels) == len(self.datasets[split]), ( |
|
|
f"labels length ({len(labels)}) and dataset length " |
|
|
f"({len(self.datasets[split])}) do not match" |
|
|
) |
|
|
|
|
|
process_label = LabelEncoder(self.target_dictionary) |
|
|
|
|
|
self.datasets[split] = AddTargetDataset( |
|
|
self.datasets[split], |
|
|
labels, |
|
|
pad=self.target_dictionary.pad(), |
|
|
eos=self.target_dictionary.eos(), |
|
|
batch_targets=True, |
|
|
process_label=process_label, |
|
|
label_len_fn=label_len_fn, |
|
|
add_to_input=False, |
|
|
|
|
|
) |
|
|
else: |
|
|
target_dataset_map = OrderedDict() |
|
|
|
|
|
multi_corpus_keys = [ |
|
|
k.strip() for k in task_cfg.multi_corpus_keys.split(",") |
|
|
] |
|
|
corpus_idx_map = {k: idx for idx, k in enumerate(multi_corpus_keys)} |
|
|
|
|
|
data_keys = [k.split(":") for k in split.split(",")] |
|
|
|
|
|
multi_corpus_sampling_weights = [ |
|
|
float(val.strip()) |
|
|
for val in task_cfg.multi_corpus_sampling_weights.split(",") |
|
|
] |
|
|
data_weights = [] |
|
|
for key, file_name in data_keys: |
|
|
k = key.strip() |
|
|
label_path = os.path.join( |
|
|
data_path, f"{file_name.strip()}.{task_cfg.labels}" |
|
|
) |
|
|
skipped_indices = getattr( |
|
|
self.dataset_map[split][k], "skipped_indices", set() |
|
|
) |
|
|
text_compressor = TextCompressor(level=text_compression_level) |
|
|
with open(label_path, "r") as f: |
|
|
labels = [ |
|
|
text_compressor.compress(l) |
|
|
for i, l in enumerate(f) |
|
|
if i not in skipped_indices |
|
|
] |
|
|
|
|
|
assert len(labels) == len(self.dataset_map[split][k]), ( |
|
|
f"labels length ({len(labels)}) and dataset length " |
|
|
f"({len(self.dataset_map[split][k])}) do not match" |
|
|
) |
|
|
|
|
|
process_label = LabelEncoder(self.target_dictionary) |
|
|
|
|
|
|
|
|
target_dataset_map[k] = AddTargetDataset( |
|
|
self.dataset_map[split][k], |
|
|
labels, |
|
|
pad=self.target_dictionary.pad(), |
|
|
eos=self.target_dictionary.eos(), |
|
|
batch_targets=True, |
|
|
process_label=process_label, |
|
|
label_len_fn=label_len_fn, |
|
|
add_to_input=False, |
|
|
|
|
|
) |
|
|
|
|
|
data_weights.append(multi_corpus_sampling_weights[corpus_idx_map[k]]) |
|
|
|
|
|
if len(target_dataset_map) == 1: |
|
|
self.datasets[split] = list(target_dataset_map.values())[0] |
|
|
else: |
|
|
self.datasets[split] = MultiCorpusDataset( |
|
|
target_dataset_map, |
|
|
distribution=data_weights, |
|
|
seed=0, |
|
|
sort_indices=True, |
|
|
) |
|
|
|
|
|
@property |
|
|
def source_dictionary(self): |
|
|
return None |
|
|
|
|
|
@property |
|
|
def target_dictionary(self): |
|
|
"""Return the :class:`~fairseq.data.Dictionary` for the language |
|
|
model.""" |
|
|
return self.state.target_dictionary |
|
|
|
|
|
def train_step(self, sample, model, *args, **kwargs): |
|
|
sample["target"] = sample["target"].to(dtype=torch.long) |
|
|
loss, sample_size, logging_output = super().train_step( |
|
|
sample, model, *args, **kwargs |
|
|
) |
|
|
self._log_metrics(sample, model, logging_output) |
|
|
return loss, sample_size, logging_output |
|
|
|
|
|
def valid_step(self, sample, model, criterion): |
|
|
sample["target"] = sample["target"].to(dtype=torch.long) |
|
|
loss, sample_size, logging_output = super().valid_step(sample, model, criterion) |
|
|
self._log_metrics(sample, model, logging_output) |
|
|
return loss, sample_size, logging_output |
|
|
|
|
|
def _log_metrics(self, sample, model, logging_output): |
|
|
metrics = self._inference_with_metrics( |
|
|
sample, |
|
|
model, |
|
|
) |
|
|
""" |
|
|
logging_output["_precision"] = metrics["precision"] |
|
|
logging_output["_recall"] = metrics["recall"] |
|
|
logging_output["_f1"] = metrics["f1"] |
|
|
logging_output["_eer"] = metrics["eer"] |
|
|
logging_output["_accuracy"] = metrics["accuracy"] |
|
|
""" |
|
|
logging_output["_correct"] = metrics["correct"] |
|
|
logging_output["_total"] = metrics["total"] |
|
|
|
|
|
def _inference_with_metrics(self, sample, model): |
|
|
def _compute_eer(target_list, lprobs): |
|
|
|
|
|
|
|
|
|
|
|
y_one_hot = np.eye(len(self.state.target_dictionary))[target_list] |
|
|
fpr, tpr, thresholds = sklearn_metrics.roc_curve( |
|
|
y_one_hot.ravel(), lprobs.ravel() |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
fnr = 1 - tpr |
|
|
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] |
|
|
|
|
|
return eer |
|
|
|
|
|
with torch.no_grad(): |
|
|
net_output = model(**sample["net_input"]) |
|
|
lprobs = ( |
|
|
model.get_normalized_probs(net_output, log_probs=True).cpu().detach() |
|
|
) |
|
|
target_list = sample["target"][:, 0].detach().cpu() |
|
|
predicted_list = torch.argmax(lprobs, 1).detach().cpu() |
|
|
|
|
|
metrics = { |
|
|
"correct": torch.sum(target_list == predicted_list).item(), |
|
|
"total": len(target_list), |
|
|
} |
|
|
return metrics |
|
|
|
|
|
def reduce_metrics(self, logging_outputs, criterion): |
|
|
super().reduce_metrics(logging_outputs, criterion) |
|
|
|
|
|
zero = torch.scalar_tensor(0.0) |
|
|
correct, total = 0, 0 |
|
|
for log in logging_outputs: |
|
|
correct += log.get("_correct", zero) |
|
|
total += log.get("_total", zero) |
|
|
metrics.log_scalar("_correct", correct) |
|
|
metrics.log_scalar("_total", total) |
|
|
|
|
|
if total > 0: |
|
|
def _fn_accuracy(meters): |
|
|
if meters["_total"].sum > 0: |
|
|
return utils.item(meters["_correct"].sum / meters["_total"].sum) |
|
|
return float("nan") |
|
|
|
|
|
metrics.log_derived("accuracy", _fn_accuracy) |
|
|
""" |
|
|
prec_sum, recall_sum, f1_sum, acc_sum, eer_sum = 0.0, 0.0, 0.0, 0.0, 0.0 |
|
|
for log in logging_outputs: |
|
|
prec_sum += log.get("_precision", zero).item() |
|
|
recall_sum += log.get("_recall", zero).item() |
|
|
f1_sum += log.get("_f1", zero).item() |
|
|
acc_sum += log.get("_accuracy", zero).item() |
|
|
eer_sum += log.get("_eer", zero).item() |
|
|
|
|
|
metrics.log_scalar("avg_precision", prec_sum / len(logging_outputs)) |
|
|
metrics.log_scalar("avg_recall", recall_sum / len(logging_outputs)) |
|
|
metrics.log_scalar("avg_f1", f1_sum / len(logging_outputs)) |
|
|
metrics.log_scalar("avg_accuracy", acc_sum / len(logging_outputs)) |
|
|
metrics.log_scalar("avg_eer", eer_sum / len(logging_outputs)) |
|
|
""" |