| import torch | |
| import torch.nn as nn | |
| class Adapter(nn.Module): | |
| def __init__(self, encoder_dim, llm_dim, downsample_rate=2): | |
| super().__init__() | |
| self.ds = downsample_rate | |
| self.linear1 = nn.Linear(encoder_dim * downsample_rate, llm_dim) | |
| self.relu = nn.ReLU() | |
| self.linear2 = nn.Linear(llm_dim, llm_dim) | |
| def forward(self, x, x_lens): | |
| batch_size, seq_len, feat_dim = x.size() | |
| num_frames_to_discard = seq_len % self.ds | |
| if num_frames_to_discard > 0: | |
| x = x[:, :-num_frames_to_discard, :] | |
| seq_len = x.size(1) | |
| x = x.contiguous() | |
| x = x.view( | |
| batch_size, seq_len // self.ds, feat_dim * self.ds | |
| ) | |
| x = self.linear1(x) | |
| x = self.relu(x) | |
| x = self.linear2(x) | |
| new_x_lens = torch.clamp(x_lens, max=seq_len) // self.ds | |
| return x, new_x_lens | |