| | |
| | |
| |
|
| | """ Wrapper for ngram_repeat_block cuda extension """ |
| | import torch |
| | from torch import nn |
| |
|
| | import math |
| | from typing import Dict, List, Optional |
| | import warnings |
| |
|
| | try: |
| | from fairseq import ngram_repeat_block_cuda |
| |
|
| | EXTENSION_BUILT = True |
| | except ImportError: |
| | EXTENSION_BUILT = False |
| |
|
| |
|
| | def is_cuda_extension_usable() -> bool: |
| | """Check whether ngram_repeat_block_cuda is built properly""" |
| | if not EXTENSION_BUILT or not torch.cuda.is_available(): |
| | return False |
| | bsz = 2 |
| | tokens = torch.tensor([[4, 4, 3, 2], [1, 2, 3, 4]], dtype=torch.long, device="cuda") |
| | lprobs = torch.rand((8, 12), device="cuda") |
| | try: |
| | outputs = ngram_repeat_block_cuda.forward(tokens, lprobs, bsz, 3, 4, 3) |
| | outputs = outputs + 4 |
| | return True |
| | except RuntimeError: |
| | warnings.warn( |
| | "NGramRepeatBlock extension must be rebuilt." |
| | 'Run TORCH_CUDA_ARCH_LIST="6.0;6.1;7.0" python setup.py build_ext --inplace' |
| | ) |
| | return False |
| |
|
| |
|
| | class NGramRepeatBlock(nn.Module): |
| | """ Wrapper class for calling ngram_repeat_block cuda extension """ |
| |
|
| | def __init__(self, no_repeat_ngram_size: int, use_extension: bool = True): |
| | super().__init__() |
| | self.use_extension = is_cuda_extension_usable() if use_extension else False |
| | self.no_repeat_ngram_size = no_repeat_ngram_size |
| |
|
| | def reset_parameters(self): |
| | pass |
| |
|
| | @torch.jit.unused |
| | def call_cuda_extension( |
| | self, |
| | tokens, |
| | lprobs, |
| | bsz: int, |
| | beam_size: int, |
| | step: int, |
| | ): |
| | return ngram_repeat_block_cuda.forward( |
| | tokens, lprobs, bsz, step, beam_size, self.no_repeat_ngram_size |
| | ) |
| |
|
| | def forward( |
| | self, |
| | tokens, |
| | lprobs, |
| | bsz: int, |
| | beam_size: int, |
| | step: int, |
| | ): |
| | """ |
| | Args: |
| | tokens(Tensor): Input tokens(Bsz*beam, seq_len) |
| | lprobs(Tensor): likelihood probability, |
| | Expected to be updated in place.(Bsz*beam, vocab_size) |
| | bsz(int): batch size |
| | step(int): current step |
| | beam_size(int): beam size |
| | no_repeat_ngram_size(int): Ngram size |
| | """ |
| | msg = f"expected {bsz *beam_size} got" |
| | assert tokens.size(0) == bsz * beam_size, f"{msg} {tokens.size(0)}" |
| | assert lprobs.size(0) == bsz * beam_size, f"{msg} {lprobs.size(0)}" |
| | if self.use_extension: |
| | return self.call_cuda_extension(tokens, lprobs, bsz, beam_size, step) |
| |
|
| | else: |
| | return self._no_repeat_ngram( |
| | tokens, |
| | lprobs, |
| | bsz, |
| | beam_size, |
| | step, |
| | ) |
| |
|
| | def _no_repeat_ngram(self, tokens, lprobs, bsz: int, beam_size: int, step: int): |
| | """For each hypothesis generate a list of previous ngrams and set associated lprobs to -inf""" |
| | gen_ngrams: List[Dict[str, List[int]]] = [ |
| | torch.jit.annotate(Dict[str, List[int]], {}) |
| | for bbsz_idx in range(bsz * beam_size) |
| | ] |
| | cpu_tokens = tokens.cpu() |
| | for bbsz_idx in range(bsz * beam_size): |
| | gen_tokens: List[int] = cpu_tokens[bbsz_idx].tolist() |
| | for ngram in self.transpose_list( |
| | [gen_tokens[i:] for i in range(self.no_repeat_ngram_size)] |
| | ): |
| | key = ",".join([str(x) for x in ngram[:-1]]) |
| | gen_ngrams[bbsz_idx][key] = gen_ngrams[bbsz_idx].get( |
| | key, torch.jit.annotate(List[int], []) |
| | ) + [ngram[-1]] |
| | if step + 2 - self.no_repeat_ngram_size >= 0: |
| | |
| | banned_tokens = [ |
| | self.calculate_banned_tokens( |
| | tokens, step, gen_ngrams, self.no_repeat_ngram_size, bbsz_idx |
| | ) |
| | for bbsz_idx in range(bsz * beam_size) |
| | ] |
| | else: |
| | banned_tokens = [ |
| | torch.jit.annotate(List[int], []) for bbsz_idx in range(bsz * beam_size) |
| | ] |
| | for bbsz_idx in range(bsz * beam_size): |
| | lprobs[bbsz_idx][ |
| | torch.tensor(banned_tokens[bbsz_idx], dtype=torch.int64) |
| | ] = torch.tensor(-math.inf).to(lprobs) |
| | return lprobs |
| |
|
| | @staticmethod |
| | def calculate_banned_tokens( |
| | tokens, |
| | step: int, |
| | gen_ngrams: List[Dict[str, List[int]]], |
| | no_repeat_ngram_size: int, |
| | bbsz_idx: int, |
| | ): |
| | tokens_list: List[int] = tokens[ |
| | bbsz_idx, step + 2 - no_repeat_ngram_size : step + 1 |
| | ].tolist() |
| | |
| | ngram_index = ",".join([str(x) for x in tokens_list]) |
| | return gen_ngrams[bbsz_idx].get(ngram_index, torch.jit.annotate(List[int], [])) |
| |
|
| | @staticmethod |
| | def transpose_list(l: List[List[int]]): |
| | |
| | min_len = min([len(x) for x in l]) |
| | l2 = [[row[i] for row in l] for i in range(min_len)] |
| | return l2 |
| |
|