| |
| |
| |
| |
| |
| |
| import torch |
| import torch.nn as nn |
|
|
| from huggingface_hub import PyTorchModelHubMixin |
|
|
|
|
| def _unsqueeze_to_3d(x): |
| """Normalize shape of `x` to [batch, n_chan, time].""" |
| if x.ndim == 1: |
| return x.reshape(1, 1, -1) |
| elif x.ndim == 2: |
| return x.unsqueeze(1) |
| else: |
| return x |
|
|
|
|
| def pad_to_appropriate_length(x, lcm): |
| values_to_pad = int(x.shape[-1]) % lcm |
| if values_to_pad: |
| appropriate_shape = x.shape |
| padded_x = torch.zeros( |
| list(appropriate_shape[:-1]) |
| + [appropriate_shape[-1] + lcm - values_to_pad], |
| dtype=torch.float32, |
| ).to(x.device) |
| padded_x[..., : x.shape[-1]] = x |
| return padded_x |
| return x |
|
|
|
|
| class BaseModel(nn.Module, PyTorchModelHubMixin, repo_url="https://github.com/JusperLee/Apollo", pipeline_tag="audio-to-audio"): |
| def __init__(self, sample_rate, in_chan=1): |
| super().__init__() |
| self._sample_rate = sample_rate |
| self._in_chan = in_chan |
|
|
| def forward(self, *args, **kwargs): |
| raise NotImplementedError |
|
|
| def sample_rate(self,): |
| return self._sample_rate |
|
|
| @staticmethod |
| def load_state_dict_in_audio(model, pretrained_dict): |
| model_dict = model.state_dict() |
| update_dict = {} |
| for k, v in pretrained_dict.items(): |
| if "audio_model" in k: |
| update_dict[k[12:]] = v |
| model_dict.update(update_dict) |
| model.load_state_dict(model_dict) |
| return model |
|
|
| @staticmethod |
| def from_pretrain(pretrained_model_conf_or_path, *args, **kwargs): |
| from . import get |
|
|
| conf = torch.load( |
| pretrained_model_conf_or_path, map_location="cpu", weights_only=False |
| ) |
|
|
| model_class = get(conf["model_name"]) |
| |
| model = model_class(*args, **kwargs) |
| model.load_state_dict(conf["state_dict"]) |
| return model |
|
|
| def serialize(self): |
| import pytorch_lightning as pl |
|
|
| model_conf = dict( |
| model_name=self.__class__.__name__, |
| state_dict=self.get_state_dict(), |
| model_args=self.get_model_args(), |
| ) |
| |
| infos = dict() |
| infos["software_versions"] = dict( |
| torch_version=torch.__version__, pytorch_lightning_version=pl.__version__, |
| ) |
| model_conf["infos"] = infos |
| return model_conf |
|
|
| def get_state_dict(self): |
| """In case the state dict needs to be modified before sharing the model.""" |
| return self.state_dict() |
|
|
| def get_model_args(self): |
| """Should return args to re-instantiate the class.""" |
| raise NotImplementedError |
|
|