|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import List, Dict, Any |
|
|
from dataclasses import dataclass, field |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from fairseq import utils |
|
|
from fairseq.logging import metrics |
|
|
from fairseq.criterions import FairseqCriterion, register_criterion |
|
|
from fairseq.dataclass import FairseqDataclass |
|
|
from fairseq.data.data_utils import lengths_to_mask |
|
|
from fairseq.models.fairseq_model import FairseqEncoderModel |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class FastSpeech2CriterionConfig(FairseqDataclass): |
|
|
ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"}) |
|
|
|
|
|
|
|
|
@register_criterion("fastspeech2", dataclass=FastSpeech2CriterionConfig) |
|
|
class FastSpeech2Loss(FairseqCriterion): |
|
|
def __init__(self, task, ctc_weight): |
|
|
super().__init__(task) |
|
|
self.ctc_weight = ctc_weight |
|
|
|
|
|
def forward(self, model: FairseqEncoderModel, sample, reduction="mean"): |
|
|
src_tokens = sample["net_input"]["src_tokens"] |
|
|
src_lens = sample["net_input"]["src_lengths"] |
|
|
tgt_lens = sample["target_lengths"] |
|
|
_feat_out, _feat_out_post, _, log_dur_out, pitch_out, energy_out = model( |
|
|
src_tokens=src_tokens, |
|
|
src_lengths=src_lens, |
|
|
prev_output_tokens=sample["net_input"]["prev_output_tokens"], |
|
|
incremental_state=None, |
|
|
target_lengths=tgt_lens, |
|
|
speaker=sample["speaker"], |
|
|
durations=sample["durations"], |
|
|
pitches=sample["pitches"], |
|
|
energies=sample["energies"], |
|
|
) |
|
|
|
|
|
src_mask = lengths_to_mask(sample["net_input"]["src_lengths"]) |
|
|
tgt_mask = lengths_to_mask(sample["target_lengths"]) |
|
|
|
|
|
pitches, energies = sample["pitches"], sample["energies"] |
|
|
pitch_out, pitches = pitch_out[src_mask], pitches[src_mask] |
|
|
energy_out, energies = energy_out[src_mask], energies[src_mask] |
|
|
|
|
|
feat_out, feat = _feat_out[tgt_mask], sample["target"][tgt_mask] |
|
|
l1_loss = F.l1_loss(feat_out, feat, reduction=reduction) |
|
|
if _feat_out_post is not None: |
|
|
l1_loss += F.l1_loss(_feat_out_post[tgt_mask], feat, reduction=reduction) |
|
|
|
|
|
pitch_loss = F.mse_loss(pitch_out, pitches, reduction=reduction) |
|
|
energy_loss = F.mse_loss(energy_out, energies, reduction=reduction) |
|
|
|
|
|
log_dur_out = log_dur_out[src_mask] |
|
|
dur = sample["durations"].float() |
|
|
dur = dur.half() if log_dur_out.type().endswith(".HalfTensor") else dur |
|
|
log_dur = torch.log(dur + 1)[src_mask] |
|
|
dur_loss = F.mse_loss(log_dur_out, log_dur, reduction=reduction) |
|
|
|
|
|
ctc_loss = torch.tensor(0.0).type_as(l1_loss) |
|
|
if self.ctc_weight > 0.0: |
|
|
lprobs = model.get_normalized_probs((_feat_out,), log_probs=True) |
|
|
lprobs = lprobs.transpose(0, 1) |
|
|
src_mask = lengths_to_mask(src_lens) |
|
|
src_tokens_flat = src_tokens.masked_select(src_mask) |
|
|
ctc_loss = ( |
|
|
F.ctc_loss( |
|
|
lprobs, |
|
|
src_tokens_flat, |
|
|
tgt_lens, |
|
|
src_lens, |
|
|
reduction=reduction, |
|
|
zero_infinity=True, |
|
|
) |
|
|
* self.ctc_weight |
|
|
) |
|
|
|
|
|
loss = l1_loss + dur_loss + pitch_loss + energy_loss + ctc_loss |
|
|
|
|
|
sample_size = sample["nsentences"] |
|
|
logging_output = { |
|
|
"loss": utils.item(loss.data), |
|
|
"ntokens": sample["ntokens"], |
|
|
"nsentences": sample["nsentences"], |
|
|
"sample_size": sample_size, |
|
|
"l1_loss": utils.item(l1_loss.data), |
|
|
"dur_loss": utils.item(dur_loss.data), |
|
|
"pitch_loss": utils.item(pitch_loss.data), |
|
|
"energy_loss": utils.item(energy_loss.data), |
|
|
"ctc_loss": utils.item(ctc_loss.data), |
|
|
} |
|
|
return loss, sample_size, logging_output |
|
|
|
|
|
@classmethod |
|
|
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None: |
|
|
ns = [log.get("sample_size", 0) for log in logging_outputs] |
|
|
ntot = sum(ns) |
|
|
ws = [n / (ntot + 1e-8) for n in ns] |
|
|
for key in [ |
|
|
"loss", |
|
|
"l1_loss", |
|
|
"dur_loss", |
|
|
"pitch_loss", |
|
|
"energy_loss", |
|
|
"ctc_loss", |
|
|
]: |
|
|
vals = [log.get(key, 0) for log in logging_outputs] |
|
|
val = sum(val * w for val, w in zip(vals, ws)) |
|
|
metrics.log_scalar(key, val, ntot, round=3) |
|
|
metrics.log_scalar("sample_size", ntot, len(logging_outputs)) |
|
|
|
|
|
|
|
|
if "targ_frames" not in logging_outputs[0]: |
|
|
return |
|
|
n = sum(log.get("targ_frames", 0) for log in logging_outputs) |
|
|
for key, new_key in [ |
|
|
("mcd_loss", "mcd_loss"), |
|
|
("pred_frames", "pred_ratio"), |
|
|
("nins", "ins_rate"), |
|
|
("ndel", "del_rate"), |
|
|
]: |
|
|
val = sum(log.get(key, 0) for log in logging_outputs) |
|
|
metrics.log_scalar(new_key, val / n, n, round=3) |
|
|
|
|
|
@staticmethod |
|
|
def logging_outputs_can_be_summed() -> bool: |
|
|
return False |
|
|
|