Spaces:
Build error
Build error
| import os | |
| import fairseq | |
| import pytorch_lightning as pl | |
| import requests | |
| import torch | |
| import torch.nn as nn | |
| from tqdm import tqdm | |
| UTMOS_CKPT_URL = "https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/epoch%3D3-step%3D7459.ckpt" | |
| WAV2VEC_URL = "https://huggingface.co/spaces/sarulab-speech/UTMOS-demo/resolve/main/wav2vec_small.pt" | |
| """ | |
| UTMOS score, automatic Mean Opinion Score (MOS) prediction system, | |
| adapted from https://huggingface.co/spaces/sarulab-speech/UTMOS-demo | |
| """ | |
| class UTMOSScore: | |
| """Predicting score for each audio clip.""" | |
| def __init__(self, device, ckpt_path="epoch=3-step=7459.ckpt"): | |
| self.device = device | |
| filepath = os.path.join(os.path.dirname(__file__), ckpt_path) | |
| if not os.path.exists(filepath): | |
| download_file(UTMOS_CKPT_URL, filepath) | |
| self.model = BaselineLightningModule.load_from_checkpoint(filepath).eval().to(device) | |
| def score(self, wavs: torch.tensor) -> torch.tensor: | |
| """ | |
| Args: | |
| wavs: audio waveform to be evaluated. When len(wavs) == 1 or 2, | |
| the model processes the input as a single audio clip. The model | |
| performs batch processing when len(wavs) == 3. | |
| """ | |
| if len(wavs.shape) == 1: | |
| out_wavs = wavs.unsqueeze(0).unsqueeze(0) | |
| elif len(wavs.shape) == 2: | |
| out_wavs = wavs.unsqueeze(0) | |
| elif len(wavs.shape) == 3: | |
| out_wavs = wavs | |
| else: | |
| raise ValueError("Dimension of input tensor needs to be <= 3.") | |
| bs = out_wavs.shape[0] | |
| batch = { | |
| "wav": out_wavs, | |
| "domains": torch.zeros(bs, dtype=torch.int).to(self.device), | |
| "judge_id": torch.ones(bs, dtype=torch.int).to(self.device) * 288, | |
| } | |
| with torch.no_grad(): | |
| output = self.model(batch) | |
| return output.mean(dim=1).squeeze(1).cpu().detach() * 2 + 3 | |
| def download_file(url, filename): | |
| """ | |
| Downloads a file from the given URL | |
| Args: | |
| url (str): The URL of the file to download. | |
| filename (str): The name to save the file as. | |
| """ | |
| print(f"Downloading file {filename}...") | |
| response = requests.get(url, stream=True) | |
| response.raise_for_status() | |
| total_size_in_bytes = int(response.headers.get("content-length", 0)) | |
| progress_bar = tqdm(total=total_size_in_bytes, unit="iB", unit_scale=True) | |
| with open(filename, "wb") as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| progress_bar.update(len(chunk)) | |
| f.write(chunk) | |
| progress_bar.close() | |
| def load_ssl_model(ckpt_path="wav2vec_small.pt"): | |
| filepath = os.path.join(os.path.dirname(__file__), ckpt_path) | |
| if not os.path.exists(filepath): | |
| download_file(WAV2VEC_URL, filepath) | |
| SSL_OUT_DIM = 768 | |
| model, cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task([filepath]) | |
| ssl_model = model[0] | |
| ssl_model.remove_pretraining_modules() | |
| return SSL_model(ssl_model, SSL_OUT_DIM) | |
| class BaselineLightningModule(pl.LightningModule): | |
| def __init__(self, cfg): | |
| super().__init__() | |
| self.cfg = cfg | |
| self.construct_model() | |
| self.save_hyperparameters() | |
| def construct_model(self): | |
| self.feature_extractors = nn.ModuleList( | |
| [load_ssl_model(ckpt_path="wav2vec_small.pt"), DomainEmbedding(3, 128),] | |
| ) | |
| output_dim = sum([feature_extractor.get_output_dim() for feature_extractor in self.feature_extractors]) | |
| output_layers = [LDConditioner(judge_dim=128, num_judges=3000, input_dim=output_dim)] | |
| output_dim = output_layers[-1].get_output_dim() | |
| output_layers.append( | |
| Projection(hidden_dim=2048, activation=torch.nn.ReLU(), range_clipping=False, input_dim=output_dim) | |
| ) | |
| self.output_layers = nn.ModuleList(output_layers) | |
| def forward(self, inputs): | |
| outputs = {} | |
| for feature_extractor in self.feature_extractors: | |
| outputs.update(feature_extractor(inputs)) | |
| x = outputs | |
| for output_layer in self.output_layers: | |
| x = output_layer(x, inputs) | |
| return x | |
| class SSL_model(nn.Module): | |
| def __init__(self, ssl_model, ssl_out_dim) -> None: | |
| super(SSL_model, self).__init__() | |
| self.ssl_model, self.ssl_out_dim = ssl_model, ssl_out_dim | |
| def forward(self, batch): | |
| wav = batch["wav"] | |
| wav = wav.squeeze(1) # [batches, audio_len] | |
| res = self.ssl_model(wav, mask=False, features_only=True) | |
| x = res["x"] | |
| return {"ssl-feature": x} | |
| def get_output_dim(self): | |
| return self.ssl_out_dim | |
| class DomainEmbedding(nn.Module): | |
| def __init__(self, n_domains, domain_dim) -> None: | |
| super().__init__() | |
| self.embedding = nn.Embedding(n_domains, domain_dim) | |
| self.output_dim = domain_dim | |
| def forward(self, batch): | |
| return {"domain-feature": self.embedding(batch["domains"])} | |
| def get_output_dim(self): | |
| return self.output_dim | |
| class LDConditioner(nn.Module): | |
| """ | |
| Conditions ssl output by listener embedding | |
| """ | |
| def __init__(self, input_dim, judge_dim, num_judges=None): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.judge_dim = judge_dim | |
| self.num_judges = num_judges | |
| assert num_judges != None | |
| self.judge_embedding = nn.Embedding(num_judges, self.judge_dim) | |
| # concat [self.output_layer, phoneme features] | |
| self.decoder_rnn = nn.LSTM( | |
| input_size=self.input_dim + self.judge_dim, | |
| hidden_size=512, | |
| num_layers=1, | |
| batch_first=True, | |
| bidirectional=True, | |
| ) # linear? | |
| self.out_dim = self.decoder_rnn.hidden_size * 2 | |
| def get_output_dim(self): | |
| return self.out_dim | |
| def forward(self, x, batch): | |
| judge_ids = batch["judge_id"] | |
| if "phoneme-feature" in x.keys(): | |
| concatenated_feature = torch.cat( | |
| (x["ssl-feature"], x["phoneme-feature"].unsqueeze(1).expand(-1, x["ssl-feature"].size(1), -1)), dim=2 | |
| ) | |
| else: | |
| concatenated_feature = x["ssl-feature"] | |
| if "domain-feature" in x.keys(): | |
| concatenated_feature = torch.cat( | |
| (concatenated_feature, x["domain-feature"].unsqueeze(1).expand(-1, concatenated_feature.size(1), -1),), | |
| dim=2, | |
| ) | |
| if judge_ids != None: | |
| concatenated_feature = torch.cat( | |
| ( | |
| concatenated_feature, | |
| self.judge_embedding(judge_ids).unsqueeze(1).expand(-1, concatenated_feature.size(1), -1), | |
| ), | |
| dim=2, | |
| ) | |
| decoder_output, (h, c) = self.decoder_rnn(concatenated_feature) | |
| return decoder_output | |
| class Projection(nn.Module): | |
| def __init__(self, input_dim, hidden_dim, activation, range_clipping=False): | |
| super(Projection, self).__init__() | |
| self.range_clipping = range_clipping | |
| output_dim = 1 | |
| if range_clipping: | |
| self.proj = nn.Tanh() | |
| self.net = nn.Sequential( | |
| nn.Linear(input_dim, hidden_dim), activation, nn.Dropout(0.3), nn.Linear(hidden_dim, output_dim), | |
| ) | |
| self.output_dim = output_dim | |
| def forward(self, x, batch): | |
| output = self.net(x) | |
| # range clipping | |
| if self.range_clipping: | |
| return self.proj(output) * 2.0 + 3 | |
| else: | |
| return output | |
| def get_output_dim(self): | |
| return self.output_dim | |