import numpy as np import torch import torchaudio import soundfile as sf from s3prl.nn import S3PRLUpstream class PhonologicalTokenizer: def __init__( self, ssl_model_path: str = "ssl.pth", centroids_path: str = "centroids.npy", device: str = "cpu" ): self.ssl_model = S3PRLUpstream(name="wavlm_large").to(device) self.ssl_model.load_state_dict(torch.load(ssl_model_path), strict=True) self.ssl_model.eval() self.centroids = torch.from_numpy(np.load(centroids_path)).to(device) self.device = device @torch.no_grad() def __call__( self, wav_path: str ): # batch size 1 wav, sr = sf.read(wav_path) wav = torch.from_numpy(wav).unsqueeze(0).float().to(self.device) if sr != 16000: wav = torchaudio.functional.resample(wav, sr, 16000) wav_len = torch.tensor([wav.shape[1]]).to(self.device) all_hs, all_lens = self.ssl_model(wav, wav_len) hs = all_hs[21] clusters = torch.argmin(torch.cdist(hs, self.centroids), dim=-1) return clusters