Spaces:
Runtime error
Runtime error
| import torch | |
| import torchaudio | |
| def load_model(model_path, model_class, num_classes=10): | |
| model = model_class(num_classes=num_classes) | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
| model.eval() | |
| return model | |
| def preprocess_audio(file_path): | |
| waveform, sample_rate = torchaudio.load(file_path) | |
| transform = torchaudio.transforms.MelSpectrogram( | |
| sample_rate=sample_rate, | |
| n_mels=64, | |
| n_fft=1024, | |
| hop_length=512 | |
| ) | |
| mel_spec = transform(waveform) | |
| mel_spec = mel_spec.mean(dim=0, keepdim=True) | |
| return mel_spec | |
| def predict(model, input_tensor, labels): | |
| with torch.no_grad(): | |
| outputs = model(input_tensor.unsqueeze(0)) | |
| _, predicted = torch.max(outputs.data, 1) | |
| return labels[predicted.item()] | |