Spaces:
Paused
Paused
| import json | |
| import os | |
| import librosa | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| from scipy.linalg import sqrtm | |
| from metrics.pipelines import sample_pipeline, sample_pipeline_GAN | |
| from metrics.pipelines_STFT import sample_pipeline_STFT, sample_pipeline_GAN_STFT | |
| from tools import rms_normalize | |
| def ASTaudio2feature(device, signal, processor, AST, sampling_rate): | |
| # audio file is decoded on the fly | |
| inputs = processor(signal, sampling_rate=sampling_rate, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| outputs = AST(**inputs) | |
| last_hidden_states = outputs.last_hidden_state[:, 0, :].to("cpu").detach().numpy() | |
| return last_hidden_states | |
| # 计算两个numpy数组的均值和协方差矩阵 | |
| def calculate_statistics(features): | |
| mu = np.mean(features, axis=0) | |
| sigma = np.cov(features, rowvar=False) | |
| return mu, sigma | |
| # 计算FID | |
| def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6): | |
| # 在协方差矩阵对角线上添加一个小的正值 | |
| sigma1 += np.eye(sigma1.shape[0]) * eps | |
| sigma2 += np.eye(sigma2.shape[0]) * eps | |
| ssdiff = np.sum((mu1 - mu2) ** 2.0) | |
| covmean = sqrtm(sigma1.dot(sigma2)) | |
| # 由于数值问题,有时可能会得到复数,只取实部 | |
| if np.iscomplexobj(covmean): | |
| covmean = covmean.real | |
| fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean) | |
| return fid | |
| # 计算FID | |
| def calculate_fid_dict(dict1, dict2, eps=1e-6): | |
| # 在协方差矩阵对角线上添加一个小的正值 | |
| mu1, sigma1 = dict1["mu"], dict1["sigma"] | |
| mu2, sigma2 = dict2["mu"], dict2["sigma"] | |
| sigma1 += np.eye(sigma1.shape[0]) * eps | |
| sigma2 += np.eye(sigma2.shape[0]) * eps | |
| ssdiff = np.sum((mu1 - mu2) ** 2.0) | |
| covmean = sqrtm(sigma1.dot(sigma2)) | |
| # 由于数值问题,有时可能会得到复数,只取实部 | |
| if np.iscomplexobj(covmean): | |
| covmean = covmean.real | |
| fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean) | |
| return fid | |
| # Todo: AudioLDM | |
| # def generate_features_with_AudioLDM_and_AST(device, processor, AST, AudioLDM_signals_directory_path, return_feature=False): | |
| # diffuSynth_features = [] | |
| # # Step 1: Load all wav files in AudioLDM_signals_directory_path | |
| # AudioLDM_signals = [] | |
| # signal_lengths = set() | |
| # for file_name in os.listdir(AudioLDM_signals_directory_path): | |
| # if file_name.endswith('.wav'): | |
| # file_path = os.path.join(AudioLDM_signals_directory_path, file_name) | |
| # signal, sr = librosa.load(file_path, sr=16000) # Load audio file with sampling rate 16000 | |
| # # Normalize | |
| # AudioLDM_signals.append(rms_normalize(signal)) | |
| # signal_lengths.add(len(signal)) | |
| # # Step 2: Check if all signals have the same length | |
| # if len(signal_lengths) != 1: | |
| # raise ValueError("Not all signals have the same length. Please ensure all audio files are of the same length.") | |
| # # Step 3: Reshape to signal_batches [number_batches, batch_size=8, signal_length] | |
| # batch_size = 8 | |
| # signal_length = signal_lengths.pop() # All lengths are the same, get one of them | |
| # # Create batches | |
| # signal_batches = [AudioLDM_signals[i:i + batch_size] for i in range(0, len(AudioLDM_signals), batch_size)] | |
| # for signal_batch in tqdm(signal_batches): | |
| # features = ASTaudio2feature(device, signal_batch, processor, AST, sampling_rate=16000) | |
| # diffuSynth_features.extend(features) | |
| # if return_feature: | |
| # return diffuSynth_features | |
| # else: | |
| # mu, sigma = calculate_statistics(diffuSynth_features) | |
| # return {"mu": mu, "sigma": sigma} | |
| def generate_features_with_AudioLDM_and_AST(device, processor, AST, AudioLDM_signals_directory_path, return_feature=False): | |
| diffuSynth_features = [] | |
| # Step 1: Load all wav files in AudioLDM_signals_directory_path | |
| AudioLDM_signals = [] | |
| signal_lengths = set() | |
| target_length = 4 * 16000 # 4 seconds * 16000 samples per second | |
| for file_name in os.listdir(AudioLDM_signals_directory_path): | |
| if file_name.endswith('.wav') and not file_name.startswith('._'): | |
| file_path = os.path.join(AudioLDM_signals_directory_path, file_name) | |
| try: | |
| signal, sr = librosa.load(file_path, sr=16000) # Load audio file with sampling rate 16000 | |
| if len(signal) >= target_length: | |
| signal = signal[:target_length] # Take only the first 4 seconds | |
| else: | |
| raise ValueError(f"The file {file_name} is shorter than 4 seconds.") | |
| # Normalize | |
| AudioLDM_signals.append(rms_normalize(signal)) | |
| signal_lengths.add(len(signal)) | |
| except Exception as e: | |
| print(f"Error loading {file_name}: {e}") | |
| # Step 2: Check if all signals have the same length | |
| if len(signal_lengths) != 1: | |
| raise ValueError("Not all signals have the same length. Please ensure all audio files are of the same length.") | |
| # Step 3: Reshape to signal_batches [number_batches, batch_size=8, signal_length] | |
| batch_size = 8 | |
| signal_length = signal_lengths.pop() # All lengths are the same, get one of them | |
| # Create batches | |
| signal_batches = [AudioLDM_signals[i:i + batch_size] for i in range(0, len(AudioLDM_signals), batch_size)] | |
| for signal_batch in tqdm(signal_batches): | |
| features = ASTaudio2feature(device, signal_batch, processor, AST, sampling_rate=16000) | |
| diffuSynth_features.extend(features) | |
| if return_feature: | |
| return diffuSynth_features | |
| else: | |
| mu, sigma = calculate_statistics(diffuSynth_features) | |
| return {"mu": mu, "sigma": sigma} | |
| def generate_features_with_diffuSynth_and_AST(device, uNet, VAE, mmm, CLAP_tokenizer, processor, AST, num_batches, | |
| positive_prompts, negative_prompts="", CFG=1, sample_steps=10, task="spectrograms", return_feature=False): | |
| diffuSynth_features = [] | |
| if task == "spectrograms": | |
| pipe = sample_pipeline | |
| elif task == "STFT": | |
| pipe = sample_pipeline_STFT | |
| else: | |
| raise NotImplementedError | |
| for _ in tqdm(range(num_batches)): | |
| quantized_latent_representations, reconstruction_batch, signals = pipe(device, uNet, VAE, mmm, | |
| CLAP_tokenizer, | |
| positive_prompts=positive_prompts, | |
| negative_prompts=negative_prompts, | |
| batchsize=8, | |
| sample_steps=sample_steps, | |
| CFG=CFG, seed=None, | |
| return_latent=False) | |
| features = ASTaudio2feature(device, signals, processor, AST, sampling_rate=16000) | |
| diffuSynth_features.extend(features) | |
| if return_feature: | |
| return diffuSynth_features | |
| else: | |
| mu, sigma = calculate_statistics(diffuSynth_features) | |
| return {"mu": mu, "sigma": sigma} | |
| def generate_features_with_GAN_and_AST(device, gan_generator, VAE, mmm, CLAP_tokenizer, processor, AST, num_batches, | |
| positive_prompts, negative_prompts="", CFG=1, sample_steps=10, task="spectrograms", return_feature=False): | |
| diffuSynth_features = [] | |
| if task == "spectrograms": | |
| pipe = sample_pipeline_GAN | |
| elif task == "STFT": | |
| pipe = sample_pipeline_GAN_STFT | |
| else: | |
| raise NotImplementedError | |
| for _ in tqdm(range(num_batches)): | |
| quantized_latent_representations, reconstruction_batch, signals = pipe(device, gan_generator, VAE, mmm, | |
| CLAP_tokenizer, | |
| positive_prompts=positive_prompts, | |
| negative_prompts=negative_prompts, | |
| batchsize=8, | |
| sample_steps=sample_steps, | |
| CFG=CFG, seed=None, | |
| return_latent=False) | |
| features = ASTaudio2feature(device, signals, processor, AST, sampling_rate=16000) | |
| diffuSynth_features.extend(features) | |
| if return_feature: | |
| return diffuSynth_features | |
| else: | |
| mu, sigma = calculate_statistics(diffuSynth_features) | |
| return {"mu": mu, "sigma": sigma} | |
| def get_FD(train_features, device, uNet, VAE, mmm, CLAP_tokenizer, processor, AST, num_batches, positive_prompts, | |
| negative_prompts="", CFG=1, sample_steps=10): | |
| diffuSynth_features = generate_features_with_diffuSynth_and_AST(device, uNet, VAE, mmm, CLAP_tokenizer, processor, | |
| AST, num_batches, positive_prompts, | |
| negative_prompts=negative_prompts, CFG=CFG, | |
| sample_steps=sample_steps) | |
| mu_real, sigma_real = calculate_statistics(train_features) | |
| mu_gen, sigma_gen = calculate_statistics(diffuSynth_features) | |
| fid_score = calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen) | |
| print('FID score:', fid_score) | |
| def get_fid_score(feature1, features2): | |
| mu_real, sigma_real = calculate_statistics(feature1) | |
| mu_gen, sigma_gen = calculate_statistics(features2) | |
| fid_score = calculate_fid(mu_real, sigma_real, mu_gen, sigma_gen) | |
| # print('FID score:', fid_score) | |
| return fid_score | |
| def calculate_fid_matrix(features_list_1, features_list_2, get_fid_score): | |
| # 初始化一个矩阵来存储FID分数 | |
| # 矩阵的大小为 len(features_list_1) x len(features_list_2) | |
| fid_scores = [[0 for _ in range(len(features_list_2))] for _ in range(len(features_list_1))] | |
| # 遍历两个列表,并计算每一对特征集合的FID分数 | |
| for i, feature1 in enumerate(features_list_1): | |
| for j, feature2 in enumerate(features_list_2): | |
| fid_scores[i][j] = get_fid_score(feature1, feature2) | |
| return fid_scores | |
| def save_AST_feature(key, mu, sigma, path='results/AST_metric/pre_calculated_features/AST_features.json'): | |
| # 尝试打开并读取现有的JSON文件 | |
| try: | |
| with open(path, 'r') as file: | |
| data = json.load(file) | |
| except FileNotFoundError: | |
| # 如果文件不存在,创建一个新的字典 | |
| data = {} | |
| if isinstance(mu, np.ndarray): | |
| mu = mu.tolist() | |
| if isinstance(sigma, np.ndarray): | |
| sigma = sigma.tolist() | |
| # 添加新数据 | |
| data[key] = {"mu": mu, "sigma": sigma} | |
| # 将更新后的数据写回文件 | |
| with open(path, 'w') as file: | |
| json.dump(data, file, indent=4) | |
| def read_AST_features(path='results/AST_metric/pre_calculated_features/AST_features.json'): | |
| try: | |
| # 尝试打开并读取JSON文件 | |
| with open(path, 'r') as file: | |
| AST_features = json.load(file) | |
| for AST_feature_name in AST_features.keys(): | |
| AST_features[AST_feature_name]["mu"] = np.array(AST_features[AST_feature_name]["mu"]) | |
| AST_features[AST_feature_name]["sigma"] = np.array(AST_features[AST_feature_name]["sigma"]) | |
| return AST_features | |
| except FileNotFoundError: | |
| # 如果文件不存在,返回一个空字典 | |
| print(f"文件 {path} 未找到.") | |
| return {} | |
| except json.JSONDecodeError: | |
| # 如果文件不是有效的JSON,返回一个空字典 | |
| print(f"文件 {path} 不是有效的JSON格式.") | |
| return {} |