| import torch | |
| from torch import nn | |
| from .layers import CustomDiagonalLinear, CustomLinear | |
| class FDDT(nn.Module): | |
| def __init__(self, d_model, non_target_rate=0.01, fddt_init=None, is_diagonal=False, | |
| bias_only=False, use_silence=True, use_target=True, use_overlap=True, use_non_target=True): | |
| super().__init__() | |
| if use_target: | |
| self.target_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else ( | |
| CustomDiagonalLinear(d_model, bias=True, fddt_init=fddt_init, | |
| init_eye_val=1.0) if is_diagonal else CustomLinear(d_model, | |
| d_model, | |
| bias=True, fddt_init=fddt_init, | |
| init_eye_val=1.0)) | |
| if use_non_target: | |
| self.non_target_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else ( | |
| CustomDiagonalLinear(d_model, bias=True, fddt_init=fddt_init, | |
| init_eye_val=non_target_rate) if is_diagonal else CustomLinear( | |
| d_model, d_model, bias=True, fddt_init=fddt_init, init_eye_val=non_target_rate)) | |
| if use_overlap: | |
| self.overlap_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else ( | |
| CustomDiagonalLinear(d_model, bias=True, fddt_init=fddt_init, | |
| init_eye_val=1.0) if is_diagonal else CustomLinear(d_model, | |
| d_model, | |
| bias=True, fddt_init=fddt_init, | |
| init_eye_val=1.0)) | |
| if use_silence: | |
| self.silence_linear = nn.Parameter(torch.zeros(d_model)) if bias_only else ( | |
| CustomDiagonalLinear(d_model, bias=True, fddt_init=fddt_init, | |
| init_eye_val=non_target_rate) if is_diagonal else CustomLinear( | |
| d_model, d_model, bias=True, fddt_init=fddt_init, init_eye_val=non_target_rate)) | |
| self.use_silence = use_silence | |
| self.use_target = use_target | |
| self.use_overlap = use_overlap | |
| self.use_non_target = use_non_target | |
| self.bias_only = bias_only | |
| def forward(self, hidden_states, stno_mask): | |
| stno_mask = stno_mask.to(hidden_states.device)[..., None] | |
| if self.bias_only: | |
| if self.use_silence: | |
| hidden_states += stno_mask[:, 0, ...] * self.silence_linear | |
| if self.use_target: | |
| hidden_states += stno_mask[:, 1, ...] * self.target_linear | |
| if self.use_non_target: | |
| hidden_states += stno_mask[:, 2, ...] * self.non_target_linear | |
| if self.use_overlap: | |
| hidden_states += stno_mask[:, 3, ...] * self.overlap_linear | |
| else: | |
| orig_hidden_states = hidden_states | |
| hidden_states = (self.silence_linear( | |
| orig_hidden_states) if self.use_silence else orig_hidden_states) * stno_mask[:, 0, :] + \ | |
| (self.target_linear( | |
| orig_hidden_states) if self.use_target else orig_hidden_states) * stno_mask[:, 1, :] + \ | |
| (self.non_target_linear( | |
| orig_hidden_states) if self.use_non_target else orig_hidden_states) * stno_mask[:, 2, | |
| :] + \ | |
| (self.overlap_linear( | |
| orig_hidden_states) if self.use_overlap else orig_hidden_states) * stno_mask[:, 3, :] | |
| return hidden_states | |