| import torch | |
| from torch import Tensor, nn | |
| from transformers import PreTrainedModel | |
| from .config import AdapterConfig | |
| class Model(nn.Module): | |
| def __init__( | |
| self, | |
| num_channels: int, | |
| num_filters: int, | |
| window_length: int, | |
| stride: int, | |
| ): | |
| super().__init__() | |
| self.stride = stride | |
| padding = window_length // 2 - stride // 2 | |
| self.conv = nn.Conv1d( | |
| in_channels=num_channels, | |
| out_channels=num_filters, | |
| kernel_size=window_length, | |
| stride=stride, | |
| padding=padding, | |
| padding_mode="reflect", | |
| bias=False, | |
| ) | |
| self.decode = nn.ConvTranspose1d( | |
| in_channels=num_filters, | |
| out_channels=num_channels, | |
| kernel_size=window_length, | |
| stride=stride, | |
| padding=padding, | |
| bias=False, | |
| ) | |
| def encode(self, x: Tensor) -> Tensor: | |
| return torch.tanh(self.conv(x)) | |
| class Adapter(PreTrainedModel): | |
| config_class = AdapterConfig | |
| def __init__(self, config: AdapterConfig): | |
| super().__init__(config) | |
| self.model = Model( | |
| num_channels=2, | |
| num_filters=128, | |
| window_length=128, | |
| stride=64 | |
| ) | |
| def encode(self, x): | |
| return self.model.encode(x) | |
| def decode(self, x): | |
| return self.model.decode(x) | |