Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import torch.nn.functional as F | |
| from ssl_ecapa_model import SSL_ECAPA_TDNN | |
| from huggingface_hub import hf_hub_download | |
| def load_model(ckpt_path): | |
| model = SSL_ECAPA_TDNN(feat_dim=1024, emb_dim=256, feat_type='wavlm_large') | |
| load_parameters(model, ckpt_path) | |
| return model | |
| def load_parameters(model, ckpt_path): | |
| model_state = model.state_dict() | |
| if not os.path.isfile(ckpt_path): | |
| print("Downloading model from Hugging Face Hub...") | |
| new_ckpt_path = hf_hub_download(repo_id="junseok520/voxsim-models", filename=ckpt_path, local_dir="./") | |
| ckpt_path = new_ckpt_path | |
| loaded_state = torch.load(ckpt_path, map_location='cpu', weights_only=True) | |
| for name, param in loaded_state.items(): | |
| if name.startswith('__S__.'): | |
| if name[6:] in model_state: | |
| model_state[name[6:]].copy_(param) | |
| else: | |
| print("{} is not in the model.".format(name[6:])) | |
| class Score: | |
| """Predicting score for each audio clip.""" | |
| def __init__( | |
| self, | |
| ckpt_path: str = "wavlm_ecapa.pt", | |
| device: str = "gpu"): | |
| """ | |
| Args: | |
| ckpt_path: path to pretrained checkpoint of voxsim evaluator. | |
| input_sample_rate: sampling rate of input audio tensor. The input audio tensor | |
| is automatically downsampled to 16kHz. | |
| """ | |
| print(f"Using device: {device}") | |
| self.device = device | |
| self.model = load_model(ckpt_path).to(self.device) | |
| self.model.eval() | |
| def score(self, inp_wavs: torch.tensor, inp_wav: torch.tensor, ref_wavs: torch.tensor, ref_wav: torch.tensor) -> torch.tensor: | |
| """ | |
| Args: | |
| wavs: audio waveform to be evaluated. When len(wavs) == 1 or 2, | |
| the model processes the input as a single audio clip. The model | |
| performs batch processing when len(wavs) == 3. | |
| """ | |
| # if len(wavs.shape) == 1: | |
| # out_wavs = wavs.unsqueeze(0).unsqueeze(0) | |
| # elif len(wavs.shape) == 2: | |
| # out_wavs = wavs.unsqueeze(0) | |
| # elif len(wavs.shape) == 3: | |
| # out_wavs = wavs | |
| # else: | |
| # raise ValueError('Dimension of input tensor needs to be <= 3.') | |
| if len(inp_wavs.shape) == 2: | |
| bs = 1 | |
| elif len(inp_wavs.shape) == 3: | |
| bs = inp_wavs.shape[0] | |
| else: | |
| raise ValueError('Dimension of input tensor needs to be <= 3.') | |
| inp_wavs = inp_wavs.reshape(-1, inp_wavs.shape[-1]).to(self.device) | |
| inp_wav = inp_wav.reshape(-1, inp_wav.shape[-1]).to(self.device) | |
| ref_wavs = ref_wavs.reshape(-1, ref_wavs.shape[-1]).to(self.device) | |
| ref_wav = ref_wav.reshape(-1, ref_wav.shape[-1]).to(self.device) | |
| # assert inp_wavs.shape[1] == 10 | |
| # assert ref_wavs.shape[1] == 10 | |
| # assert inp_wav.shape[1] == 1 | |
| # assert ref_wav.shape[1] == 1 | |
| # import pdb; pdb.set_trace() | |
| with torch.no_grad(): | |
| input_emb_1 = F.normalize(self.model.forward(inp_wavs), p=2, dim=1).detach() | |
| input_emb_2 = F.normalize(self.model.forward(inp_wav), p=2, dim=1).detach() | |
| ref_emb_1 = F.normalize(self.model.forward(ref_wavs), p=2, dim=1).detach() | |
| ref_emb_2 = F.normalize(self.model.forward(ref_wav), p=2, dim=1).detach() | |
| emb_size = input_emb_1.shape[-1] | |
| input_emb_1 = input_emb_1.reshape(bs, -1, emb_size) | |
| input_emb_2 = input_emb_2.reshape(bs, -1, emb_size) | |
| ref_emb_1 = ref_emb_1.reshape(bs, -1, emb_size) | |
| ref_emb_2 = ref_emb_2.reshape(bs, -1, emb_size) | |
| score_1 = torch.mean(torch.bmm(input_emb_1, ref_emb_1.transpose(1,2)), dim=(1,2)) | |
| score_2 = torch.mean(torch.bmm(input_emb_2, ref_emb_2.transpose(1,2)), dim=(1,2)) | |
| score = (score_1 + score_2) / 2 | |
| score = score.detach().cpu().numpy() | |
| return score | |