| | import logging |
| | from utils import logging_utils |
| | logging_utils.config_logger() |
| | import torch |
| | import random |
| | import numpy as np |
| | from data.dataloader import extract_data |
| | import torchaudio |
| | from models.RagaNet import BaseRagaClassifier, ResNetRagaClassifier, Wav2VecTransformer, count_parameters |
| | from collections import OrderedDict |
| |
|
| | np.random.seed(123) |
| | random.seed(123) |
| |
|
| |
|
| | class Evaluator(): |
| |
|
| | def __init__(self, params): |
| | self.params = params |
| | self.device = self.params.device |
| | |
| | _, self.raga2label = extract_data(self.params) |
| | self.raga_list = list(self.raga2label.keys()) |
| | self.label_list = list(self.raga2label.values()) |
| |
|
| | |
| | if params.model == 'base': |
| | self.model = BaseRagaClassifier(params).to(self.device) |
| | elif params.model == 'resnet': |
| | self.model = ResNetRagaClassifier(params).to(self.device) |
| | elif params.model == 'wav2vec': |
| | self.model = Wav2VecTransformer(params).to(self.device) |
| | else: |
| | logging.error("Model must be either 'base', 'resnet', or 'wav2vec'") |
| | |
| | |
| | logging.info("Loading checkpoint %s"%params.best_checkpoint_path) |
| | self.restore_checkpoint('ckpts/resnet_0.7/150classes_alldata_cliplength30/training_checkpoints/best_ckpt.tar') |
| | self.model.eval() |
| |
|
| |
|
| | def normalize(self, audio): |
| | return (audio - torch.mean(audio, dim=1, keepdim=True))/(torch.std(audio, dim=1, keepdim=True) + 1e-5) |
| |
|
| | def pad_audio(self, audio): |
| | pad = (0, self.params.sample_rate*self.params.clip_length - audio.shape[1]) |
| | return torch.nn.functional.pad(audio, pad = pad, value=0) |
| |
|
| | def inference(self, k, audio): |
| | |
| | sample_rate, audio_clip = audio |
| |
|
| | |
| | if len(audio_clip.shape) == 1: |
| | audio_clip = torch.tensor(audio_clip).unsqueeze(0).repeat(2,1).to(torch.float32) |
| | else: |
| | audio_clip = torch.tensor(audio_clip).T.to(torch.float32) |
| | |
| | |
| | resample = torchaudio.transforms.Resample(orig_freq = sample_rate, new_freq = self.params.sample_rate) |
| | audio_clip = resample(audio_clip) |
| |
|
| | |
| | if self.params.normalize: |
| | audio_clip = self.normalize(audio_clip) |
| |
|
| | |
| | if audio_clip.size()[1] < self.params.sample_rate*self.params.clip_length: |
| | audio_clip = self.pad_audio(audio_clip) |
| | |
| | assert not torch.any(torch.isnan(audio_clip)) |
| | audio_clip = audio_clip.to(self.device) |
| |
|
| | with torch.no_grad(): |
| | length = audio_clip.shape[1] |
| | train_length = self.params.sample_rate*self.params.clip_length |
| |
|
| | pred_probs = torch.zeros((self.params.num_classes,)).to(self.device) |
| |
|
| | |
| | num_clips = int(np.floor(length/train_length)) |
| | for i in range(num_clips): |
| | |
| | clip = audio_clip[:, i*train_length:(i+1)*train_length].unsqueeze(0) |
| | |
| | |
| | pred_distribution = self.model(clip).reshape(-1, self.params.num_classes) |
| | pred_probs += 1 / num_clips * (torch.exp(pred_distribution)/torch.exp(pred_distribution).sum(axis = 1, keepdim=True))[0] |
| |
|
| | |
| | pred_probs, labels = pred_probs.sort(descending=True) |
| | pred_probs_topk = pred_probs[:k] |
| | pred_ragas_topk = [self.raga_list[self.label_list.index(label)] for label in labels[:k]] |
| | d = dict(zip(pred_ragas_topk, pred_probs_topk)) |
| | return {k: v.item() for k, v in d.items()} |
| |
|
| | def restore_checkpoint(self, checkpoint_path): |
| | checkpoint = torch.load(checkpoint_path, map_location=self.device) |
| | try: |
| | self.model.load_state_dict(checkpoint['model_state']) |
| | except: |
| | |
| | new_state_dict = OrderedDict() |
| | for k, v in checkpoint['model_state'].items(): |
| | name = k[7:] |
| | new_state_dict[name] = v |
| | |
| | self.model.load_state_dict(new_state_dict) |
| | |
| | self.iters = checkpoint['iters'] |
| | self.startEpoch = checkpoint['epoch'] |
| |
|
| |
|
| | |
| | |
| | |
| | |