| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class BeamableMM(nn.Module): |
| """This module provides an optimized MM for beam decoding with attention. |
| |
| It leverage the fact that the source-side of the input is replicated beam |
| times and the target-side of the input is of width one. This layer speeds up |
| inference by replacing the inputs {(bsz x 1 x nhu), (bsz x sz2 x nhu)} |
| with smaller inputs {(bsz/beam x beam x nhu), (bsz/beam x sz2 x nhu)}. |
| """ |
|
|
| def __init__(self, beam_size=None): |
| super(BeamableMM, self).__init__() |
| self.beam_size = beam_size |
|
|
| def forward(self, input1, input2): |
| if ( |
| not self.training |
| and self.beam_size is not None |
| and input1.dim() == 3 |
| and input1.size(1) |
| == 1 |
| ): |
| bsz, beam = input1.size(0), self.beam_size |
|
|
| |
| input1 = input1[:, 0, :].unfold(0, beam, beam).transpose(2, 1) |
|
|
| |
| input2 = input2.unfold(0, beam, beam)[:, :, :, 0] |
|
|
| |
| if input1.size(0) == 1: |
| output = torch.mm(input1[0, :, :], input2[0, :, :]) |
| else: |
| output = input1.bmm(input2) |
| return output.view(bsz, 1, -1) |
| else: |
| return input1.bmm(input2) |
|
|
| def set_beam_size(self, beam_size): |
| self.beam_size = beam_size |
|
|