| | |
| | |
| | |
| | |
| |
|
| | import sys |
| |
|
| | import torch |
| | from fairseq import utils |
| |
|
| |
|
| | class SequenceScorer(object): |
| | """Scores the target for a given source sentence.""" |
| |
|
| | def __init__( |
| | self, |
| | tgt_dict, |
| | softmax_batch=None, |
| | compute_alignment=False, |
| | eos=None, |
| | symbols_to_strip_from_output=None, |
| | ): |
| | self.pad = tgt_dict.pad() |
| | self.eos = tgt_dict.eos() if eos is None else eos |
| | self.softmax_batch = softmax_batch or sys.maxsize |
| | assert self.softmax_batch > 0 |
| | self.compute_alignment = compute_alignment |
| | self.symbols_to_strip_from_output = ( |
| | symbols_to_strip_from_output.union({self.eos}) |
| | if symbols_to_strip_from_output is not None |
| | else {self.eos} |
| | ) |
| |
|
| | @torch.no_grad() |
| | def generate(self, models, sample, **kwargs): |
| | """Score a batch of translations.""" |
| | net_input = sample["net_input"] |
| |
|
| | def batch_for_softmax(dec_out, target): |
| | |
| | first, rest = dec_out[0], dec_out[1:] |
| | bsz, tsz, dim = first.shape |
| | if bsz * tsz < self.softmax_batch: |
| | yield dec_out, target, True |
| | else: |
| | flat = first.contiguous().view(1, -1, dim) |
| | flat_tgt = target.contiguous().view(flat.shape[:-1]) |
| | s = 0 |
| | while s < flat.size(1): |
| | e = s + self.softmax_batch |
| | yield (flat[:, s:e],) + rest, flat_tgt[:, s:e], False |
| | s = e |
| |
|
| | def gather_target_probs(probs, target): |
| | probs = probs.gather( |
| | dim=2, |
| | index=target.unsqueeze(-1), |
| | ) |
| | return probs |
| |
|
| | orig_target = sample["target"] |
| |
|
| | |
| | avg_probs = None |
| | avg_attn = None |
| | for model in models: |
| | model.eval() |
| | decoder_out = model(**net_input) |
| | attn = decoder_out[1] if len(decoder_out) > 1 else None |
| | if type(attn) is dict: |
| | attn = attn.get("attn", None) |
| |
|
| | batched = batch_for_softmax(decoder_out, orig_target) |
| | probs, idx = None, 0 |
| | for bd, tgt, is_single in batched: |
| | sample["target"] = tgt |
| | curr_prob = model.get_normalized_probs( |
| | bd, log_probs=len(models) == 1, sample=sample |
| | ).data |
| | if is_single: |
| | probs = gather_target_probs(curr_prob, orig_target) |
| | else: |
| | if probs is None: |
| | probs = curr_prob.new(orig_target.numel()) |
| | step = curr_prob.size(0) * curr_prob.size(1) |
| | end = step + idx |
| | tgt_probs = gather_target_probs( |
| | curr_prob.view(tgt.shape + (curr_prob.size(-1),)), tgt |
| | ) |
| | probs[idx:end] = tgt_probs.view(-1) |
| | idx = end |
| | sample["target"] = orig_target |
| |
|
| | probs = probs.view(sample["target"].shape) |
| |
|
| | if avg_probs is None: |
| | avg_probs = probs |
| | else: |
| | avg_probs.add_(probs) |
| | if attn is not None: |
| | if torch.is_tensor(attn): |
| | attn = attn.data |
| | else: |
| | attn = attn[0] |
| | if avg_attn is None: |
| | avg_attn = attn |
| | else: |
| | avg_attn.add_(attn) |
| | if len(models) > 1: |
| | avg_probs.div_(len(models)) |
| | avg_probs.log_() |
| | if avg_attn is not None: |
| | avg_attn.div_(len(models)) |
| |
|
| | bsz = avg_probs.size(0) |
| | hypos = [] |
| | start_idxs = sample["start_indices"] if "start_indices" in sample else [0] * bsz |
| | for i in range(bsz): |
| | |
| | ref = ( |
| | utils.strip_pad(sample["target"][i, start_idxs[i] :], self.pad) |
| | if sample["target"] is not None |
| | else None |
| | ) |
| | tgt_len = ref.numel() |
| | avg_probs_i = avg_probs[i][start_idxs[i] : start_idxs[i] + tgt_len] |
| | score_i = avg_probs_i.sum() / tgt_len |
| | if avg_attn is not None: |
| | avg_attn_i = avg_attn[i] |
| | if self.compute_alignment: |
| | alignment = utils.extract_hard_alignment( |
| | avg_attn_i, |
| | sample["net_input"]["src_tokens"][i], |
| | sample["target"][i], |
| | self.pad, |
| | self.eos, |
| | ) |
| | else: |
| | alignment = None |
| | else: |
| | avg_attn_i = alignment = None |
| | hypos.append( |
| | [ |
| | { |
| | "tokens": ref, |
| | "score": score_i, |
| | "attention": avg_attn_i, |
| | "alignment": alignment, |
| | "positional_scores": avg_probs_i, |
| | } |
| | ] |
| | ) |
| | return hypos |
| |
|