File size: 2,601 Bytes
eb34860
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
"""Gradio demo for AnimeScore: audio in -> anime-likeness score out."""

import sys
from pathlib import Path

import gradio as gr
import torch
import torchaudio

HERE = Path(__file__).resolve().parent
sys.path.insert(0, str(HERE))

from modeling_animescore import AnimeScoreConfig, AnimeScoreRankNet
from safetensors.torch import load_file


DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")


def _build_model() -> AnimeScoreRankNet:
    cfg = AnimeScoreConfig.from_json_file(str(HERE / "config.json"))
    model = AnimeScoreRankNet(cfg).to(DEVICE).eval()
    sd = load_file(str(HERE / "model.safetensors"))
    missing, unexpected = model.load_state_dict(sd, strict=False)
    if [m for m in missing if not m.startswith("ssl.")]:
        raise RuntimeError(f"unexpected missing head keys: {missing}")
    if unexpected:
        raise RuntimeError(f"unexpected keys in safetensors: {unexpected}")
    return model


MODEL = _build_model()
TARGET_SR = MODEL.config.target_sr


def _read_audio(path: str):
    """Load audio to a [channels, frames] float32 tensor and its sample rate.

    Uses soundfile (self-contained libsndfile) first so the demo does not depend
    on torchaudio's optional torchcodec/ffmpeg backend; falls back to
    torchaudio.load for the rare format libsndfile cannot decode.
    """
    try:
        import soundfile as sf
        data, sr = sf.read(path, dtype="float32", always_2d=True)  # [frames, ch]
        return torch.from_numpy(data.T).contiguous(), sr
    except Exception:
        wav, sr = torchaudio.load(path)
        return wav.to(torch.float32), sr


def _load_wav_to_tensor(path: str) -> torch.Tensor:
    wav, sr = _read_audio(path)
    if wav.size(0) > 1:
        wav = wav.mean(0, keepdim=True)
    if sr != TARGET_SR:
        wav = torchaudio.functional.resample(wav, sr, TARGET_SR)
    return wav.to(DEVICE)


def predict(audio):
    if audio is None:
        return "—"
    wav = _load_wav_to_tensor(audio)
    with torch.no_grad():
        score = MODEL.score(wav).item()
    return f"{score:.4f}"


with gr.Blocks(title="AnimeScore") as demo:
    gr.Markdown("# AnimeScore\n\nScore a speech clip for anime-likeness. Higher = more anime-like.")
    audio_in = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio")
    run = gr.Button("Score", variant="primary")
    score_out = gr.Textbox(label="AnimeScore", interactive=False)

    run.click(predict, inputs=audio_in, outputs=score_out)
    audio_in.change(predict, inputs=audio_in, outputs=score_out)


if __name__ == "__main__":
    demo.queue().launch()