File size: 5,350 Bytes
85ba398 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
from typing import List, Dict, Any
from dataclasses import dataclass, field
import torch
import torch.nn.functional as F
from fairseq import utils
from fairseq.logging import metrics
from fairseq.criterions import FairseqCriterion, register_criterion
from fairseq.dataclass import FairseqDataclass
from fairseq.data.data_utils import lengths_to_mask
from fairseq.models.fairseq_model import FairseqEncoderModel
@dataclass
class FastSpeech2CriterionConfig(FairseqDataclass):
ctc_weight: float = field(default=0.0, metadata={"help": "weight for CTC loss"})
@register_criterion("fastspeech2", dataclass=FastSpeech2CriterionConfig)
class FastSpeech2Loss(FairseqCriterion):
def __init__(self, task, ctc_weight):
super().__init__(task)
self.ctc_weight = ctc_weight
def forward(self, model: FairseqEncoderModel, sample, reduction="mean"):
src_tokens = sample["net_input"]["src_tokens"]
src_lens = sample["net_input"]["src_lengths"]
tgt_lens = sample["target_lengths"]
_feat_out, _feat_out_post, _, log_dur_out, pitch_out, energy_out = model(
src_tokens=src_tokens,
src_lengths=src_lens,
prev_output_tokens=sample["net_input"]["prev_output_tokens"],
incremental_state=None,
target_lengths=tgt_lens,
speaker=sample["speaker"],
durations=sample["durations"],
pitches=sample["pitches"],
energies=sample["energies"],
)
src_mask = lengths_to_mask(sample["net_input"]["src_lengths"])
tgt_mask = lengths_to_mask(sample["target_lengths"])
pitches, energies = sample["pitches"], sample["energies"]
pitch_out, pitches = pitch_out[src_mask], pitches[src_mask]
energy_out, energies = energy_out[src_mask], energies[src_mask]
feat_out, feat = _feat_out[tgt_mask], sample["target"][tgt_mask]
l1_loss = F.l1_loss(feat_out, feat, reduction=reduction)
if _feat_out_post is not None:
l1_loss += F.l1_loss(_feat_out_post[tgt_mask], feat, reduction=reduction)
pitch_loss = F.mse_loss(pitch_out, pitches, reduction=reduction)
energy_loss = F.mse_loss(energy_out, energies, reduction=reduction)
log_dur_out = log_dur_out[src_mask]
dur = sample["durations"].float()
dur = dur.half() if log_dur_out.type().endswith(".HalfTensor") else dur
log_dur = torch.log(dur + 1)[src_mask]
dur_loss = F.mse_loss(log_dur_out, log_dur, reduction=reduction)
ctc_loss = torch.tensor(0.0).type_as(l1_loss)
if self.ctc_weight > 0.0:
lprobs = model.get_normalized_probs((_feat_out,), log_probs=True)
lprobs = lprobs.transpose(0, 1) # T x B x C
src_mask = lengths_to_mask(src_lens)
src_tokens_flat = src_tokens.masked_select(src_mask)
ctc_loss = (
F.ctc_loss(
lprobs,
src_tokens_flat,
tgt_lens,
src_lens,
reduction=reduction,
zero_infinity=True,
)
* self.ctc_weight
)
loss = l1_loss + dur_loss + pitch_loss + energy_loss + ctc_loss
sample_size = sample["nsentences"]
logging_output = {
"loss": utils.item(loss.data),
"ntokens": sample["ntokens"],
"nsentences": sample["nsentences"],
"sample_size": sample_size,
"l1_loss": utils.item(l1_loss.data),
"dur_loss": utils.item(dur_loss.data),
"pitch_loss": utils.item(pitch_loss.data),
"energy_loss": utils.item(energy_loss.data),
"ctc_loss": utils.item(ctc_loss.data),
}
return loss, sample_size, logging_output
@classmethod
def reduce_metrics(cls, logging_outputs: List[Dict[str, Any]]) -> None:
ns = [log.get("sample_size", 0) for log in logging_outputs]
ntot = sum(ns)
ws = [n / (ntot + 1e-8) for n in ns]
for key in [
"loss",
"l1_loss",
"dur_loss",
"pitch_loss",
"energy_loss",
"ctc_loss",
]:
vals = [log.get(key, 0) for log in logging_outputs]
val = sum(val * w for val, w in zip(vals, ws))
metrics.log_scalar(key, val, ntot, round=3)
metrics.log_scalar("sample_size", ntot, len(logging_outputs))
# inference metrics
if "targ_frames" not in logging_outputs[0]:
return
n = sum(log.get("targ_frames", 0) for log in logging_outputs)
for key, new_key in [
("mcd_loss", "mcd_loss"),
("pred_frames", "pred_ratio"),
("nins", "ins_rate"),
("ndel", "del_rate"),
]:
val = sum(log.get(key, 0) for log in logging_outputs)
metrics.log_scalar(new_key, val / n, n, round=3)
@staticmethod
def logging_outputs_can_be_summed() -> bool:
return False
|