|
|
from pathlib import Path |
|
|
import os |
|
|
from functools import partial |
|
|
|
|
|
from frechet_audio_distance import FrechetAudioDistance |
|
|
import pandas |
|
|
import argbind |
|
|
import torch |
|
|
from tqdm import tqdm |
|
|
|
|
|
import audiotools |
|
|
from audiotools import AudioSignal |
|
|
|
|
|
@argbind.bind(without_prefix=True) |
|
|
def eval( |
|
|
exp_dir: str = None, |
|
|
baseline_key: str = "baseline", |
|
|
audio_ext: str = ".wav", |
|
|
): |
|
|
assert exp_dir is not None |
|
|
exp_dir = Path(exp_dir) |
|
|
assert exp_dir.exists(), f"exp_dir {exp_dir} does not exist" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mel_loss = audiotools.metrics.spectral.MelSpectrogramLoss() |
|
|
frechet = FrechetAudioDistance( |
|
|
use_pca=False, |
|
|
use_activation=False, |
|
|
verbose=True, |
|
|
audio_load_worker=4, |
|
|
) |
|
|
frechet.model.to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
|
|
|
conditions = [d.name for d in exp_dir.iterdir() if d.is_dir()] |
|
|
|
|
|
assert baseline_key in conditions, f"baseline_key {baseline_key} not found in {exp_dir}" |
|
|
conditions.remove(baseline_key) |
|
|
|
|
|
print(f"Found {len(conditions)} conditions in {exp_dir}") |
|
|
print(f"conditions: {conditions}") |
|
|
|
|
|
baseline_dir = exp_dir / baseline_key |
|
|
baseline_files = sorted(list(baseline_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem)) |
|
|
|
|
|
metrics = [] |
|
|
for condition in tqdm(conditions): |
|
|
cond_dir = exp_dir / condition |
|
|
cond_files = sorted(list(cond_dir.glob(f"*{audio_ext}")), key=lambda x: int(x.stem)) |
|
|
|
|
|
print(f"computing fad for {baseline_dir} and {cond_dir}") |
|
|
frechet_score = frechet.score(baseline_dir, cond_dir) |
|
|
|
|
|
|
|
|
num_files = min(len(baseline_files), len(cond_files)) |
|
|
baseline_files = baseline_files[:num_files] |
|
|
cond_files = cond_files[:num_files] |
|
|
assert len(list(baseline_files)) == len(list(cond_files)), f"number of files in {baseline_dir} and {cond_dir} do not match. {len(list(baseline_files))} vs {len(list(cond_files))}" |
|
|
|
|
|
def process(baseline_file, cond_file): |
|
|
|
|
|
assert baseline_file.stem == cond_file.stem, f"baseline file {baseline_file} and cond file {cond_file} do not match" |
|
|
|
|
|
|
|
|
baseline_sig = AudioSignal(str(baseline_file)) |
|
|
cond_sig = AudioSignal(str(cond_file)) |
|
|
|
|
|
cond_sig.resample(baseline_sig.sample_rate) |
|
|
cond_sig.truncate_samples(baseline_sig.length) |
|
|
|
|
|
|
|
|
if "inpaint" in condition: |
|
|
ctx_amt = float(condition.split("_")[-1]) |
|
|
ctx_samples = int(ctx_amt * baseline_sig.sample_rate) |
|
|
print(f"found inpainting condition. trimming off {ctx_samples} samples from {cond_file} and {baseline_file}") |
|
|
cond_sig.trim(ctx_samples, ctx_samples) |
|
|
baseline_sig.trim(ctx_samples, ctx_samples) |
|
|
|
|
|
return { |
|
|
|
|
|
|
|
|
"mel": mel_loss(baseline_sig, cond_sig).item(), |
|
|
"frechet": frechet_score, |
|
|
|
|
|
"condition": condition, |
|
|
"file": baseline_file.stem, |
|
|
} |
|
|
|
|
|
print(f"processing {len(baseline_files)} files in {baseline_dir} and {cond_dir}") |
|
|
metrics.extend(tqdm(map(process, baseline_files, cond_files), total=len(baseline_files))) |
|
|
|
|
|
metric_keys = [k for k in metrics[0].keys() if k not in ("condition", "file")] |
|
|
|
|
|
|
|
|
for mk in metric_keys: |
|
|
stat = pandas.DataFrame(metrics) |
|
|
stat = stat.groupby(['condition'])[mk].agg(['mean', 'count', 'std']) |
|
|
stat.to_csv(exp_dir / f"stats-{mk}.csv") |
|
|
|
|
|
df = pandas.DataFrame(metrics) |
|
|
df.to_csv(exp_dir / "metrics-all.csv", index=False) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = argbind.parse_args() |
|
|
|
|
|
with argbind.scope(args): |
|
|
eval() |