| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.nn as nn |
| from scaling import ScaledLinear |
|
|
|
|
| class Joiner(nn.Module): |
| def __init__( |
| self, |
| encoder_dim: int, |
| decoder_dim: int, |
| joiner_dim: int, |
| vocab_size: int, |
| ): |
| super().__init__() |
|
|
| self.encoder_proj = ScaledLinear(encoder_dim, joiner_dim) |
| self.decoder_proj = ScaledLinear(decoder_dim, joiner_dim) |
| self.output_linear = ScaledLinear(joiner_dim, vocab_size) |
|
|
| def forward( |
| self, |
| encoder_out: torch.Tensor, |
| decoder_out: torch.Tensor, |
| project_input: bool = True, |
| ) -> torch.Tensor: |
| """ |
| Args: |
| encoder_out: |
| Output from the encoder. Its shape is (N, T, s_range, C). |
| decoder_out: |
| Output from the decoder. Its shape is (N, T, s_range, C). |
| project_input: |
| If true, apply input projections encoder_proj and decoder_proj. |
| If this is false, it is the user's responsibility to do this |
| manually. |
| Returns: |
| Return a tensor of shape (N, T, s_range, C). |
| """ |
| assert encoder_out.ndim == decoder_out.ndim == 4 |
| assert encoder_out.shape[:-1] == decoder_out.shape[:-1] |
|
|
| if project_input: |
| logit = self.encoder_proj(encoder_out) + self.decoder_proj(decoder_out) |
| else: |
| logit = encoder_out + decoder_out |
|
|
| logit = self.output_linear(torch.tanh(logit)) |
|
|
| return logit |
|
|