Spaces:
Runtime error
Runtime error
File size: 811 Bytes
7e8c9b0 8570ecf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 |
import torch
import torchaudio
def load_model(model_path, model_class, num_classes=10):
model = model_class(num_classes=num_classes)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
return model
def preprocess_audio(file_path):
waveform, sample_rate = torchaudio.load(file_path)
transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_mels=64,
n_fft=1024,
hop_length=512
)
mel_spec = transform(waveform)
mel_spec = mel_spec.mean(dim=0, keepdim=True)
return mel_spec
def predict(model, input_tensor, labels):
with torch.no_grad():
outputs = model(input_tensor.unsqueeze(0))
_, predicted = torch.max(outputs.data, 1)
return labels[predicted.item()]
|