| import os |
| import librosa |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| import torchaudio |
| from torch.utils.data import Dataset, DataLoader |
|
|
|
|
| from hparams import Hparams |
| from model_cnn import Model |
| from dataset import MyDataset |
|
|
|
|
| args = Hparams.args |
| device = args['device'] |
| split = 'train' |
|
|
| tone_class = 5 |
| NUM_EPOCHS = 100 |
|
|
|
|
|
|
|
|
| |
|
|
| |
| |
|
|
|
|
|
|
| def move_data_to_device(data, device): |
| ret = [] |
| for i in data: |
| if isinstance(i, torch.Tensor): |
| ret.append(i.to(device)) |
| return ret |
|
|
| def collate_fn(batch): |
| |
| inp = [] |
| f0 = [] |
| word = [] |
| tone = [] |
| max_frame_num = 1600 |
| for sample in batch: |
| max_frame_num = max(max_frame_num, sample[0].shape[0], sample[1].shape[0], sample[2].shape[0], sample[3].shape[0]) |
| for sample in batch: |
| inp.append( |
| torch.nn.functional.pad(sample[0], (0, 0, 0, max_frame_num - sample[0].shape[0]), mode='constant', value=0)) |
| f0.append( |
| torch.nn.functional.pad(sample[1], (0, max_frame_num - sample[1].shape[0]), mode='constant', value=0)) |
| word.append( |
| torch.nn.functional.pad(sample[2], (0, 50 - sample[2].shape[0]), mode='constant', value=0)) |
| tone.append( |
| torch.nn.functional.pad(sample[3], (0, 50 - sample[3].shape[0]), mode='constant', value=0)) |
| inp = torch.stack(inp) |
| f0 = torch.stack(f0) |
| word = torch.stack(word) |
| tone = torch.stack(tone) |
|
|
| return inp, f0, word, tone |
|
|
| def get_data_loader(split, args): |
| Dataset = MyDataset( |
| dataset_root=args['dataset_root'], |
| split=split, |
| sampling_rate=args['sampling_rate'], |
| sample_length=args['sample_length'], |
| frame_size=args['frame_size'], |
| ) |
| Dataset.dataset_index=Dataset.dataset_index[:32] |
| Dataset.index=Dataset.index[:32] |
| data_loader = DataLoader( |
| Dataset, |
| batch_size=args['batch_size'], |
| num_workers=args['num_workers'], |
| pin_memory=True, |
| shuffle=True, |
| collate_fn=collate_fn, |
| ) |
|
|
| return data_loader |
|
|
|
|
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
|
|
| def process_sequence(seq): |
| ret = [] |
| for w in seq: |
| if len(ret)==0 or ret[-1]!=w: |
| ret.append(w) |
| return ret |
|
|
|
|
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
|
|
| class ASR_Model: |
| ''' |
| This is main class for training model and making predictions. |
| ''' |
| def __init__(self, device="cpu", model_path=None,pinyin_path ='pinyin.txt'): |
| |
| self.device = device |
|
|
| self.pinyin = {} |
| |
| with open(pinyin_path, 'r') as f: |
| lines = f.readlines() |
| i = 0 |
| for l in lines: |
| self.pinyin[l.replace('\n', '')] = i |
| i += 1 |
| |
| self.idx2char = { idx:char for char,idx in self.pinyin.items()} |
| num_class = 2036 |
|
|
| self.model = Model(syllable_class=num_class).to(self.device) |
| self.sampling_rate = args['sampling_rate'] |
| if model_path is not None: |
| self.model = torch.load(model_path) |
| print('Model loaded.') |
| else: |
| print('Model initialized.') |
| self.model.to(device) |
| |
|
|
| def fit(self, args,NUM_EPOCHS=100): |
| |
| save_model_dir = args['save_model_dir'] |
| if not os.path.exists(save_model_dir): |
| os.mkdir(save_model_dir) |
| loss_fn = nn.CTCLoss() |
| optimizer = optim.Adam(self.model.parameters(), lr=0.001) |
|
|
|
|
| train_loader = get_data_loader(split='train', args=args) |
| valid_loader = get_data_loader(split='train', args=args) |
|
|
| |
| print('Start training...') |
| min_valid_loss = 10000 |
|
|
| self.model.train() |
| for epoch in range(NUM_EPOCHS): |
| for idx, data in enumerate(train_loader): |
| mel, f0, word, tone = move_data_to_device(data, device) |
| input_length = (mel[:,:,0]!=0.0).sum(axis=1) |
| |
| mel = mel.unsqueeze(1) |
| |
| |
| |
| output = self.model(mel) |
| output = output.permute(1,0,2) |
| |
|
|
| output_len = input_length//4 |
| move_data_to_device(output_len, Hparams.args['device']) |
| |
| target_len = (tone!=0).sum(axis=1) |
| |
| target = word*5+tone |
|
|
| loss = loss_fn(output,target,output_len,target_len) |
| optimizer.zero_grad() |
| loss.backward() |
| optimizer.step() |
|
|
| if(idx%100==0): |
| print(f'Epoch {epoch+1},Iteration {idx+1}, Loss: {loss.item()}') |
| |
| |
| self.model.eval() |
| with torch.no_grad(): |
| losses = [] |
| for idx, data in enumerate(valid_loader): |
| mel, f0, word, tone = move_data_to_device(data, device) |
| input_length = (mel[:,:,0]!=0.0).sum(axis=1) |
| mel = mel.unsqueeze(1) |
|
|
| out = self.model(mel) |
| out = out.permute(1,0,2) |
| |
| output_len = input_length//4 |
| move_data_to_device(output_len, Hparams.args['device']) |
| target_len = (tone!=0).sum(axis=1) |
| target = word*5+tone |
| |
| loss = loss_fn(out,target,output_len,target_len) |
| losses.append(loss.item()) |
| loss = np.mean(losses) |
| |
| |
| if loss < min_valid_loss: |
| min_valid_loss = loss |
| target_model_path = save_model_dir + '/best_model.pth' |
| torch.save(self.model, target_model_path) |
|
|
| def to_pinyin(self, num): |
| if num==0: |
| return |
| pinyin,tone = self.idx2char[(num-1)//5],(num-1)%5+1 |
| return pinyin,tone |
| |
| def getsentence(self, words): |
| words = words.tolist() |
| return [self.idx2char[int(word)] for word in words] |
| |
| def predict(self, audio_fp): |
| """Predict results for a given test dataset.""" |
| |
|
|
| waveform, sample_rate = torchaudio.load(audio_fp) |
| waveform = torchaudio.transforms.Resample(sample_rate, self.sampling_rate)(waveform) |
| mel_spec = torchaudio.transforms.MelSpectrogram(sample_rate=self.sampling_rate, n_fft=2048, hop_length=100, n_mels=256)(waveform) |
| mel_spec = torch.mean(mel_spec,0) |
|
|
| waveform, sr = librosa.load(audio_fp, sr=self.sampling_rate) |
| f0 = torch.from_numpy(librosa.yin(waveform, fmin=50, fmax=550, hop_length=100)) |
| mel = torch.tensor(mel_spec.T).unsqueeze(0).unsqueeze(0) |
|
|
| |
| self.model.eval() |
| with torch.no_grad(): |
| output = self.model(mel.to(self.device)) |
| |
| seq = process_sequence(output[0].cpu().numpy().argmax(-1)) |
| result = [self.to_pinyin(c) for c in seq if c!=0] |
|
|
| return result |
|
|
|
|
|
|
|
|