CoLMbo / wrapper.py
massabaali's picture
Upload CoLMbo model weights and code
f55a095 verified
import numpy as np
from transformers import AutoTokenizer
import os
import torch
from collections import OrderedDict
import librosa
from importlib_resources import files
import yaml
import argparse
import torchaudio
import torchaudio.transforms as T
import collections
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.parallel import DistributedDataParallel as DDP
import logging
from glob import glob
from mapper import get_sid_mapper, get_text_mapper
from transformers import GPT2LMHeadModel
from transformers import AutoTokenizer
class ExpWrapper():
def __init__(self, config_wrapper, gpu_id):
self.tok_len = config_wrapper['tok_len']
self.text_prefix_length = config_wrapper['text_prefix_length']
self.sid_prefix_length = config_wrapper['sid_prefix_length']
self.norm_sid_emb = config_wrapper['norm_sid_emb']
self.gpu_id = gpu_id
self.gpt = GPT2LMHeadModel.from_pretrained(config_wrapper['text_decoder'])
self.gpt = self.gpt.to(self.gpu_id)
# for param in self.gpt.parameters():
# param.requires_grad = False
self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1]
self.sid_mapper = get_sid_mapper(config_wrapper["map_type"],None,
config_wrapper["prefix_size"], self.gpt_embedding_size,
config_wrapper["sid_prefix_length"], config_wrapper["sid_prefix_length_clip"],
config_wrapper["num_layers"])
# self.text_mapper = get_text_mapper(config_wrapper["map_type"], None,
# config_wrapper["prefix_size"], self.gpt_embedding_size,
# config_wrapper["text_prefix_length"], config_wrapper["text_prefix_length_clip"],
# config_wrapper["num_layers"])
# # this is temporary
# if config_wrapper["checkpoint_path"]:
# checkpoint = torch.load(config_wrapper["checkpoint_path"])
# state_dict = checkpoint['model']
# text_project_weights = {k.replace('caption_decoder.text_project.',''): v for k, v in state_dict.items()
# if 'caption_decoder.text_project' in k}
# self.text_mapper.load_state_dict(text_project_weights)
self.sid_mapper = self.sid_mapper.to(self.gpu_id)
# self.text_mapper = self.text_mapper.to(self.gpu_id)
self.tokenizer = AutoTokenizer.from_pretrained(config_wrapper['text_decoder'])
self.tokenizer.add_special_tokens({'pad_token': '!'})
def init_mapper(self):
self.sid_mapper = DDP(self.sid_mapper, device_ids=[self.gpu_id], find_unused_parameters=True)
def freeze_llm(self):
for param in self.sid_mapper.parameters():
param.requires_grad = False
for param in self.gpt.parameters():
param.requires_grad = False
def default_collate(self, batch):
r"""Puts each data field into a tensor with outer dimension batch size"""
elem = batch[0]
elem_type = type(elem)
if isinstance(elem, torch.Tensor):
out = None
if torch.utils.data.get_worker_info() is not None:
# If we're in a background process, concatenate directly into a
# shared memory tensor to avoid an extra copy
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
and elem_type.__name__ != 'string_':
if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
# array of string classes and object
if self.np_str_obj_array_pattern.search(elem.dtype.str) is not None:
raise TypeError(
self.default_collate_err_msg_format.format(elem.dtype))
return self.default_collate([torch.as_tensor(b) for b in batch])
elif elem.shape == (): # scalars
return torch.as_tensor(batch)
elif isinstance(elem, float):
return torch.tensor(batch, dtype=torch.float64)
elif isinstance(elem, int):
return torch.tensor(batch)
elif isinstance(elem, collections.abc.Mapping):
return {key: self.default_collate([d[key] for d in batch]) for key in elem}
elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
return elem_type(*(self.default_collate(samples) for samples in zip(*batch)))
elif isinstance(elem, collections.abc.Sequence):
# check to make sure that the elements in batch have consistent size
it = iter(batch)
elem_size = len(next(it))
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError(
'each element in list of batch should be of equal size')
transposed = zip(*batch)
return [self.default_collate(samples) for samples in transposed]
raise TypeError(self.default_collate_err_msg_format.format(elem_type))
def load_model(self, st, model):
try:
model.load_state_dict(st)
except:
for key in list(st.keys()):
if "module." in key:
st[key.replace("module.", "")] = st.pop(key)
model.load_state_dict(st)
return model
def load_model(self, st, model):
try:
model.load_state_dict(st)
except:
for key in list(st.keys()):
if "module." in key:
st[key.replace("module.", "")] = st.pop(key)
model.load_state_dict(st)
return model
def load_sid_model(self, sid_model, snapshot_path, sid_ck_name):
loc = f"cuda:{self.gpu_id}"
# sid_model_path = sorted(glob(f"{snapshot_path}/sid_model_epoch_*.pt"),
# key=lambda x: float(x.split('_')[-1].replace('.pt', '')))[0]
sid_model_path = f"{snapshot_path}/{sid_ck_name}"
snapshot = torch.load(sid_model_path, map_location=loc)
sid_model = self.load_model(snapshot["sid_model"], sid_model)
best_val_loss = snapshot["val_loss"]
epochs_run = snapshot["epochs_run"]
def load_mapper(self, snapshot_path, mapper_ck_name):
loc = f"cuda:{self.gpu_id}"
mapper_path = sorted(glob(f"{snapshot_path}/mapper_*.pt"))[-1]
mapper_path = f"{snapshot_path}/{mapper_ck_name}"
snapshot = torch.load(mapper_path, map_location=loc)
self.sid_mapper = self.load_model(snapshot["sid_mapper"],self.sid_mapper)
# self.text_mapper = self.load_model(snapshot["text_mapper"],self.text_mapper)
self.epochs_run = snapshot["epochs_run"]
logging.info(f"Resuming training from mapper at Epoch {self.epochs_run}")
def save_mapper(self, epoch, snapshot_path, val_epoch_ce_llm):
mapper = {
# "text_mapper": self.text_mapper.state_dict(),
"sid_mapper": self.sid_mapper.state_dict(),
"epochs_run": epoch,
}
part = snapshot_path
torch.save(mapper, f"{part}/unfrozen_mapper_epoch_{str(epoch).zfill(4)}_val_epoch_ce_llm_{val_epoch_ce_llm}.pt")
logging.info(f"Epoch {epoch} | Training mapper saved at {snapshot_path}")
def preprocess_prompt(self, texts): # true false
r"""Load list of prompts and return tokenized text"""
tokenized_texts = []
for ttext in texts:
tok = self.tokenizer.encode_plus(
text=ttext, add_special_tokens=True,
max_length=10,
pad_to_max_length=True, return_tensors="pt", truncation=True)
for key in tok.keys():
tok[key] = tok[key].reshape(-1).to(self.gpu_id)
tokenized_texts.append(tok)
return self.default_collate(tokenized_texts)
def preprocess_prompt_single(self, texts): # true false
r"""Load list of prompts and return tokenized text"""
tokenized_texts = []
tok = self.tokenizer.encode_plus(
text=texts, add_special_tokens=True,
max_length=10,
pad_to_max_length=True, return_tensors="pt", truncation=True)
for key in tok.keys():
tok[key] = tok[key].reshape(-1).to(self.gpu_id)
tokenized_texts.append(tok)
return self.default_collate(tokenized_texts)
def preprocess_text(self, texts): # true false
r"""Load list of prompts and return tokenized text"""
tokenized_texts = []
for ttext in texts:
ttext = ttext + ' <|endoftext|>'
tok = self.tokenizer.encode_plus(
text=ttext, add_special_tokens=True,
max_length=self.tok_len,
pad_to_max_length=True, return_tensors="pt", truncation=True)
for key in tok.keys():
tok[key] = tok[key].reshape(-1).to(self.gpu_id)
tokenized_texts.append(tok)
return self.default_collate(tokenized_texts)
def _get_text_embeddings(self, preprocessed_texts):
r"""Load preprocessed prompts and return a prompt embeddings"""
with torch.no_grad():
texts_embed = self.gpt.transformer.wte(preprocessed_texts['input_ids'])
return texts_embed
def get_sid_prefix(self, sid_embeddings):
r"""Produces audio embedding which is fed to LM"""
if self.norm_sid_emb:
sid_embeddings = sid_embeddings / sid_embeddings.norm(2, -1).reshape(-1,1)
# raise SystemError(sid_embeddings.shape) # torch.Size([2, 1024])
sids_prefix = self.sid_mapper(sid_embeddings).contiguous().view(-1, self.sid_prefix_length, self.gpt_embedding_size)
# raise SystemError(sids_prefix.shape) # torch.Size([2, 40, 768]) batch_size, seq_len, embed_size
return sids_prefix
def get_prompt_prefix(self, texts):
r"""Load list of text prompts and return prompt prefix and prompt embeddings"""
preprocessed_texts = self.preprocess_prompt(texts)
print(preprocessed_texts)
texts_embed = self._get_text_embeddings(preprocessed_texts)
return texts_embed, preprocessed_texts
def get_prompt_prefix_single(self, texts):
r"""Load list of text prompts and return prompt prefix and prompt embeddings"""
preprocessed_texts = self.preprocess_prompt_single(texts)
texts_embed = self._get_text_embeddings(preprocessed_texts)
return texts_embed, preprocessed_texts
def get_text_prefix(self, texts):
r"""Load list of text prompts and return prompt prefix and prompt embeddings"""
preprocessed_texts = self.preprocess_text(texts)
texts_embed = self._get_text_embeddings(preprocessed_texts)
return texts_embed, preprocessed_texts
def generate_beam(self, beam_size: int = 1, sids_prefix=None, entry_length=80, temperature=1., stop_token: str = ' <|endoftext|>'):
stop_token_index = self.tokenizer.encode(stop_token)[0]
tokens = None
scores = None
device = next(self.gpt.parameters()).device
seq_lengths = torch.ones(beam_size, device=device)
is_stopped = torch.zeros(beam_size, device=device, dtype=torch.bool)
with torch.no_grad():
generated = sids_prefix # sid embedding
for i in range(entry_length):
outputs = self.gpt(inputs_embeds=generated)
logits = outputs.logits
logits = logits[:, -1, :] / (temperature if temperature > 0 else 1.0)
logits = logits.softmax(-1).log()
if scores is None:
scores, next_tokens = logits.topk(beam_size, -1)
generated = generated.expand(beam_size, *generated.shape[1:])
next_tokens, scores = next_tokens.permute(1, 0), scores.squeeze(0)
if tokens is None:
tokens = next_tokens
else:
tokens = tokens.expand(beam_size, *tokens.shape[1:])
tokens = torch.cat((tokens, next_tokens), dim=1)
else:
logits[is_stopped] = -float(np.inf)
logits[is_stopped, 0] = 0
scores_sum = scores[:, None] + logits
seq_lengths[~is_stopped] += 1
scores_sum_average = scores_sum / seq_lengths[:, None]
scores_sum_average, next_tokens = scores_sum_average.view(-1).topk(beam_size, -1)
next_tokens_source = next_tokens // scores_sum.shape[1]
seq_lengths = seq_lengths[next_tokens_source]
next_tokens = next_tokens % scores_sum.shape[1]
next_tokens = next_tokens.unsqueeze(1)
tokens = tokens[next_tokens_source]
tokens = torch.cat((tokens, next_tokens), dim=1)
generated = generated[next_tokens_source]
scores = scores_sum_average * seq_lengths
is_stopped = is_stopped[next_tokens_source]
next_token_embed = self.gpt.transformer.wte(next_tokens.squeeze()).view(generated.shape[0], 1, -1)
generated = torch.cat((generated, next_token_embed), dim=1)
is_stopped = is_stopped + next_tokens.eq(stop_token_index).squeeze()
if is_stopped.all():
break
scores = scores / seq_lengths
output_list = tokens.cpu().numpy()
############ Shuo added for attn plot ###########
# token_list = []
# text_list = []
# for output, length in zip(output_list, seq_lengths):
# for item in output[:int(length)]:
# token_list.append(item)
# text_list.append(self.tokenizer.decode(item))
############ Shuo added for attn plot ###########
output_texts = [self.tokenizer.decode(output[:int(length)]) for output, length in zip(output_list, seq_lengths)]
order = scores.argsort(descending=True)
#output_texts = [[output_texts[i], scores[i].item()] for i in order]
output_texts = [output_texts[i] for i in order]
return output_texts
# return output_texts, token_list, text_list