from typing import Union, Callable, List, Optional, Dict import torch import torch.nn as nn from torch.utils.data import Dataset, DataLoader from torch.optim import Adam import numpy as np import librosa import miniaudio from pathlib import Path from sklearn.model_selection import train_test_split from tqdm import tqdm from functools import partial import math from mae import MaskedAutoencoderViT def load_audio( path: str, sr: int = 32000, duration: int = 20, ) -> (np.ndarray, int): g = miniaudio.stream_file(path, output_format=miniaudio.SampleFormat.FLOAT32, nchannels=1, sample_rate=sr, frames_to_read=sr * duration) signal = np.array(next(g)) return signal def mel_spectrogram( signal: np.ndarray, sr: int = 32000, n_fft: int = 800, hop_length: int = 320, n_mels: int = 128, ) -> np.ndarray: mel_spec = librosa.feature.melspectrogram( y=signal, sr=sr, n_fft=n_fft, hop_length=hop_length, n_mels=n_mels, window='hann', pad_mode='constant' ) mel_spec = librosa.power_to_db(mel_spec) # (freq, time) return mel_spec.T # (time, freq) def normalize(arr: np.ndarray, eps: float = 1e-8) -> np.ndarray: return (arr - arr.mean()) / (arr.std() + eps) device = 'cuda:0' seed = 42 train_size = 0.8 # 80% train, 20% test batch_size_train = 10 batch_size_test = 32 num_workers = 1 lr = 1e-3 epochs = 200 detection_epoch = 20 sr = 32000 n_fft = 800 # 25ms hop_length = 320 # 10ms duration = 10000 # seconds. 10000 ~= Inf for reading the whole audio file feature_length = 2048 # length of mel spectrogram (MAE is trained with 2048x128 mel spectrogram) patch_size = 16 # MAE split the mel spectrogram into patches with size 16x16 feature_padding = True header = 'mean' mlp_num_neurons = [768, 10] mlp_activation_layer = nn.ReLU mlp_bias = True torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed_all(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # =============================== model =============================== mae = MaskedAutoencoderViT( img_size=(2048, 128), patch_size=16, in_chans=1, embed_dim=768, depth=12, num_heads=12, decoder_mode=1, no_shift=False, decoder_embed_dim=512, norm_layer=partial(nn.LayerNorm, eps=1e-6), norm_pix_loss=False, pos_trainable=False, ) # Load pre-trained weights ckpt_path = 'music-mae-32kHz.pth.pth' mae.load_state_dict(torch.load(ckpt_path, map_location='cpu')) mae.to(device) mae.eval() # =============================== data =============================== fp = Path('GTZAN-dataset/genres_original') audio_data = dict() # {genre: [audio_file1, audio_file2, ...]} for d in fp.iterdir(): if d.is_dir(): for f in d.iterdir(): if f.is_file(): genres = f.name.split('.')[0] if genres not in audio_data: audio_data[genres] = [str(f)] else: audio_data[genres].append(str(f)) audio_data_train = dict() audio_data_test = dict() for k, v in audio_data.items(): train_data, test_data = train_test_split(v, train_size=train_size, random_state=seed, shuffle=True) audio_data_train[k] = train_data audio_data_test[k] = test_data @torch.no_grad() def infer_mae_embedding(data: Dict) -> Dict: emb_data = dict() # {genre: [embed1, embed2, ...]} for k, v in tqdm(data.items(), desc='infer mae embedding', total=len(data)): for f in v: try: mel_spec = mel_spectrogram(load_audio(f, duration=duration), sr=sr, n_fft=n_fft, hop_length=hop_length) except Exception as e: print(e) print(f) continue # pad the mel spectrogram to the multiple of patch_size input_length = mel_spec.shape[0] n = math.ceil(input_length / patch_size) if input_length < patch_size * n: pad_length = patch_size * n - input_length mel_spec = np.pad(mel_spec, ((0, pad_length), (0, 0)), mode='constant', constant_values=mel_spec.min()) # if the length of mel spectrogram after padding is longer than feature_length, # split it into multiple snippets input_length = mel_spec.shape[0] embeds = [] for i in range(0, input_length, feature_length): snippet = mel_spec[i:i + feature_length] snippet = normalize(snippet) snippet = snippet[None, None, :, :] x = torch.from_numpy(snippet).to(device) y = mae.forward_encoder_no_mask(x, header=header) # (1, 768) y = y / y.norm(p=2, dim=-1, keepdim=True) # normalize y = y.cpu().numpy().squeeze() embeds.append(y) y = np.mean(embeds, axis=0) # (768,) if k not in emb_data: emb_data[k] = [y] else: emb_data[k].append(y) return emb_data audio_emb_train = infer_mae_embedding(audio_data_train) audio_emb_test = infer_mae_embedding(audio_data_test) label_set = set(audio_emb_train.keys()) label_map = {label: i for i, label in enumerate(label_set)} print(label_map) class MLP(torch.nn.Sequential): def __init__( self, num_neurons: List[int], activation_layer: Optional[Callable[..., torch.nn.Module]] = torch.nn.ReLU, bias: bool = True, dropout: float = 0.0, ): layers = [] for c_in, c_out in zip(num_neurons[:-1], num_neurons[1:]): layers.append(torch.nn.Linear(c_in, c_out, bias=bias)) layers.append(activation_layer()) layers.append(torch.nn.Dropout(dropout)) # remove the last two layers layers.pop() layers.pop() super().__init__(*layers) class SimpleDataset(Dataset): def __init__(self, dict_data: Dict, label_map: Dict): self.embed_with_label = [] for k, v in dict_data.items(): for emb in v: self.embed_with_label.append((emb, label_map[k])) def __len__(self): return len(self.embed_with_label) def __getitem__(self, idx): return self.embed_with_label[idx] train_dataset = SimpleDataset(audio_emb_train, label_map) test_dataset = SimpleDataset(audio_emb_test, label_map) print(f"len(train_dataset): {len(train_dataset)}") print(f"len(test_dataset): {len(test_dataset)}") def train_one_epoch(model, device, dataloader, loss_fn, optimizer): model.train() # for batch in tqdm(dataloader, desc='train', total=len(dataloader)): for batch in dataloader: x, y = batch x = x.to(device) y = y.to(device) y_logit = model(x) loss = loss_fn(y_logit, y) optimizer.zero_grad() loss.backward() optimizer.step() @torch.no_grad() def eval_one_epoch(model, device, dataloader, loss_fn): model.eval() total_loss = 0.0 total_correct = 0.0 total_num = 0.0 for batch in dataloader: x, y = batch x = x.to(device) y = y.to(device) y_logit = model(x) loss = loss_fn(y_logit, y) total_loss += loss.item() * x.shape[0] total_correct += (y_logit.argmax(dim=-1) == y).sum().item() total_num += x.shape[0] loss = total_loss / total_num acc = total_correct / total_num return loss, acc mlp = MLP( num_neurons=mlp_num_neurons, activation_layer=mlp_activation_layer, bias=mlp_bias, dropout=0.0 ) print(MLP) mlp.to(device) optimizer = Adam(mlp.parameters(), lr=lr) loss_fn = nn.CrossEntropyLoss() train_dataloader = DataLoader(train_dataset, batch_size=batch_size_train, shuffle=True, num_workers=num_workers) test_dataloader = DataLoader(test_dataset, batch_size=batch_size_test, shuffle=False, num_workers=num_workers) test_loss, test_accuracy = eval_one_epoch(mlp, device, test_dataloader, loss_fn) print(f"init: test loss {test_loss:.4f}, test accuracy {test_accuracy:.4f}") best_accuracy = 0.0 at = 0 for epoch in range(epochs): train_one_epoch(mlp, device, train_dataloader, loss_fn, optimizer) test_loss, test_accuracy = eval_one_epoch(mlp, device, test_dataloader, loss_fn) print(f"epoch {epoch}: test loss {test_loss:.4f}, test accuracy {test_accuracy:.4f}") if test_accuracy > best_accuracy: best_accuracy = test_accuracy at = epoch if epoch - at >= detection_epoch: print(f"early stop at epoch {epoch}") print(f"best accuracy: {best_accuracy:.4f} at epoch {at}") break