|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from dataclasses import dataclass, field |
|
|
import torch |
|
|
from fairseq import utils |
|
|
from fairseq.logging import metrics |
|
|
from fairseq.criterions import register_criterion |
|
|
from fairseq.criterions.label_smoothed_cross_entropy import ( |
|
|
LabelSmoothedCrossEntropyCriterion, |
|
|
LabelSmoothedCrossEntropyCriterionConfig, |
|
|
) |
|
|
|
|
|
try: |
|
|
from simuleval.metrics.latency import ( |
|
|
AverageLagging, |
|
|
AverageProportion, |
|
|
DifferentiableAverageLagging, |
|
|
) |
|
|
|
|
|
LATENCY_METRICS = { |
|
|
"average_lagging": AverageLagging, |
|
|
"average_proportion": AverageProportion, |
|
|
"differentiable_average_lagging": DifferentiableAverageLagging, |
|
|
} |
|
|
except ImportError: |
|
|
LATENCY_METRICS = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig( |
|
|
LabelSmoothedCrossEntropyCriterionConfig |
|
|
): |
|
|
latency_avg_weight: float = field( |
|
|
default=0.0, |
|
|
metadata={"help": "weight fot average latency loss."}, |
|
|
) |
|
|
latency_var_weight: float = field( |
|
|
default=0.0, |
|
|
metadata={"help": "weight fot variance latency loss."}, |
|
|
) |
|
|
latency_avg_type: str = field( |
|
|
default="differentiable_average_lagging", |
|
|
metadata={"help": "latency type for average loss"}, |
|
|
) |
|
|
latency_var_type: str = field( |
|
|
default="variance_delay", |
|
|
metadata={"help": "latency typ for variance loss"}, |
|
|
) |
|
|
latency_gather_method: str = field( |
|
|
default="weighted_average", |
|
|
metadata={"help": "method to gather latency loss for all heads"}, |
|
|
) |
|
|
latency_update_after: int = field( |
|
|
default=0, |
|
|
metadata={"help": "Add latency loss after certain steps"}, |
|
|
) |
|
|
|
|
|
|
|
|
@register_criterion( |
|
|
"latency_augmented_label_smoothed_cross_entropy", |
|
|
dataclass=LabelSmoothedCrossEntropyCriterionLatencyAugmentConfig, |
|
|
) |
|
|
class LatencyAugmentedLabelSmoothedCrossEntropyCriterion( |
|
|
LabelSmoothedCrossEntropyCriterion |
|
|
): |
|
|
def __init__( |
|
|
self, |
|
|
task, |
|
|
sentence_avg, |
|
|
label_smoothing, |
|
|
ignore_prefix_size, |
|
|
report_accuracy, |
|
|
latency_avg_weight, |
|
|
latency_var_weight, |
|
|
latency_avg_type, |
|
|
latency_var_type, |
|
|
latency_gather_method, |
|
|
latency_update_after, |
|
|
): |
|
|
super().__init__( |
|
|
task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy |
|
|
) |
|
|
assert LATENCY_METRICS is not None, "Please make sure SimulEval is installed." |
|
|
|
|
|
self.latency_avg_weight = latency_avg_weight |
|
|
self.latency_var_weight = latency_var_weight |
|
|
self.latency_avg_type = latency_avg_type |
|
|
self.latency_var_type = latency_var_type |
|
|
self.latency_gather_method = latency_gather_method |
|
|
self.latency_update_after = latency_update_after |
|
|
|
|
|
def forward(self, model, sample, reduce=True): |
|
|
net_output = model(**sample["net_input"]) |
|
|
|
|
|
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) |
|
|
|
|
|
|
|
|
latency_loss, expected_latency, expected_delays_var = self.compute_latency_loss( |
|
|
model, sample, net_output |
|
|
) |
|
|
|
|
|
if self.latency_update_after > 0: |
|
|
num_updates = getattr(model.decoder, "num_updates", None) |
|
|
assert ( |
|
|
num_updates is not None |
|
|
), "model.decoder doesn't have attribute 'num_updates'" |
|
|
if num_updates <= self.latency_update_after: |
|
|
latency_loss = 0 |
|
|
|
|
|
loss += latency_loss |
|
|
|
|
|
sample_size = ( |
|
|
sample["target"].size(0) if self.sentence_avg else sample["ntokens"] |
|
|
) |
|
|
|
|
|
logging_output = { |
|
|
"loss": loss.data, |
|
|
"nll_loss": nll_loss.data, |
|
|
"ntokens": sample["ntokens"], |
|
|
"nsentences": sample["target"].size(0), |
|
|
"sample_size": sample_size, |
|
|
"latency": expected_latency, |
|
|
"delays_var": expected_delays_var, |
|
|
"latency_loss": latency_loss, |
|
|
} |
|
|
|
|
|
if self.report_accuracy: |
|
|
n_correct, total = self.compute_accuracy(model, net_output, sample) |
|
|
logging_output["n_correct"] = utils.item(n_correct.data) |
|
|
logging_output["total"] = utils.item(total.data) |
|
|
return loss, sample_size, logging_output |
|
|
|
|
|
def compute_latency_loss(self, model, sample, net_output): |
|
|
assert ( |
|
|
net_output[-1].encoder_padding_mask is None |
|
|
or not net_output[-1].encoder_padding_mask[:, 0].any() |
|
|
), "Only right padding on source is supported." |
|
|
|
|
|
alpha_list = [item["alpha"] for item in net_output[1].attn_list] |
|
|
num_layers = len(alpha_list) |
|
|
bsz, num_heads, tgt_len, src_len = alpha_list[0].size() |
|
|
|
|
|
|
|
|
alpha_all = torch.cat(alpha_list, dim=1).view(-1, tgt_len, src_len) |
|
|
|
|
|
|
|
|
|
|
|
steps = ( |
|
|
torch.arange(1, 1 + src_len) |
|
|
.unsqueeze(0) |
|
|
.unsqueeze(1) |
|
|
.expand_as(alpha_all) |
|
|
.type_as(alpha_all) |
|
|
) |
|
|
|
|
|
expected_delays = torch.sum(steps * alpha_all, dim=-1) |
|
|
|
|
|
target_padding_mask = ( |
|
|
model.get_targets(sample, net_output) |
|
|
.eq(self.padding_idx) |
|
|
.unsqueeze(1) |
|
|
.expand(bsz, num_layers * num_heads, tgt_len) |
|
|
.contiguous() |
|
|
.view(-1, tgt_len) |
|
|
) |
|
|
|
|
|
src_lengths = ( |
|
|
sample["net_input"]["src_lengths"] |
|
|
.unsqueeze(1) |
|
|
.expand(bsz, num_layers * num_heads) |
|
|
.contiguous() |
|
|
.view(-1) |
|
|
) |
|
|
expected_latency = LATENCY_METRICS[self.latency_avg_type]( |
|
|
expected_delays, src_lengths, None, target_padding_mask=target_padding_mask |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
expected_latency = expected_latency.view(bsz, -1) |
|
|
if self.latency_gather_method == "average": |
|
|
|
|
|
expected_latency = expected_delays.mean(dim=1) |
|
|
elif self.latency_gather_method == "weighted_average": |
|
|
weights = torch.nn.functional.softmax(expected_latency, dim=1) |
|
|
expected_latency = torch.sum(expected_latency * weights, dim=1) |
|
|
elif self.latency_gather_method == "max": |
|
|
expected_latency = expected_latency.max(dim=1)[0] |
|
|
else: |
|
|
raise NotImplementedError |
|
|
|
|
|
expected_latency = expected_latency.sum() |
|
|
avg_loss = self.latency_avg_weight * expected_latency |
|
|
|
|
|
|
|
|
expected_delays_var = ( |
|
|
expected_delays.view(bsz, -1, tgt_len).var(dim=1).mean(dim=1) |
|
|
) |
|
|
expected_delays_var = expected_delays_var.sum() |
|
|
var_loss = self.latency_avg_weight * expected_delays_var |
|
|
|
|
|
|
|
|
latency_loss = avg_loss + var_loss |
|
|
|
|
|
return latency_loss, expected_latency, expected_delays_var |
|
|
|
|
|
@classmethod |
|
|
def reduce_metrics(cls, logging_outputs) -> None: |
|
|
super().reduce_metrics(logging_outputs) |
|
|
latency = sum(log.get("latency", 0) for log in logging_outputs) |
|
|
delays_var = sum(log.get("delays_var", 0) for log in logging_outputs) |
|
|
latency_loss = sum(log.get("latency_loss", 0) for log in logging_outputs) |
|
|
nsentences = sum(log.get("nsentences", 0) for log in logging_outputs) |
|
|
metrics.log_scalar("latency", latency.float() / nsentences, nsentences, round=3) |
|
|
metrics.log_scalar("delays_var", delays_var / nsentences, nsentences, round=3) |
|
|
metrics.log_scalar( |
|
|
"latency_loss", latency_loss / nsentences, nsentences, round=3 |
|
|
) |
|
|
|