Ilia
update app.py
a65497a
Raw
History Blame Contribute Delete
1.99 kB
import gradio as gr
import torch
import torchaudio
from hydra.utils import instantiate
from src.models import AudioBatch
# Путь к чекпоинтам
CHECKPOINTS = {
"RU+EN": "gigaam_ru_en/gigaam_ru_en.ckpt"
}
# Кэш моделей
LOADED_MODELS = {}
def load_model(ckpt_path):
if ckpt_path in LOADED_MODELS:
return LOADED_MODELS[ckpt_path]
checkpoint = torch.load(ckpt_path, map_location='cpu')
config = checkpoint['config']
id2name = checkpoint['id2name']
model = instantiate(config, _recursive_=False)
model.load_state_dict(checkpoint['state_dict'])
model.eval()
LOADED_MODELS[ckpt_path] = (model, id2name)
return model, id2name
def classify_emotion(audio, ckpt_name):
# Load waveform
waveform, sr = torchaudio.load(audio)
# Load model
model, id2name = load_model(CHECKPOINTS[ckpt_name])
# Если нужно, ресемплим до 16к
if sr != 16000:
waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(waveform)
# B x T
if waveform.dim() > 1:
waveform = waveform.mean(dim=0, keepdim=True)
length = torch.tensor([waveform.shape[-1]])
batch = AudioBatch(waveform, length, None)
with torch.no_grad():
logits, _ = model(batch)
probs = torch.softmax(logits, dim=-1).squeeze(0)
result = {label: float(probs[i]) for i, label in enumerate(id2name)}
return result
demo = gr.Interface(
fn=classify_emotion,
inputs=[
gr.Audio(type="filepath", label="Загрузите аудиофайл"),
gr.Dropdown(choices=list(CHECKPOINTS.keys()), label="Выберите модель", value="RU+EN")
],
outputs=gr.Label(label="Эмоциональная окраска (вероятности)"),
title="Эмоциональная классификация речи (дообученная GigaAM на 8 классов)",
description="Загрузите аудио"
)
demo.launch()