| import numpy as np |
| from transformers import AutoTokenizer |
| import os |
| import torch |
| from collections import OrderedDict |
| import librosa |
| from importlib_resources import files |
| import yaml |
| import argparse |
| import torchaudio |
| import torchaudio.transforms as T |
| import collections |
| import random |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| import logging |
| from glob import glob |
|
|
| from mapper import get_sid_mapper, get_text_mapper |
| from transformers import GPT2LMHeadModel |
| from transformers import AutoTokenizer |
|
|
|
|
| class ExpWrapper(): |
| def __init__(self, config_wrapper, gpu_id): |
| self.tok_len = config_wrapper['tok_len'] |
| self.text_prefix_length = config_wrapper['text_prefix_length'] |
| self.sid_prefix_length = config_wrapper['sid_prefix_length'] |
| self.norm_sid_emb = config_wrapper['norm_sid_emb'] |
| self.gpu_id = gpu_id |
| self.gpt = GPT2LMHeadModel.from_pretrained(config_wrapper['text_decoder']) |
| self.gpt = self.gpt.to(self.gpu_id) |
| |
| |
|
|
| self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1] |
|
|
| self.sid_mapper = get_sid_mapper(config_wrapper["map_type"],None, |
| config_wrapper["prefix_size"], self.gpt_embedding_size, |
| config_wrapper["sid_prefix_length"], config_wrapper["sid_prefix_length_clip"], |
| config_wrapper["num_layers"]) |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.sid_mapper = self.sid_mapper.to(self.gpu_id) |
| |
| self.tokenizer = AutoTokenizer.from_pretrained(config_wrapper['text_decoder']) |
| self.tokenizer.add_special_tokens({'pad_token': '!'}) |
|
|
| def init_mapper(self): |
| self.sid_mapper = DDP(self.sid_mapper, device_ids=[self.gpu_id], find_unused_parameters=True) |
|
|
| def freeze_llm(self): |
| for param in self.sid_mapper.parameters(): |
| param.requires_grad = False |
| for param in self.gpt.parameters(): |
| param.requires_grad = False |
|
|
| def default_collate(self, batch): |
| r"""Puts each data field into a tensor with outer dimension batch size""" |
| elem = batch[0] |
| elem_type = type(elem) |
| if isinstance(elem, torch.Tensor): |
| out = None |
| if torch.utils.data.get_worker_info() is not None: |
| |
| |
| numel = sum([x.numel() for x in batch]) |
| storage = elem.storage()._new_shared(numel) |
| out = elem.new(storage) |
| return torch.stack(batch, 0, out=out) |
| elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \ |
| and elem_type.__name__ != 'string_': |
| if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap': |
| |
| if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None: |
| raise TypeError( |
| self.default_collate_err_msg_format.format(elem.dtype)) |
|
|
| return self.default_collate([torch.as_tensor(b) for b in batch]) |
| elif elem.shape == (): |
| return torch.as_tensor(batch) |
| elif isinstance(elem, float): |
| return torch.tensor(batch, dtype=torch.float64) |
| elif isinstance(elem, int): |
| return torch.tensor(batch) |
| elif isinstance(elem, collections.abc.Mapping): |
| return {key: self.default_collate([d[key] for d in batch]) for key in elem} |
| elif isinstance(elem, tuple) and hasattr(elem, '_fields'): |
| return elem_type(*(self.default_collate(samples) for samples in zip(*batch))) |
| elif isinstance(elem, collections.abc.Sequence): |
| |
| it = iter(batch) |
| elem_size = len(next(it)) |
| if not all(len(elem) == elem_size for elem in it): |
| raise RuntimeError( |
| 'each element in list of batch should be of equal size') |
| transposed = zip(*batch) |
| return [self.default_collate(samples) for samples in transposed] |
|
|
| raise TypeError(self.default_collate_err_msg_format.format(elem_type)) |
| |
| def load_model(self, st, model): |
| try: |
| model.load_state_dict(st) |
| except: |
| for key in list(st.keys()): |
| if "module." in key: |
| st[key.replace("module.", "")] = st.pop(key) |
| model.load_state_dict(st) |
| return model |
|
|
| def load_model(self, st, model): |
| try: |
| model.load_state_dict(st) |
| except: |
| for key in list(st.keys()): |
| if "module." in key: |
| st[key.replace("module.", "")] = st.pop(key) |
| model.load_state_dict(st) |
| return model |
|
|
| def load_sid_model(self, sid_model, snapshot_path, sid_ck_name): |
| loc = f"cuda:{self.gpu_id}" |
| |
| |
| sid_model_path = f"{snapshot_path}/{sid_ck_name}" |
| snapshot = torch.load(sid_model_path, map_location=loc) |
| sid_model = self.load_model(snapshot["sid_model"], sid_model) |
| best_val_loss = snapshot["val_loss"] |
| epochs_run = snapshot["epochs_run"] |
|
|
| def load_mapper(self, snapshot_path, mapper_ck_name): |
| loc = f"cuda:{self.gpu_id}" |
| mapper_path = sorted(glob(f"{snapshot_path}/mapper_*.pt"))[-1] |
| mapper_path = f"{snapshot_path}/{mapper_ck_name}" |
| snapshot = torch.load(mapper_path, map_location=loc) |
|
|
| self.sid_mapper = self.load_model(snapshot["sid_mapper"],self.sid_mapper) |
| |
| |
| self.epochs_run = snapshot["epochs_run"] |
| logging.info(f"Resuming training from mapper at Epoch {self.epochs_run}") |
|
|
| def save_mapper(self, epoch, snapshot_path, val_epoch_ce_llm): |
| mapper = { |
| |
| "sid_mapper": self.sid_mapper.state_dict(), |
| "epochs_run": epoch, |
| } |
| part = snapshot_path |
| torch.save(mapper, f"{part}/unfrozen_mapper_epoch_{str(epoch).zfill(4)}_val_epoch_ce_llm_{val_epoch_ce_llm}.pt") |
| logging.info(f"Epoch {epoch} | Training mapper saved at {snapshot_path}") |
|
|
| def preprocess_prompt(self, texts): |
| r"""Load list of prompts and return tokenized text""" |
| tokenized_texts = [] |
| for ttext in texts: |
| tok = self.tokenizer.encode_plus( |
| text=ttext, add_special_tokens=True, |
| max_length=10, |
| pad_to_max_length=True, return_tensors="pt", truncation=True) |
| for key in tok.keys(): |
| tok[key] = tok[key].reshape(-1).to(self.gpu_id) |
| tokenized_texts.append(tok) |
| return self.default_collate(tokenized_texts) |
| |
| def preprocess_prompt_single(self, texts): |
| r"""Load list of prompts and return tokenized text""" |
| tokenized_texts = [] |
| tok = self.tokenizer.encode_plus( |
| text=texts, add_special_tokens=True, |
| max_length=10, |
| pad_to_max_length=True, return_tensors="pt", truncation=True) |
| for key in tok.keys(): |
| tok[key] = tok[key].reshape(-1).to(self.gpu_id) |
| tokenized_texts.append(tok) |
| return self.default_collate(tokenized_texts) |
|
|
|
|
| def preprocess_text(self, texts): |
| r"""Load list of prompts and return tokenized text""" |
| tokenized_texts = [] |
| for ttext in texts: |
| ttext = ttext + ' <|endoftext|>' |
| tok = self.tokenizer.encode_plus( |
| text=ttext, add_special_tokens=True, |
| max_length=self.tok_len, |
| pad_to_max_length=True, return_tensors="pt", truncation=True) |
| for key in tok.keys(): |
| tok[key] = tok[key].reshape(-1).to(self.gpu_id) |
| tokenized_texts.append(tok) |
| return self.default_collate(tokenized_texts) |
|
|
| def _get_text_embeddings(self, preprocessed_texts): |
| r"""Load preprocessed prompts and return a prompt embeddings""" |
| with torch.no_grad(): |
| texts_embed = self.gpt.transformer.wte(preprocessed_texts['input_ids']) |
| return texts_embed |
|
|
| def get_sid_prefix(self, sid_embeddings): |
| r"""Produces audio embedding which is fed to LM""" |
| if self.norm_sid_emb: |
| sid_embeddings = sid_embeddings / sid_embeddings.norm(2, -1).reshape(-1,1) |
|
|
| |
| sids_prefix = self.sid_mapper(sid_embeddings).contiguous().view(-1, self.sid_prefix_length, self.gpt_embedding_size) |
| |
| return sids_prefix |
| |
| def get_prompt_prefix(self, texts): |
| r"""Load list of text prompts and return prompt prefix and prompt embeddings""" |
| preprocessed_texts = self.preprocess_prompt(texts) |
| print(preprocessed_texts) |
| texts_embed = self._get_text_embeddings(preprocessed_texts) |
| return texts_embed, preprocessed_texts |
| def get_prompt_prefix_single(self, texts): |
| r"""Load list of text prompts and return prompt prefix and prompt embeddings""" |
| preprocessed_texts = self.preprocess_prompt_single(texts) |
| texts_embed = self._get_text_embeddings(preprocessed_texts) |
| return texts_embed, preprocessed_texts |
|
|
| def get_text_prefix(self, texts): |
| r"""Load list of text prompts and return prompt prefix and prompt embeddings""" |
| preprocessed_texts = self.preprocess_text(texts) |
| texts_embed = self._get_text_embeddings(preprocessed_texts) |
| return texts_embed, preprocessed_texts |
| |
| def generate_beam(self, beam_size: int = 1, sids_prefix=None, entry_length=80, temperature=1., stop_token: str = ' <|endoftext|>'): |
| stop_token_index = self.tokenizer.encode(stop_token)[0] |
| tokens = None |
| scores = None |
| device = next(self.gpt.parameters()).device |
| seq_lengths = torch.ones(beam_size, device=device) |
| is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool) |
| with torch.no_grad(): |
| generated = sids_prefix |
| for i in range(entry_length): |
| outputs = self.gpt(inputs_embeds=generated) |
| logits = outputs.logits |
| logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0) |
| logits = logits.softmax(-1).log() |
| if scores is None: |
| scores, next_tokens = logits.topk(beam_size, -1) |
| generated = generated.expand(beam_size, *generated.shape[1:]) |
| next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0) |
| if tokens is None: |
| tokens = next_tokens |
| else: |
| tokens = tokens.expand(beam_size, *tokens.shape[1:]) |
| tokens = torch.cat((tokens, next_tokens), dim=1) |
| else: |
| logits[is_stopped] = -float(np.inf) |
| logits[is_stopped, 0] = 0 |
| scores_sum = scores[:, None] + logits |
| seq_lengths[~is_stopped] += 1 |
| scores_sum_average = scores_sum / seq_lengths[:, None] |
| scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1) |
| next_tokens_source = next_tokens // scores_sum.shape[1] |
| seq_lengths = seq_lengths[next_tokens_source] |
| next_tokens = next_tokens % scores_sum.shape[1] |
| next_tokens = next_tokens.unsqueeze(1) |
| tokens = tokens[next_tokens_source] |
| tokens = torch.cat((tokens, next_tokens), dim=1) |
| generated = generated[next_tokens_source] |
| scores = scores_sum_average * seq_lengths |
| is_stopped = is_stopped[next_tokens_source] |
| |
| next_token_embed = self.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1) |
| generated = torch.cat((generated, next_token_embed), dim=1) |
| is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze() |
| if is_stopped.all(): |
| break |
| scores = scores / seq_lengths |
| output_list = tokens.cpu().numpy() |
| |
| |
| |
| |
| |
| |
| |
| |
| output_texts = [self.tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)] |
| order = scores.argsort(descending=True) |
| |
| output_texts = [output_texts[i] for i in order] |
| return output_texts |
| |
| |
|
|
|
|