animescore / app.py
nonmetal's picture
AnimeScore release
eb34860
Raw
History Blame Contribute Delete
2.6 kB
"""Gradio demo for AnimeScore: audio in -> anime-likeness score out."""
import sys
from pathlib import Path
import gradio as gr
import torch
import torchaudio
HERE = Path(__file__).resolve().parent
sys.path.insert(0, str(HERE))
from modeling_animescore import AnimeScoreConfig, AnimeScoreRankNet
from safetensors.torch import load_file
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def _build_model() -> AnimeScoreRankNet:
cfg = AnimeScoreConfig.from_json_file(str(HERE / "config.json"))
model = AnimeScoreRankNet(cfg).to(DEVICE).eval()
sd = load_file(str(HERE / "model.safetensors"))
missing, unexpected = model.load_state_dict(sd, strict=False)
if [m for m in missing if not m.startswith("ssl.")]:
raise RuntimeError(f"unexpected missing head keys: {missing}")
if unexpected:
raise RuntimeError(f"unexpected keys in safetensors: {unexpected}")
return model
MODEL = _build_model()
TARGET_SR = MODEL.config.target_sr
def _read_audio(path: str):
"""Load audio to a [channels, frames] float32 tensor and its sample rate.
Uses soundfile (self-contained libsndfile) first so the demo does not depend
on torchaudio's optional torchcodec/ffmpeg backend; falls back to
torchaudio.load for the rare format libsndfile cannot decode.
"""
try:
import soundfile as sf
data, sr = sf.read(path, dtype="float32", always_2d=True) # [frames, ch]
return torch.from_numpy(data.T).contiguous(), sr
except Exception:
wav, sr = torchaudio.load(path)
return wav.to(torch.float32), sr
def _load_wav_to_tensor(path: str) -> torch.Tensor:
wav, sr = _read_audio(path)
if wav.size(0) > 1:
wav = wav.mean(0, keepdim=True)
if sr != TARGET_SR:
wav = torchaudio.functional.resample(wav, sr, TARGET_SR)
return wav.to(DEVICE)
def predict(audio):
if audio is None:
return "—"
wav = _load_wav_to_tensor(audio)
with torch.no_grad():
score = MODEL.score(wav).item()
return f"{score:.4f}"
with gr.Blocks(title="AnimeScore") as demo:
gr.Markdown("# AnimeScore\n\nScore a speech clip for anime-likeness. Higher = more anime-like.")
audio_in = gr.Audio(sources=["upload", "microphone"], type="filepath", label="Audio")
run = gr.Button("Score", variant="primary")
score_out = gr.Textbox(label="AnimeScore", interactive=False)
run.click(predict, inputs=audio_in, outputs=score_out)
audio_in.change(predict, inputs=audio_in, outputs=score_out)
if __name__ == "__main__":
demo.queue().launch()