opinder2906's picture
Update utils.py
8570ecf verified
raw
history blame contribute delete
811 Bytes
import torch
import torchaudio
def load_model(model_path, model_class, num_classes=10):
model = model_class(num_classes=num_classes)
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
return model
def preprocess_audio(file_path):
waveform, sample_rate = torchaudio.load(file_path)
transform = torchaudio.transforms.MelSpectrogram(
sample_rate=sample_rate,
n_mels=64,
n_fft=1024,
hop_length=512
)
mel_spec = transform(waveform)
mel_spec = mel_spec.mean(dim=0, keepdim=True)
return mel_spec
def predict(model, input_tensor, labels):
with torch.no_grad():
outputs = model(input_tensor.unsqueeze(0))
_, predicted = torch.max(outputs.data, 1)
return labels[predicted.item()]