| import torch
|
| from audiotools import AudioSignal
|
| from audiotools.ml import BaseModel
|
| from encodec import EncodecModel
|
|
|
|
|
| class Encodec(BaseModel):
|
| def __init__(self, sample_rate: int = 24000, bandwidth: float = 24.0):
|
| super().__init__()
|
|
|
| if sample_rate == 24000:
|
| self.model = EncodecModel.encodec_model_24khz()
|
| else:
|
| self.model = EncodecModel.encodec_model_48khz()
|
| self.model.set_target_bandwidth(bandwidth)
|
| self.sample_rate = 44100
|
|
|
| def forward(
|
| self,
|
| audio_data: torch.Tensor,
|
| sample_rate: int = 44100,
|
| n_quantizers: int = None,
|
| ):
|
| signal = AudioSignal(audio_data, sample_rate)
|
| signal.resample(self.model.sample_rate)
|
| recons = self.model(signal.audio_data)
|
| recons = AudioSignal(recons, self.model.sample_rate)
|
| recons.resample(sample_rate)
|
| return {"audio": recons.audio_data}
|
|
|
|
|
| if __name__ == "__main__":
|
| import numpy as np
|
| from functools import partial
|
|
|
| model = Encodec()
|
|
|
| for n, m in model.named_modules():
|
| o = m.extra_repr()
|
| p = sum([np.prod(p.size()) for p in m.parameters()])
|
| fn = lambda o, p: o + f" {p/1e6:<.3f}M params."
|
| setattr(m, "extra_repr", partial(fn, o=o, p=p))
|
| print(model)
|
| print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()]))
|
|
|
| length = 88200 * 2
|
| x = torch.randn(1, 1, length).to(model.device)
|
| x.requires_grad_(True)
|
| x.retain_grad()
|
|
|
|
|
| out = model(x)["audio"]
|
|
|
| print(x.shape, out.shape)
|
|
|