from pathlib import Path import argbind import audiotools import pandas import torch from audiotools import AudioSignal from frechet_audio_distance import FrechetAudioDistance from tqdm import tqdm @argbind.bind(without_prefix=True) def eval( exp_dir: str = None, baseline_key: str = "baseline", audio_ext: str = ".wav", ): assert exp_dir is not None exp_dir = Path(exp_dir) assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist" # set up our metrics # sisdr_loss = audiotools.metrics.distance.SISDRLoss() # stft_loss = audiotools.metrics.spectral.MultiScaleSTFTLoss() mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss() frechet = FrechetAudioDistance( use_pca=False, use_activation=False, verbose=True, audio_load_worker=4, ) frechet.model.to("cuda" if torch.cuda.is_available() else "cpu") # figure out what conditions we have conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()] assert baseline_key in conditions, ( f"baseline_key {baseline_key} not found in {exp_dir}" ) conditions.remove(baseline_key) print(f"Found {len(conditions)} conditions in {exp_dir}") print(f"conditions: {conditions}") baseline_dir = exp_dir / baseline_key baseline_files = sorted( list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem) ) metrics = [] for condition in tqdm(conditions): cond_dir = exp_dir / condition cond_files = sorted( list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem) ) print(f"computing fad for {baseline_dir} and {cond_dir}") frechet_score = frechet.score(baseline_dir, cond_dir) # make sure we have the same number of files num_files = min(len(baseline_files), len(cond_files)) baseline_files = baseline_files[:num_files] cond_files = cond_files[:num_files] assert len(list(baseline_files)) == len(list(cond_files)), ( f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}" ) def process(baseline_file, cond_file): # make sure the files match (same name) assert baseline_file.stem == cond_file.stem, ( f"baseline file {baseline_file} and cond file {cond_file} do not match" ) # load the files baseline_sig = AudioSignal(str(baseline_file)) cond_sig = AudioSignal(str(cond_file)) cond_sig.resample(baseline_sig.sample_rate) cond_sig.truncate_samples(baseline_sig.length) # if our condition is inpainting, we need to trim the conditioning off if "inpaint" in condition: ctx_amt = float(condition.split("_")[-1]) ctx_samples = int(ctx_amt * baseline_sig.sample_rate) print( f"found inpainting condition. trimming off {ctx_samples} samples from {cond_file} and {baseline_file}" ) cond_sig.trim(ctx_samples, ctx_samples) baseline_sig.trim(ctx_samples, ctx_samples) return { # "sisdr": -sisdr_loss(baseline_sig, cond_sig).item(), # "stft": stft_loss(baseline_sig, cond_sig).item(), "mel": mel_loss(baseline_sig, cond_sig).item(), "frechet": frechet_score, # "visqol": vsq, "condition": condition, "file": baseline_file.stem, } print( f"processing {len(baseline_files)} files in {baseline_dir} and {cond_dir}" ) metrics.extend( tqdm(map(process, baseline_files, cond_files), total=len(baseline_files)) ) metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")] for mk in metric_keys: stat = pandas.DataFrame(metrics) stat = stat.groupby(["condition"])[mk].agg(["mean", "count", "std"]) stat.to_csv(exp_dir / f"stats-{mk}.csv") df = pandas.DataFrame(metrics) df.to_csv(exp_dir / "metrics-all.csv", index=False) if __name__ == "__main__": args = argbind.parse_args() with argbind.scope(args): eval()