|
|
import os |
|
|
import sys |
|
|
from typing import Any |
|
|
|
|
|
sys.path.append("../") |
|
|
import linecache |
|
|
import mmap |
|
|
import pickle as pkl |
|
|
import random |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.optim as optim |
|
|
import torchaudio |
|
|
import transformers |
|
|
from accelerate import Accelerator, DistributedDataParallelKwargs |
|
|
from autoregressive import TS_model |
|
|
from cleaners import english_cleaners |
|
|
from librosa.filters import mel as librosa_mel_fn |
|
|
from mel_spec import get_mel_spectrogram |
|
|
from meta_stats import process_file, process_file_for_heads |
|
|
from stft import STFT |
|
|
from torch.utils.data import (DataLoader, Dataset, WeightedRandomSampler, |
|
|
get_worker_info) |
|
|
from tqdm import tqdm |
|
|
from utilities import get_mask_from_lengths |
|
|
|
|
|
import wandb |
|
|
from config import config |
|
|
from Text import code_labels, labels, text_labels |
|
|
|
|
|
torch.manual_seed(config.seed_value) |
|
|
np.random.seed(config.seed_value) |
|
|
random.seed(config.seed_value) |
|
|
print(text_labels) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text_enc = {j: i for i, j in enumerate(text_labels)} |
|
|
text_dec = {i: j for i, j in enumerate(text_labels)} |
|
|
|
|
|
|
|
|
code_enc = {j: i for i, j in enumerate(code_labels)} |
|
|
code_dec = {i: j for i, j in enumerate(code_labels)} |
|
|
|
|
|
|
|
|
def read_specific_line(filename, line_number): |
|
|
line = linecache.getline(filename, line_number) |
|
|
return line.strip() |
|
|
|
|
|
|
|
|
CLIP_LENGTH = config.CLIP_LENGTH |
|
|
|
|
|
|
|
|
class semantic_dataset_batch(Dataset): |
|
|
def __init__( |
|
|
self, |
|
|
transcript_path, |
|
|
semantic_path=None, |
|
|
ref_mels_path=None, |
|
|
ref_k=3, |
|
|
scale=False, |
|
|
process_id=None, |
|
|
total_processes=None, |
|
|
): |
|
|
super().__init__() |
|
|
self.scale = scale |
|
|
if not scale: |
|
|
with open(transcript_path, "r") as file: |
|
|
data = file.read().strip("\n").split("\n")[:] |
|
|
|
|
|
with open(semantic_path, "r") as file: |
|
|
semb = file.read().strip("\n").split("\n") |
|
|
|
|
|
with open(ref_mels_path, "rb") as file: |
|
|
self.ref_mels = pkl.load(file) |
|
|
|
|
|
semb = { |
|
|
i.split("\t")[0]: [j for j in i.split("\t")[1].split()] for i in semb |
|
|
} |
|
|
data = {i.split("|")[0]: i.split("|")[1].strip().lower() for i in data} |
|
|
|
|
|
self.data = [[i, semb[i], data[i]] for i in data.keys()] |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(transcript_path) |
|
|
|
|
|
self.heads, self.weights, self.count = process_file_for_heads( |
|
|
transcript_path, total_processes, process_id |
|
|
) |
|
|
print("length :", self.count) |
|
|
self.data_len = self.count |
|
|
self.transcript_path = transcript_path |
|
|
line_index = {} |
|
|
with open(transcript_path, "rb") as file: |
|
|
mmapped_file = mmap.mmap(file.fileno(), 0, access=mmap.ACCESS_READ) |
|
|
line_number = 0 |
|
|
offset = 0 |
|
|
while offset < len(mmapped_file): |
|
|
line_index[line_number] = offset |
|
|
offset = mmapped_file.find(b"\n", offset) + 1 |
|
|
|
|
|
line_number += 1 |
|
|
self.mmapped_file = mmapped_file |
|
|
self.line_index = line_index |
|
|
|
|
|
self.process_id = process_id |
|
|
self.total_processes = total_processes |
|
|
self.iterator = None |
|
|
|
|
|
self.ref_k = ref_k |
|
|
self.max_wav_value = config.MAX_WAV_VALUE |
|
|
self.stft_fn = STFT(config.filter_length, config.hop_length, config.win_length) |
|
|
|
|
|
mel_basis = librosa_mel_fn( |
|
|
sr=config.sampling_rate, |
|
|
n_fft=config.filter_length, |
|
|
n_mels=config.n_mel_channels, |
|
|
fmin=config.mel_fmin, |
|
|
fmax=config.mel_fmax, |
|
|
) |
|
|
|
|
|
self.mel_basis = torch.from_numpy(mel_basis).float() |
|
|
|
|
|
def get_mel(self, filepath): |
|
|
|
|
|
|
|
|
audio_norm, sampling_rate = torchaudio.load(filepath) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
melspec = get_mel_spectrogram(audio_norm, sampling_rate).squeeze(0) |
|
|
energy = [] |
|
|
|
|
|
|
|
|
return melspec, list(energy) |
|
|
|
|
|
def __len__(self): |
|
|
if self.scale: |
|
|
return self.data_len |
|
|
return len(self.data) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_worker_heads( |
|
|
self, |
|
|
): |
|
|
self.worker_id = get_worker_info().id |
|
|
self.num_worker = get_worker_info().num_workers |
|
|
new_heads = {} |
|
|
for i in self.heads: |
|
|
segment_size = (len(self.heads[i]) + self.num_worker - 1) // self.num_worker |
|
|
start_idx = self.worker_id * segment_size |
|
|
end_idx = start_idx + segment_size |
|
|
|
|
|
if end_idx > len(self.heads[i]): |
|
|
|
|
|
segment = ( |
|
|
self.heads[i][start_idx:] |
|
|
+ self.heads[i][: end_idx - len(self.heads[i])] |
|
|
) |
|
|
else: |
|
|
segment = self.heads[i][start_idx:end_idx] |
|
|
new_heads[i] = segment |
|
|
self.heads = new_heads |
|
|
|
|
|
|
|
|
def get_head(self): |
|
|
|
|
|
self.get_worker_heads() |
|
|
|
|
|
self.indices = [0] * len(self.heads) |
|
|
|
|
|
while True: |
|
|
for ( |
|
|
n, |
|
|
(head, weight), |
|
|
) in enumerate(zip(self.heads, self.weights)): |
|
|
|
|
|
|
|
|
for i in range(weight): |
|
|
if self.indices[n] < len(self.heads[head]): |
|
|
|
|
|
yield self.heads[head][self.indices[n]] |
|
|
self.indices[n] += 1 |
|
|
else: |
|
|
self.indices[n] = 0 |
|
|
random.shuffle(self.heads[head]) |
|
|
|
|
|
|
|
|
def __getitem__(self, index) -> Any: |
|
|
if self.iterator is None: |
|
|
self.iterator = self.get_head() |
|
|
if not self.scale: |
|
|
lang, path, semb, text = self.data[index] |
|
|
ref_mels = self.ref_mels[path][: self.ref_k] |
|
|
|
|
|
else: |
|
|
|
|
|
|
|
|
index = next(self.iterator) |
|
|
|
|
|
self.mmapped_file.seek(self.line_index[index]) |
|
|
line = self.mmapped_file.readline().decode("utf-8") |
|
|
|
|
|
lang, path, text, semb_ids, ref_mels = line.split("|") |
|
|
|
|
|
|
|
|
semb = semb_ids.split() |
|
|
ref_mels = [i.split(",") for i in ref_mels.split("\t")][: self.ref_k] |
|
|
|
|
|
if len(semb) < 25: |
|
|
if index + 1 < self.data_len: |
|
|
return self.__getitem__(index + 1) |
|
|
return self.__getitem__(0) |
|
|
|
|
|
if len(ref_mels) == 0: |
|
|
ref_mels.append((path, 1)) |
|
|
ref_mels.append((path, 1)) |
|
|
ref_mels.append((path, 1)) |
|
|
|
|
|
while len(ref_mels) < self.ref_k: |
|
|
ref_mels.append(ref_mels[-1]) |
|
|
|
|
|
text = text.lower().strip() |
|
|
|
|
|
text_ids = [text_enc["<S>"]] + [text_enc[i] for i in text] + [text_enc["<E>"]] |
|
|
semb_ids = ( |
|
|
[code_enc["<SST>"]] + [code_enc[i] for i in semb] + [code_enc["<EST>"]] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_random_portion(mel, mask_lengths): |
|
|
clip = mask_lengths <= CLIP_LENGTH |
|
|
ref_mel = mel[:, :, :CLIP_LENGTH].clone() |
|
|
for n, z in enumerate(clip): |
|
|
if not z: |
|
|
start = np.random.randint(0, mask_lengths[n].item() - CLIP_LENGTH) |
|
|
ref_mel[n, :, :] = mel[n, :, start : start + CLIP_LENGTH].clone() |
|
|
return ref_mel |
|
|
|
|
|
try: |
|
|
ref_mels = [self.get_mel(path)[0] for path, score in ref_mels] |
|
|
except Exception as e: |
|
|
print(index, e) |
|
|
if index + 1 < self.data_len: |
|
|
return self.__getitem__(index + 1) |
|
|
return self.__getitem__(0) |
|
|
|
|
|
ref_c = [] |
|
|
for i in range(self.ref_k): |
|
|
if ref_mels[i] is None: |
|
|
continue |
|
|
ref_c.append(ref_mels[i]) |
|
|
|
|
|
if len(ref_c) == 0: |
|
|
|
|
|
if index + 1 < self.data_len: |
|
|
return self.__getitem__(index + 1) |
|
|
return self.__getitem__(0) |
|
|
|
|
|
if len(ref_c) != self.ref_k: |
|
|
|
|
|
while len(ref_c) < self.ref_k: |
|
|
ref_c.append(ref_c[-1]) |
|
|
|
|
|
ref_mels = ref_c |
|
|
|
|
|
max_target_len = max([x.size(1) for x in ref_mels]) |
|
|
ref_mels_padded = ( |
|
|
torch.randn((self.ref_k, config.n_mel_channels, max_target_len)) |
|
|
) * 1e-9 |
|
|
mel_length = [] |
|
|
for i, mel in enumerate(ref_mels): |
|
|
ref_mels_padded[i, :, : mel.size(1)] = mel |
|
|
mel_length.append(mel.shape[-1]) |
|
|
|
|
|
ref_mels = get_random_portion(ref_mels_padded, torch.tensor(mel_length)) |
|
|
|
|
|
return { |
|
|
"text_ids": text_ids, |
|
|
"semb_ids": semb_ids, |
|
|
"ref_mels": ref_mels, |
|
|
"lang": torch.tensor(config.lang_index[lang]), |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def get_padded_seq(sequences, pad_random, before=False, pad__=0): |
|
|
max_len = max([len(s) for s in sequences]) |
|
|
seq_len = [] |
|
|
for i in range(len(sequences)): |
|
|
seq_len.append(len(sequences[i])) |
|
|
if pad_random: |
|
|
pad_ = pad_ = list((np.random.rand(max_len - len(sequences[i]))) * 1e-9) |
|
|
else: |
|
|
pad_ = [pad__] * (max_len - len(sequences[i])) |
|
|
if not before: |
|
|
sequences[i] = sequences[i] + pad_ |
|
|
else: |
|
|
sequences[i] = pad_ + sequences[i] |
|
|
|
|
|
return sequences, seq_len |
|
|
|
|
|
|
|
|
def collate(batch): |
|
|
text_ids = [] |
|
|
semb_ids = [] |
|
|
|
|
|
ref_mels = [] |
|
|
langs = [] |
|
|
|
|
|
|
|
|
for b in batch: |
|
|
text_ids.append(b["text_ids"]) |
|
|
semb_ids.append(b["semb_ids"]) |
|
|
|
|
|
ref_mels.append(b["ref_mels"]) |
|
|
langs.append(b["lang"]) |
|
|
|
|
|
|
|
|
text_ids, text_len = get_padded_seq( |
|
|
text_ids, pad_random=False, before=False, pad__=text_enc["<E>"] |
|
|
) |
|
|
code, code_len = get_padded_seq(semb_ids, pad_random=False, pad__=code_enc["<EST>"]) |
|
|
|
|
|
ref_max_target_len = max([x.size(-1) for x in ref_mels]) |
|
|
ref_mels_padded = ( |
|
|
torch.randn( |
|
|
( |
|
|
len(batch), |
|
|
ref_mels[0].shape[0], |
|
|
config.n_mel_channels, |
|
|
ref_max_target_len, |
|
|
) |
|
|
) |
|
|
) * 1e-9 |
|
|
|
|
|
for i, mel in enumerate(ref_mels): |
|
|
ref_mels_padded[i, :, :, : mel.size(-1)] = mel |
|
|
|
|
|
|
|
|
|
|
|
return ( |
|
|
torch.tensor(text_ids), |
|
|
torch.tensor(code), |
|
|
torch.tensor(text_len), |
|
|
torch.tensor(code_len), |
|
|
ref_mels_padded, |
|
|
torch.tensor(langs), |
|
|
) |
|
|
|
|
|
|
|
|
def get_dataset(transcript_path, get_process_id, total_processes): |
|
|
return semantic_dataset_batch( |
|
|
transcript_path, |
|
|
scale=True, |
|
|
process_id=get_process_id, |
|
|
total_processes=total_processes, |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
accelerator = Accelerator( |
|
|
gradient_accumulation_steps=config.ts_gradient_accumulation_steps |
|
|
) |
|
|
|
|
|
get_process_id = accelerator.process_index |
|
|
total_processes = accelerator.num_processes |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_dataset_ = semantic_dataset_batch( |
|
|
config.data_path + "/transcript_train_20s_final_normalized_filtered.txt", |
|
|
"../" + config.data_path + "/semt.txt", |
|
|
"../" + config.data_path + "/ref_clips.pkl", |
|
|
scale=True, |
|
|
process_id=get_process_id, |
|
|
total_processes=total_processes, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
train_dataset = DataLoader( |
|
|
train_dataset_, |
|
|
pin_memory=True, |
|
|
persistent_workers=True, |
|
|
num_workers=config.ts_num_workers, |
|
|
batch_size=config.ts_batch_size, |
|
|
shuffle=False, |
|
|
drop_last=False, |
|
|
collate_fn=collate, |
|
|
sampler=None, |
|
|
) |
|
|
print("batch", config.ts_batch_size) |
|
|
|
|
|
|
|
|
|
|
|
train_dataloader = accelerator.prepare(train_dataset) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
from collections import defaultdict |
|
|
|
|
|
def calculate_duration(code_len): |
|
|
return math.ceil(((code_len + 1) / 50) * 2) / 2 |
|
|
|
|
|
sampling = defaultdict(int) |
|
|
dataset = [] |
|
|
batch_data = {} |
|
|
batch = 0 |
|
|
batch_data[batch] = defaultdict(int) |
|
|
for n, data in enumerate(tqdm(train_dataloader)): |
|
|
|
|
|
text_ids, code, text_len, code_len, ref_clips, langs = data |
|
|
|
|
|
|
|
|
|
|
|
for i, j in zip(code_len, text_ids): |
|
|
dur = calculate_duration(i - 2) |
|
|
|
|
|
|
|
|
dataset.append(list(j.detach().cpu().numpy())) |
|
|
|
|
|
if dur > 19.5: |
|
|
batch_data[batch]["20_sentence"] += 1 |
|
|
continue |
|
|
if dur <= 5: |
|
|
batch_data[batch]["5s"] += 1 |
|
|
continue |
|
|
elif dur <= 10: |
|
|
batch_data[batch]["10s"] += 1 |
|
|
continue |
|
|
elif dur <= 15: |
|
|
batch_data[batch]["15s"] += 1 |
|
|
continue |
|
|
elif dur <= 20: |
|
|
batch_data[batch]["20s"] += 1 |
|
|
continue |
|
|
|
|
|
if (n + 1) % config.ts_gradient_accumulation_steps == 0: |
|
|
batch += 1 |
|
|
batch_data[batch] = defaultdict(int) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open( |
|
|
f"Sampling_data_meta/sampling_{accelerator.process_index}.pkl", "wb" |
|
|
) as file: |
|
|
pkl.dump(batch_data, file) |
|
|
with open( |
|
|
f"Sampling_data_meta/sampling_dataset_{accelerator.process_index}.pkl", "wb" |
|
|
) as file: |
|
|
pkl.dump(dataset, file) |
|
|
print(batch_data[0]) |
|
|
|
|
|
|