| | import math |
| | import os |
| |
|
| | import librosa |
| | import numpy as np |
| | import onnxruntime as ort |
| | from numpy.fft import rfft |
| | from numpy.lib.stride_tricks import as_strided |
| |
|
| | class PLCMOSEstimator(): |
| | def __init__(self, model_version=1): |
| | """ |
| | Initialize a PLC-MOS model of a given version. There are currently three models available, v0 (intrusive) |
| | and v1 (both non-intrusive and intrusive available). The default is to use the v1 models. |
| | """ |
| |
|
| | self.model_version = model_version |
| | model_paths = [ |
| | |
| | [("models/plcmos_v0.onnx", 999999999999), (None, 0)], |
| |
|
| | |
| | [("models/plcmos_v1_intrusive.onnx", 768), |
| | ("models/plcmos_v1_nonintrusive.onnx", 999999999999)], |
| | ] |
| | self.sessions = [] |
| | self.max_lens = [] |
| | options = ort.SessionOptions() |
| | options.intra_op_num_threads = 8 |
| | options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
| | for path, max_len in model_paths[model_version]: |
| | if not path is None: |
| | file_dir = os.path.dirname(os.path.realpath(__file__)) |
| | self.sessions.append(ort.InferenceSession( |
| | os.path.join(file_dir, path), options)) |
| | self.max_lens.append(max_len) |
| | else: |
| | self.sessions.append(None) |
| | self.max_lens.append(0) |
| |
|
| | def logpow_dns(self, sig, floor=-30.): |
| | """ |
| | Compute log power of complex spectrum. |
| | |
| | Floor any -`np.inf` value to (nonzero minimum + `floor`) dB. |
| | If all values are 0s, floor all values to -80 dB. |
| | """ |
| | log10e = np.log10(np.e) |
| | pspec = sig.real ** 2 + sig.imag ** 2 |
| | zeros = pspec == 0 |
| | logp = np.empty_like(pspec) |
| | if np.any(~zeros): |
| | logp[~zeros] = np.log(pspec[~zeros]) |
| | logp[zeros] = np.log(pspec[~zeros].min()) + floor / 10 / log10e |
| | else: |
| | logp.fill(-80 / 10 / log10e) |
| |
|
| | return logp |
| |
|
| | def hop2hsize(self, wind, hop): |
| | """ |
| | Convert hop fraction to integer size if necessary. |
| | """ |
| | if hop >= 1: |
| | assert type(hop) == int, "Hop size must be integer!" |
| | return hop |
| | else: |
| | assert 0 < hop < 1, "Hop fraction has to be in range (0,1)!" |
| | return int(len(wind) * hop) |
| |
|
| | def stana(self, sig, sr, wind, hop, synth=False, center=False): |
| | """ |
| | Short term analysis by windowing |
| | """ |
| | ssize = len(sig) |
| | fsize = len(wind) |
| | hsize = self.hop2hsize(wind, hop) |
| | if synth: |
| | sstart = hsize - fsize |
| | elif center: |
| | sstart = -int(len(wind) / 2) |
| | else: |
| | sstart = 0 |
| | send = ssize |
| |
|
| | nframe = math.ceil((send - sstart) / hsize) |
| |
|
| | |
| | zpleft = -sstart |
| | zpright = (nframe - 1) * hsize + fsize - zpleft - ssize |
| | if zpleft > 0 or zpright > 0: |
| | sigpad = np.zeros(ssize + zpleft + zpright, dtype=sig.dtype) |
| | sigpad[zpleft:len(sigpad) - zpright] = sig |
| | else: |
| | sigpad = sig |
| |
|
| | return as_strided(sigpad, shape=(nframe, fsize), |
| | strides=(sig.itemsize * hsize, sig.itemsize)) * wind |
| |
|
| | def stft(self, sig, sr, wind, hop, nfft): |
| | """ |
| | Compute STFT: window + rfft |
| | """ |
| | frames = self.stana(sig, sr, wind, hop, synth=True) |
| | return rfft(frames, n=nfft) |
| |
|
| | def stft_transform(self, audio, dft_size=512, hop_fraction=0.5, sr=16000): |
| | """ |
| | Compute STFT parameters, then compute STFT |
| | """ |
| | window = np.hamming(dft_size + 1) |
| | window = window[:-1] |
| | amp = np.abs(self.stft(audio, sr, window, hop_fraction, dft_size)) |
| | feat = self.logpow_dns(amp, floor=-120.) |
| | return feat / 20. |
| |
|
| | def run(self, audio_degraded, audio_clean=None, combined=False): |
| | """ |
| | Run the PLCMOS model and return the MOS for the given audio. If a clean audio file is passed and the |
| | selected model version has an intrusive version, that version will be used, otherwise, the nonintrusive |
| | model will be used. If combined is set to true (default), the mean of intrusive and nonintrusive models |
| | results will be returned, when both are available |
| | |
| | For intrusive models, the clean reference should be the unprocessed audio file the degraded audio is |
| | based on. It is not required to be aligned with the degraded audio. |
| | |
| | Audio data should be 16kHz, mono, [-1, 1] range. |
| | """ |
| | audio_features_degraded = np.float32(self.stft_transform(audio_degraded))[ |
| | np.newaxis, np.newaxis, ...] |
| | assert len( |
| | audio_features_degraded) <= self.max_lens[0], "Maximum input length exceeded" |
| |
|
| | if audio_clean is None: |
| | combined = False |
| |
|
| | mos = 0 |
| |
|
| | session = self.sessions[0] |
| | assert not session is None, "Intrusive model not available for this model version." |
| | audio_features_clean = np.float32(self.stft_transform(audio_clean))[ |
| | np.newaxis, np.newaxis, ...] |
| | assert len( |
| | audio_features_clean) <= self.max_lens[0], "Maximum input length exceeded" |
| | onnx_inputs = {"degraded_audio": audio_features_degraded, |
| | "clean_audio": audio_features_clean} |
| | mos = float(session.run(None, onnx_inputs)[0]) |
| |
|
| | session = self.sessions[1] |
| | assert not session is None, "Nonintrusive model not available for this model version." |
| | onnx_inputs = {"degraded_audio": audio_features_degraded} |
| | mos_2 = float(session.run(None, onnx_inputs)[0]) |
| | mos = [mos, mos_2] |
| | return mos |
| |
|