| import numpy as np |
| import torch |
| import torchaudio |
| import soundfile as sf |
|
|
| from s3prl.nn import S3PRLUpstream |
|
|
| class PhonologicalTokenizer: |
| def __init__( |
| self, |
| ssl_model_path: str = "ssl.pth", |
| centroids_path: str = "centroids.npy", |
| device: str = "cpu" |
| ): |
| self.ssl_model = S3PRLUpstream(name="wavlm_large").to(device) |
| self.ssl_model.load_state_dict(torch.load(ssl_model_path), strict=True) |
| self.ssl_model.eval() |
| self.centroids = torch.from_numpy(np.load(centroids_path)).to(device) |
| self.device = device |
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| wav_path: str |
| ): |
| |
| wav, sr = sf.read(wav_path) |
| wav = torch.from_numpy(wav).unsqueeze(0).float().to(self.device) |
| if sr != 16000: |
| wav = torchaudio.functional.resample(wav, sr, 16000) |
| wav_len = torch.tensor([wav.shape[1]]).to(self.device) |
| all_hs, all_lens = self.ssl_model(wav, wav_len) |
| hs = all_hs[21] |
| clusters = torch.argmin(torch.cdist(hs, self.centroids), dim=-1) |
| return clusters |