File size: 3,343 Bytes
eb34860
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""Minimal inference example for the AnimeScore release.

Loads the model from a local directory (or HuggingFace Hub) and scores
either a single wav or a directory of wavs.

Usage:
    # single wav
    python example_inference.py --ckpt . --wav path/to/audio.wav

    # batch over a directory
    python example_inference.py --ckpt . --dir path/to/wavs --csv out.csv

    # pairwise probability A > B
    python example_inference.py --ckpt . --pair a.wav b.wav
"""

import argparse
import os
from pathlib import Path

import torch
import torchaudio
from transformers import AutoModel


def _read_audio(path: str):
    """Load audio to a [channels, frames] float32 tensor and its sample rate.

    Prefers soundfile (self-contained libsndfile) so this does not depend on
    torchaudio's optional torchcodec/ffmpeg backend; falls back to
    torchaudio.load for the rare format libsndfile cannot decode.
    """
    try:
        import soundfile as sf
        data, sr = sf.read(path, dtype="float32", always_2d=True)  # [frames, ch]
        return torch.from_numpy(data.T).contiguous(), sr
    except Exception:
        wav, sr = torchaudio.load(path)
        return wav.to(torch.float32), sr


def load_wav(path: str, target_sr: int = 16000) -> torch.Tensor:
    wav, sr = _read_audio(path)
    if wav.size(0) > 1:
        wav = wav.mean(dim=0, keepdim=True)
    if sr != target_sr:
        wav = torchaudio.functional.resample(wav, sr, target_sr)
    return wav.squeeze(0)


def score_paths(model, paths, device):
    scores = []
    for p in paths:
        wav = load_wav(p, model.config.target_sr).unsqueeze(0).to(device)
        s = model.score(wav).item()
        scores.append(s)
    return scores


def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--ckpt", required=True, help="HF repo id or local directory")
    ap.add_argument("--wav", help="single wav path")
    ap.add_argument("--dir", help="directory of wavs to score")
    ap.add_argument("--pair", nargs=2, metavar=("A", "B"), help="two wavs for pairwise prob")
    ap.add_argument("--csv", default="", help="optional output CSV when using --dir")
    ap.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu")
    args = ap.parse_args()

    model = AutoModel.from_pretrained(args.ckpt, trust_remote_code=True).eval().to(args.device)

    if args.wav:
        s = score_paths(model, [args.wav], args.device)[0]
        print(f"{args.wav}\tanimescore={s:.4f}")

    if args.dir:
        paths = sorted(str(p) for p in Path(args.dir).glob("*.wav"))
        scores = score_paths(model, paths, args.device)
        if args.csv:
            with open(args.csv, "w") as f:
                f.write("path,animescore\n")
                for p, s in zip(paths, scores):
                    f.write(f"{p},{s:.6f}\n")
            print(f"wrote {len(paths)} rows to {args.csv}")
        else:
            for p, s in zip(paths, scores):
                print(f"{p}\t{s:.4f}")

    if args.pair:
        a, b = args.pair
        sa, sb = score_paths(model, [a, b], args.device)
        p_a_gt_b = torch.sigmoid(torch.tensor(sa - sb)).item()
        print(f"score({a}) = {sa:.4f}")
        print(f"score({b}) = {sb:.4f}")
        print(f"P({os.path.basename(a)} > {os.path.basename(b)}) = {p_a_gt_b:.4f}")


if __name__ == "__main__":
    main()