Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch.nn.functional import pad | |
| from collections import OrderedDict | |
| class Encoder(nn.Module): | |
| def __init__(self, config): | |
| super(Encoder, self).__init__() | |
| self.config = config | |
| self.encoder = None | |
| self.succeeding_layers = None | |
| # AUDIO | |
| if self.config.model.task == "audio": | |
| if self.config.model.encoder.name.lower() == "wavlm": | |
| from manipulate_model.encoder.wavlm.WavLM import WavLM, WavLMConfig | |
| ckpt = torch.load( | |
| config.model.encoder.pretrained_path, map_location="cpu" | |
| ) | |
| cfg = WavLMConfig(ckpt) | |
| self.encoder = WavLM(cfg) | |
| def forward(self, x): | |
| if self.config.model.encoder.name.lower() == "wavlm": | |
| return self.encoder(x, output_layer=self.config.model.encoder.output_layer) | |
| elif self.config.model.encoder.name.lower() == "videomamba": | |
| return self.encoder(x) | |
| return self.encoder(x) | |
| def get_encoding_dim(self): | |
| return self.encoder.get_encoding_dim() | |
| def get_temporal_dim(self): | |
| return self.encoder.get_temporal_dim(window_size=self.config.data.window_size) | |