animescore / example_inference.py
nonmetal's picture
AnimeScore release
eb34860
Raw
History Blame Contribute Delete
3.34 kB
"""Minimal inference example for the AnimeScore release.
Loads the model from a local directory (or HuggingFace Hub) and scores
either a single wav or a directory of wavs.
Usage:
# single wav
python example_inference.py --ckpt . --wav path/to/audio.wav
# batch over a directory
python example_inference.py --ckpt . --dir path/to/wavs --csv out.csv
# pairwise probability A > B
python example_inference.py --ckpt . --pair a.wav b.wav
"""
import argparse
import os
from pathlib import Path
import torch
import torchaudio
from transformers import AutoModel
def _read_audio(path: str):
"""Load audio to a [channels, frames] float32 tensor and its sample rate.
Prefers soundfile (self-contained libsndfile) so this does not depend on
torchaudio's optional torchcodec/ffmpeg backend; falls back to
torchaudio.load for the rare format libsndfile cannot decode.
"""
try:
import soundfile as sf
data, sr = sf.read(path, dtype="float32", always_2d=True) # [frames, ch]
return torch.from_numpy(data.T).contiguous(), sr
except Exception:
wav, sr = torchaudio.load(path)
return wav.to(torch.float32), sr
def load_wav(path: str, target_sr: int = 16000) -> torch.Tensor:
wav, sr = _read_audio(path)
if wav.size(0) > 1:
wav = wav.mean(dim=0, keepdim=True)
if sr != target_sr:
wav = torchaudio.functional.resample(wav, sr, target_sr)
return wav.squeeze(0)
def score_paths(model, paths, device):
scores = []
for p in paths:
wav = load_wav(p, model.config.target_sr).unsqueeze(0).to(device)
s = model.score(wav).item()
scores.append(s)
return scores
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--ckpt", required=True, help="HF repo id or local directory")
ap.add_argument("--wav", help="single wav path")
ap.add_argument("--dir", help="directory of wavs to score")
ap.add_argument("--pair", nargs=2, metavar=("A", "B"), help="two wavs for pairwise prob")
ap.add_argument("--csv", default="", help="optional output CSV when using --dir")
ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
args = ap.parse_args()
model = AutoModel.from_pretrained(args.ckpt, trust_remote_code=True).eval().to(args.device)
if args.wav:
s = score_paths(model, [args.wav], args.device)[0]
print(f"{args.wav}\tanimescore={s:.4f}")
if args.dir:
paths = sorted(str(p) for p in Path(args.dir).glob("*.wav"))
scores = score_paths(model, paths, args.device)
if args.csv:
with open(args.csv, "w") as f:
f.write("path,animescore\n")
for p, s in zip(paths, scores):
f.write(f"{p},{s:.6f}\n")
print(f"wrote {len(paths)} rows to {args.csv}")
else:
for p, s in zip(paths, scores):
print(f"{p}\t{s:.4f}")
if args.pair:
a, b = args.pair
sa, sb = score_paths(model, [a, b], args.device)
p_a_gt_b = torch.sigmoid(torch.tensor(sa - sb)).item()
print(f"score({a}) = {sa:.4f}")
print(f"score({b}) = {sb:.4f}")
print(f"P({os.path.basename(a)} > {os.path.basename(b)}) = {p_a_gt_b:.4f}")
if __name__ == "__main__":
main()