import os import torch import torch.nn.functional as F from ssl_ecapa_model import SSL_ECAPA_TDNN from huggingface_hub import hf_hub_download def load_model(ckpt_path): model = SSL_ECAPA_TDNN(feat_dim=1024, emb_dim=256, feat_type='wavlm_large') load_parameters(model, ckpt_path) return model def load_parameters(model, ckpt_path): model_state = model.state_dict() if not os.path.isfile(ckpt_path): print("Downloading model from Hugging Face Hub...") new_ckpt_path = hf_hub_download(repo_id="junseok520/voxsim-models", filename=ckpt_path, local_dir="./") ckpt_path = new_ckpt_path loaded_state = torch.load(ckpt_path, map_location='cpu', weights_only=True) for name, param in loaded_state.items(): if name.startswith('__S__.'): if name[6:] in model_state: model_state[name[6:]].copy_(param) else: print("{} is not in the model.".format(name[6:])) else: if name in model_state: model_state[name].copy_(param) else: print("{} is not in the model.".format(name)) class Score: """Predicting score for each audio clip.""" def __init__( self, ckpt_path: str = "wavlm_ecapa.pt", device: str = "gpu"): """ Args: ckpt_path: path to pretrained checkpoint of voxsim evaluator. input_sample_rate: sampling rate of input audio tensor. The input audio tensor is automatically downsampled to 16kHz. """ print(f"Using device: {device}") self.device = device self.model = load_model(ckpt_path).to(self.device) self.model.eval() def score(self, inp_wavs: torch.tensor, inp_wav: torch.tensor, ref_wavs: torch.tensor, ref_wav: torch.tensor) -> torch.tensor: """ Args: wavs: audio waveform to be evaluated. When len(wavs) == 1 or 2, the model processes the input as a single audio clip. The model performs batch processing when len(wavs) == 3. """ # if len(wavs.shape) == 1: # out_wavs = wavs.unsqueeze(0).unsqueeze(0) # elif len(wavs.shape) == 2: # out_wavs = wavs.unsqueeze(0) # elif len(wavs.shape) == 3: # out_wavs = wavs # else: # raise ValueError('Dimension of input tensor needs to be <= 3.') if len(inp_wavs.shape) == 2: bs = 1 elif len(inp_wavs.shape) == 3: bs = inp_wavs.shape[0] else: raise ValueError('Dimension of input tensor needs to be <= 3.') inp_wavs = inp_wavs.reshape(-1, inp_wavs.shape[-1]).to(self.device) inp_wav = inp_wav.reshape(-1, inp_wav.shape[-1]).to(self.device) ref_wavs = ref_wavs.reshape(-1, ref_wavs.shape[-1]).to(self.device) ref_wav = ref_wav.reshape(-1, ref_wav.shape[-1]).to(self.device) # assert inp_wavs.shape[1] == 10 # assert ref_wavs.shape[1] == 10 # assert inp_wav.shape[1] == 1 # assert ref_wav.shape[1] == 1 # import pdb; pdb.set_trace() with torch.no_grad(): input_emb_1 = F.normalize(self.model.forward(inp_wavs), p=2, dim=1).detach() input_emb_2 = F.normalize(self.model.forward(inp_wav), p=2, dim=1).detach() ref_emb_1 = F.normalize(self.model.forward(ref_wavs), p=2, dim=1).detach() ref_emb_2 = F.normalize(self.model.forward(ref_wav), p=2, dim=1).detach() emb_size = input_emb_1.shape[-1] input_emb_1 = input_emb_1.reshape(bs, -1, emb_size) input_emb_2 = input_emb_2.reshape(bs, -1, emb_size) ref_emb_1 = ref_emb_1.reshape(bs, -1, emb_size) ref_emb_2 = ref_emb_2.reshape(bs, -1, emb_size) score_1 = torch.mean(torch.bmm(input_emb_1, ref_emb_1.transpose(1,2)), dim=(1,2)) score_2 = torch.mean(torch.bmm(input_emb_2, ref_emb_2.transpose(1,2)), dim=(1,2)) score = (score_1 + score_2) / 2 score = score.detach().cpu().numpy() return score