Spaces:
Build error
Build error
| import gradio as gr | |
| import torch | |
| import torchaudio | |
| from hydra.utils import instantiate | |
| from src.models import AudioBatch | |
| # Путь к чекпоинтам | |
| CHECKPOINTS = { | |
| "RU+EN": "gigaam_ru_en/gigaam_ru_en.ckpt" | |
| } | |
| # Кэш моделей | |
| LOADED_MODELS = {} | |
| def load_model(ckpt_path): | |
| if ckpt_path in LOADED_MODELS: | |
| return LOADED_MODELS[ckpt_path] | |
| checkpoint = torch.load(ckpt_path, map_location='cpu') | |
| config = checkpoint['config'] | |
| id2name = checkpoint['id2name'] | |
| model = instantiate(config, _recursive_=False) | |
| model.load_state_dict(checkpoint['state_dict']) | |
| model.eval() | |
| LOADED_MODELS[ckpt_path] = (model, id2name) | |
| return model, id2name | |
| def classify_emotion(audio, ckpt_name): | |
| # Load waveform | |
| waveform, sr = torchaudio.load(audio) | |
| # Load model | |
| model, id2name = load_model(CHECKPOINTS[ckpt_name]) | |
| # Если нужно, ресемплим до 16к | |
| if sr != 16000: | |
| waveform = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)(waveform) | |
| # B x T | |
| if waveform.dim() > 1: | |
| waveform = waveform.mean(dim=0, keepdim=True) | |
| length = torch.tensor([waveform.shape[-1]]) | |
| batch = AudioBatch(waveform, length, None) | |
| with torch.no_grad(): | |
| logits, _ = model(batch) | |
| probs = torch.softmax(logits, dim=-1).squeeze(0) | |
| result = {label: float(probs[i]) for i, label in enumerate(id2name)} | |
| return result | |
| demo = gr.Interface( | |
| fn=classify_emotion, | |
| inputs=[ | |
| gr.Audio(type="filepath", label="Загрузите аудиофайл"), | |
| gr.Dropdown(choices=list(CHECKPOINTS.keys()), label="Выберите модель", value="RU+EN") | |
| ], | |
| outputs=gr.Label(label="Эмоциональная окраска (вероятности)"), | |
| title="Эмоциональная классификация речи (дообученная GigaAM на 8 классов)", | |
| description="Загрузите аудио" | |
| ) | |
| demo.launch() | |