import os import sys from typing import Any sys.path.append("../") import linecache import mmap import pickle as pkl import random import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchaudio import transformers from accelerate import Accelerator, DistributedDataParallelKwargs from autoregressive import TS_model from cleaners import english_cleaners from librosa.filters import mel as librosa_mel_fn from mel_spec import get_mel_spectrogram from meta_stats import process_file, process_file_for_heads from stft import STFT from torch.utils.data import (DataLoader, Dataset, WeightedRandomSampler, get_worker_info) from tqdm import tqdm from utilities import get_mask_from_lengths import wandb from config import config from Text import code_labels, labels, text_labels torch.manual_seed(config.seed_value) np.random.seed(config.seed_value) random.seed(config.seed_value) print(text_labels) # add semantic tokens: # tok_enc = {j:i for i,j in enumerate(labels)} # tok_dec = {j:i for i,j in enumerate(labels)} # code encdec text_enc = {j: i for i, j in enumerate(text_labels)} text_dec = {i: j for i, j in enumerate(text_labels)} # text encdec code_enc = {j: i for i, j in enumerate(code_labels)} code_dec = {i: j for i, j in enumerate(code_labels)} def read_specific_line(filename, line_number): line = linecache.getline(filename, line_number) return line.strip() # Remove any leading or trailing whitespace CLIP_LENGTH = config.CLIP_LENGTH class semantic_dataset_batch(Dataset): def __init__( self, transcript_path, semantic_path=None, ref_mels_path=None, ref_k=3, scale=False, process_id=None, total_processes=None, ): super().__init__() self.scale = scale if not scale: with open(transcript_path, "r") as file: data = file.read().strip("\n").split("\n")[:] with open(semantic_path, "r") as file: semb = file.read().strip("\n").split("\n") with open(ref_mels_path, "rb") as file: self.ref_mels = pkl.load(file) semb = { i.split("\t")[0]: [j for j in i.split("\t")[1].split()] for i in semb } data = {i.split("|")[0]: i.split("|")[1].strip().lower() for i in data} self.data = [[i, semb[i], data[i]] for i in data.keys()] else: # with open(transcript_path,'r') as file: # get meta for dataset # for count, line in enumerate(file): # pass # count = 80 print(transcript_path) # self.weights,self.count = process_file(transcript_path) self.heads, self.weights, self.count = process_file_for_heads( transcript_path, total_processes, process_id ) print("length :", self.count) self.data_len = self.count self.transcript_path = transcript_path line_index = {} with open(transcript_path, "rb") as file: mmapped_file = mmap.mmap(file.fileno(), 0, access=mmap.ACCESS_READ) line_number = 0 offset = 0 while offset < len(mmapped_file): line_index[line_number] = offset offset = mmapped_file.find(b"\n", offset) + 1 # print(line_number,offset) line_number += 1 self.mmapped_file = mmapped_file self.line_index = line_index self.process_id = process_id self.total_processes = total_processes self.iterator = None self.ref_k = ref_k self.max_wav_value = config.MAX_WAV_VALUE self.stft_fn = STFT(config.filter_length, config.hop_length, config.win_length) mel_basis = librosa_mel_fn( sr=config.sampling_rate, n_fft=config.filter_length, n_mels=config.n_mel_channels, fmin=config.mel_fmin, fmax=config.mel_fmax, ) self.mel_basis = torch.from_numpy(mel_basis).float() def get_mel(self, filepath): # audio, sampling_rate = load_wav_to_torch(filepath) # audio_norm = audio / self.max_wav_value audio_norm, sampling_rate = torchaudio.load(filepath) # dur = audio_norm.shape[-1]/sampling_rate # if dur<0.5: # return None,None,None # if self.clip and dur>10 and align: # # print('big file',dur) # max_audio_start = int(dur - 10) # audio_start = random.randint(0, max_audio_start) # audio_norm = audio_norm[:,audio_start*sampling_rate:(audio_start+10)*sampling_rate] # semb_ids = semb_ids[audio_start*50:(audio_start+10)*50 -1] # 86 mel -> 1s for 22050 setting # ` 93 mel ->`1s for 24000 setting # add 64ms of values to start and end # audio_norm += torch.randn(audio_norm.shape[0])*1e-8 # audio_norm = torch.concat([torch.randn(1412)*1e-8,audio_norm,torch.randn(1412)*1e-8]) # audio_norm = audio_norm.unsqueeze(0) # y = torch.autograd.Variable(audio_norm, requires_grad=False) # assert(torch.min(y.data) >= -1) # assert(torch.max(y.data) <= 1) # magnitudes, phases = self.stft_fn.transform(y) # magnitudes = magnitudes.data # mel_output = torch.matmul(self.mel_basis, magnitudes) # mel_output = dynamic_range_compression(mel_output) # melspec = torch.squeeze(mel_output, 0) # energy = torch.norm(magnitudes, dim=1).squeeze(0) # melspec,energy = mel_spectrogram(audio_norm) melspec = get_mel_spectrogram(audio_norm, sampling_rate).squeeze(0) energy = [] # if align: # return melspec,list(energy),semb_ids return melspec, list(energy) def __len__(self): if self.scale: return self.data_len return len(self.data) # def get_process_heads(self,): # ''' # divide data and heads based on the batch_size and weights # ''' # new_heads ={} # new_weights =[] # process_batch_size = config.ts_batch_size*config.ts_gradient_accumulation_steps # sm=0 # for i,j in zip(self.heads,self.weights): # if sm + j > process_batch_size: # if sm+j == process_batch_size: # new_heads[i] = self.heads[i] # new_weights.append(j) # else: # new_heads[i] = self.heads[i][:len(self.heads[i])*(process_batch_size-sm)//process_batch_size] # new_weights.append(process_batch_size-sm) # else: # new_heads[i] = self.heads[i] # new_weights.append(j) # self.get_worker_heads() # old heads and weights # new_heads = {} # for i in self.heads: # segment_size = (len(self.heads[i]) + self.total_processes - 1) // self.total_processes # start_idx = self.process_id * segment_size # end_idx = start_idx + segment_size # if end_idx > len(self.heads[i]): # # Create a list that wraps around to the beginning # segment = self.heads[i][start_idx:] + self.heads[i][:end_idx - len(self.heads[i])] # else: # segment = self.heads[i][start_idx:end_idx] # new_heads[i]=segment # self.heads = new_heads # print(self.process_id,[len(self.heads[i]) for i in self.heads]) # self.get_worker_heads() def get_worker_heads( self, ): self.worker_id = get_worker_info().id self.num_worker = get_worker_info().num_workers new_heads = {} for i in self.heads: segment_size = (len(self.heads[i]) + self.num_worker - 1) // self.num_worker start_idx = self.worker_id * segment_size end_idx = start_idx + segment_size if end_idx > len(self.heads[i]): # Create a list that wraps around to the beginning segment = ( self.heads[i][start_idx:] + self.heads[i][: end_idx - len(self.heads[i])] ) else: segment = self.heads[i][start_idx:end_idx] new_heads[i] = segment self.heads = new_heads # print("worker:",self.worker_id,self.process_id,[len(self.heads[i]) for i in self.heads],self.weights) def get_head(self): # self.get_process_heads() self.get_worker_heads() # print("weights:",self.weights,[h for h in self.heads]) self.indices = [0] * len(self.heads) # self.process_heads = [{i:self.heads[i][self.process_id:]}for i in self.heads] while True: for ( n, (head, weight), ) in enumerate(zip(self.heads, self.weights)): # if process_id == 0: # print(weight,head) for i in range(weight): if self.indices[n] < len(self.heads[head]): # print(self.heads[head][self.indices[n]],worker_id,self.indices) yield self.heads[head][self.indices[n]] self.indices[n] += 1 else: self.indices[n] = 0 random.shuffle(self.heads[head]) # shuffle the indices def __getitem__(self, index) -> Any: if self.iterator is None: self.iterator = self.get_head() if not self.scale: lang, path, semb, text = self.data[index] ref_mels = self.ref_mels[path][: self.ref_k] else: # line = read_specific_line(self.transcript_path,index+1) index = next(self.iterator) # print(self.worker_id,self.process_id,index) self.mmapped_file.seek(self.line_index[index]) line = self.mmapped_file.readline().decode("utf-8") lang, path, text, semb_ids, ref_mels = line.split("|") # a=5/0 # semb_ids = [int(i)+1 for i in semb_ids.split()] semb = semb_ids.split() ref_mels = [i.split(",") for i in ref_mels.split("\t")][: self.ref_k] if len(semb) < 25: if index + 1 < self.data_len: return self.__getitem__(index + 1) return self.__getitem__(0) if len(ref_mels) == 0: ref_mels.append((path, 1)) ref_mels.append((path, 1)) ref_mels.append((path, 1)) while len(ref_mels) < self.ref_k: ref_mels.append(ref_mels[-1]) text = text.lower().strip() # try: text_ids = [text_enc[""]] + [text_enc[i] for i in text] + [text_enc[""]] semb_ids = ( [code_enc[""]] + [code_enc[i] for i in semb] + [code_enc[""]] ) # except Exception as e: # print(e) # print(lang,path,text,index) # exit # input_ids = text_ids+semb_ids # pad_length = config.t2s_position-(len(text_ids)+len(semb_ids)) # token_type_ids = [0]*len(text_ids)+[1]*len(semb_ids)+[0]*pad_length # positional_ids = [i for i in range(len(text_ids))]+[i for i in range(len(semb_ids))]+[0]*pad_length # labels = [-100]*len(text_ids)+semb_ids+[-100]*pad_length # attention_mask = [1]*len(input_ids)+[0]*pad_length # input_ids += [tok_enc['']]*pad_length def get_random_portion(mel, mask_lengths): clip = mask_lengths <= CLIP_LENGTH ref_mel = mel[:, :, :CLIP_LENGTH].clone() for n, z in enumerate(clip): if not z: start = np.random.randint(0, mask_lengths[n].item() - CLIP_LENGTH) ref_mel[n, :, :] = mel[n, :, start : start + CLIP_LENGTH].clone() return ref_mel try: ref_mels = [self.get_mel(path)[0] for path, score in ref_mels] except Exception as e: print(index, e) if index + 1 < self.data_len: return self.__getitem__(index + 1) return self.__getitem__(0) ref_c = [] for i in range(self.ref_k): if ref_mels[i] is None: continue ref_c.append(ref_mels[i]) if len(ref_c) == 0: # print('no refs worthy') if index + 1 < self.data_len: return self.__getitem__(index + 1) return self.__getitem__(0) if len(ref_c) != self.ref_k: # print('less refs found',len(ref_c)) while len(ref_c) < self.ref_k: ref_c.append(ref_c[-1]) ref_mels = ref_c max_target_len = max([x.size(1) for x in ref_mels]) ref_mels_padded = ( torch.randn((self.ref_k, config.n_mel_channels, max_target_len)) ) * 1e-9 mel_length = [] for i, mel in enumerate(ref_mels): ref_mels_padded[i, :, : mel.size(1)] = mel mel_length.append(mel.shape[-1]) ref_mels = get_random_portion(ref_mels_padded, torch.tensor(mel_length)) return { "text_ids": text_ids, "semb_ids": semb_ids, "ref_mels": ref_mels, "lang": torch.tensor(config.lang_index[lang]), } # def get_padded_seq(sequences): # max_len=max([len(s) for s in sequences]) # for i in range(len(sequences)): # sequences[i]=sequences[i]+tok_enc['']*(max_len-len(sequences[i])) # return sequences def get_padded_seq(sequences, pad_random, before=False, pad__=0): max_len = max([len(s) for s in sequences]) seq_len = [] for i in range(len(sequences)): seq_len.append(len(sequences[i])) if pad_random: pad_ = pad_ = list((np.random.rand(max_len - len(sequences[i]))) * 1e-9) else: pad_ = [pad__] * (max_len - len(sequences[i])) if not before: sequences[i] = sequences[i] + pad_ else: sequences[i] = pad_ + sequences[i] return sequences, seq_len def collate(batch): text_ids = [] semb_ids = [] # paths=[] ref_mels = [] langs = [] # ref_mels_length=[] for b in batch: text_ids.append(b["text_ids"]) semb_ids.append(b["semb_ids"]) # paths.append(b['path']) ref_mels.append(b["ref_mels"]) langs.append(b["lang"]) # ref_mels_length.append(b['ref_mel_length']) text_ids, text_len = get_padded_seq( text_ids, pad_random=False, before=False, pad__=text_enc[""] ) code, code_len = get_padded_seq(semb_ids, pad_random=False, pad__=code_enc[""]) ref_max_target_len = max([x.size(-1) for x in ref_mels]) ref_mels_padded = ( torch.randn( ( len(batch), ref_mels[0].shape[0], config.n_mel_channels, ref_max_target_len, ) ) ) * 1e-9 for i, mel in enumerate(ref_mels): ref_mels_padded[i, :, :, : mel.size(-1)] = mel # print(mel_padded.shape,torch.tensor(code).shape,torch.tensor(mel_length),get_mask_from_lengths(torch.tensor(mel_length))) return ( torch.tensor(text_ids), torch.tensor(code), torch.tensor(text_len), torch.tensor(code_len), ref_mels_padded, torch.tensor(langs), ) def get_dataset(transcript_path, get_process_id, total_processes): return semantic_dataset_batch( transcript_path, scale=True, process_id=get_process_id, total_processes=total_processes, ) if __name__ == "__main__": accelerator = Accelerator( gradient_accumulation_steps=config.ts_gradient_accumulation_steps ) # ,kwargs_handlers=[ddp_kwargs]) mixed_precision="fp16", get_process_id = accelerator.process_index total_processes = accelerator.num_processes # train_dataset_ = semantic_dataset_batch(config.data_path+'/transcript_train_20s_final_normalized_filtered.txt','../'+config.data_path+'/semt.txt','../'+config.data_path+'/ref_clips.pkl', # scale=True,process_id=get_process_id,total_processes = total_processes) # train_dataset_ = semantic_dataset_batch(config.data_path+'/transcript_train_20s_final_normalized_filtered.txt','../'+config.data_path+'/semt.txt','../'+config.data_path+'/ref_clips.pkl', # scale=True,process_id=get_process_id,total_processes = total_processes) # train_dataset_ = semantic_dataset_batch(config.data_path+'/transcript_train_20s_final_normalized_filtered.txt','../'+config.data_path+'/semt.txt','../'+config.data_path+'/ref_clips.pkl', # scale=True,process_id=get_process_id,total_processes = total_processes) train_dataset_ = semantic_dataset_batch( config.data_path + "/transcript_train_20s_final_normalized_filtered.txt", "../" + config.data_path + "/semt.txt", "../" + config.data_path + "/ref_clips.pkl", scale=True, process_id=get_process_id, total_processes=total_processes, ) # sampler = WeightedRandomSampler( # train_dataset_.weights, # train_dataset_.count, # replacement=False) train_dataset = DataLoader( train_dataset_, pin_memory=True, persistent_workers=True, num_workers=config.ts_num_workers, batch_size=config.ts_batch_size, shuffle=False, drop_last=False, collate_fn=collate, sampler=None, ) print("batch", config.ts_batch_size) # val_dataset = DataLoader(semantic_dataset_batch(config.data_path+'/transcript_test_20_final_normalized.txt','../'+config.data_path+'/semt.txt','../'+config.data_path+'/ref_clips.pkl',scale=True,process_id=get_process_id,total_processes = total_processes),pin_memory=True, # persistent_workers=True,num_workers=2,batch_size=config.ts_batch_size,shuffle=True,drop_last=False,collate_fn=collate) train_dataloader = accelerator.prepare(train_dataset) # if accelerator.is_local_main_process: # from IPython import embed # embed() # checkiong the sampler working import math from collections import defaultdict def calculate_duration(code_len): return math.ceil(((code_len + 1) / 50) * 2) / 2 sampling = defaultdict(int) dataset = [] batch_data = {} batch = 0 batch_data[batch] = defaultdict(int) for n, data in enumerate(tqdm(train_dataloader)): # break text_ids, code, text_len, code_len, ref_clips, langs = data # print(text_ids) # print('=====') # # break for i, j in zip(code_len, text_ids): dur = calculate_duration(i - 2) # print(dur,i,code.shape) # sampling[calculate_duration(i)]+=1 dataset.append(list(j.detach().cpu().numpy())) if dur > 19.5: batch_data[batch]["20_sentence"] += 1 continue if dur <= 5: batch_data[batch]["5s"] += 1 continue elif dur <= 10: batch_data[batch]["10s"] += 1 continue elif dur <= 15: batch_data[batch]["15s"] += 1 continue elif dur <= 20: batch_data[batch]["20s"] += 1 continue # print(batch) if (n + 1) % config.ts_gradient_accumulation_steps == 0: batch += 1 batch_data[batch] = defaultdict(int) # break # if n==20: # break # # print(sampling) with open( f"Sampling_data_meta/sampling_{accelerator.process_index}.pkl", "wb" ) as file: pkl.dump(batch_data, file) with open( f"Sampling_data_meta/sampling_dataset_{accelerator.process_index}.pkl", "wb" ) as file: pkl.dump(dataset, file) print(batch_data[0]) # # # return 0