Spaces:
Configuration error
Configuration error
| #Heavily influenced by https://github.com/facebookresearch/audiocraft/blob/main/audiocraft/modules/conditioners.py | |
| import torch | |
| import logging, warnings | |
| import string | |
| import typing as tp | |
| import gc | |
| from .adp import NumberEmbedder | |
| from ..inference.utils import set_audio_channels | |
| from .factory import create_pretransform_from_config | |
| from .pretransforms import Pretransform | |
| from ..training.utils import copy_state_dict | |
| from .utils import load_ckpt_state_dict | |
| from torch import nn | |
| from .Qformer import BertConfig, BertLMHeadModel, BertAttention, BertIntermediate, BertOutput | |
| from transformers import BertTokenizer | |
| class Conditioner(nn.Module): | |
| def __init__( | |
| self, | |
| dim: int, | |
| output_dim: int, | |
| project_out: bool = False | |
| ): | |
| super().__init__() | |
| self.dim = dim | |
| self.output_dim = output_dim | |
| self.proj_out = nn.Linear(dim, output_dim) if (dim != output_dim or project_out) else nn.Identity() | |
| def forward(self, x: tp.Any) -> tp.Any: | |
| raise NotImplementedError() | |
| class IntConditioner(Conditioner): | |
| def __init__(self, | |
| output_dim: int, | |
| min_val: int=0, | |
| max_val: int=512 | |
| ): | |
| super().__init__(output_dim, output_dim) | |
| self.min_val = min_val | |
| self.max_val = max_val | |
| self.int_embedder = nn.Embedding(max_val - min_val + 1, output_dim).requires_grad_(True) | |
| def forward(self, ints: tp.List[int], device=None) -> tp.Any: | |
| #self.int_embedder.to(device) | |
| ints = torch.tensor(ints).to(device) | |
| ints = ints.clamp(self.min_val, self.max_val) | |
| int_embeds = self.int_embedder(ints).unsqueeze(1) | |
| return [int_embeds, torch.ones(int_embeds.shape[0], 1).to(device)] | |
| class NumberConditioner(Conditioner): | |
| ''' | |
| Conditioner that takes a list of floats, normalizes them for a given range, and returns a list of embeddings | |
| ''' | |
| def __init__(self, | |
| output_dim: int, | |
| min_val: float=0, | |
| max_val: float=1 | |
| ): | |
| super().__init__(output_dim, output_dim) | |
| self.min_val = min_val | |
| self.max_val = max_val | |
| self.embedder = NumberEmbedder(features=output_dim) | |
| def forward(self, floats: tp.List[float], device=None) -> tp.Any: | |
| # Cast the inputs to floats | |
| floats = [float(x) for x in floats] | |
| floats = torch.tensor(floats).to(device) | |
| floats = floats.clamp(self.min_val, self.max_val) | |
| normalized_floats = (floats - self.min_val) / (self.max_val - self.min_val) | |
| # Cast floats to same type as embedder | |
| embedder_dtype = next(self.embedder.parameters()).dtype | |
| normalized_floats = normalized_floats.to(embedder_dtype) | |
| float_embeds = self.embedder(normalized_floats).unsqueeze(1) | |
| return [float_embeds, torch.ones(float_embeds.shape[0], 1).to(device)] | |
| class CLAPTextConditioner(Conditioner): | |
| def __init__(self, | |
| output_dim: int, | |
| clap_ckpt_path, | |
| use_text_features = False, | |
| feature_layer_ix: int = -1, | |
| audio_model_type="HTSAT-base", | |
| enable_fusion=True, | |
| project_out: bool = False, | |
| finetune: bool = False): | |
| super().__init__(768 if use_text_features else 512, output_dim, project_out=project_out) | |
| self.use_text_features = use_text_features | |
| self.feature_layer_ix = feature_layer_ix | |
| self.finetune = finetune | |
| # Suppress logging from transformers | |
| previous_level = logging.root.manager.disable | |
| logging.disable(logging.ERROR) | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| try: | |
| import laion_clap | |
| from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict | |
| model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu') | |
| if self.finetune: | |
| self.model = model | |
| else: | |
| self.__dict__["model"] = model | |
| state_dict = clap_load_state_dict(clap_ckpt_path) | |
| self.model.model.load_state_dict(state_dict, strict=False) | |
| if self.finetune: | |
| self.model.model.text_branch.requires_grad_(True) | |
| self.model.model.text_branch.train() | |
| else: | |
| self.model.model.text_branch.requires_grad_(False) | |
| self.model.model.text_branch.eval() | |
| finally: | |
| logging.disable(previous_level) | |
| del self.model.model.audio_branch | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"): | |
| prompt_tokens = self.model.tokenizer(prompts) | |
| attention_mask = prompt_tokens["attention_mask"].to(device=device, non_blocking=True) | |
| prompt_features = self.model.model.text_branch( | |
| input_ids=prompt_tokens["input_ids"].to(device=device, non_blocking=True), | |
| attention_mask=attention_mask, | |
| output_hidden_states=True | |
| )["hidden_states"][layer_ix] | |
| return prompt_features, attention_mask | |
| def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any: | |
| self.model.to(device) | |
| if self.use_text_features: | |
| if len(texts) == 1: | |
| text_features, text_attention_mask = self.get_clap_features([texts[0], ""], layer_ix=self.feature_layer_ix, device=device) | |
| text_features = text_features[:1, ...] | |
| text_attention_mask = text_attention_mask[:1, ...] | |
| else: | |
| text_features, text_attention_mask = self.get_clap_features(texts, layer_ix=self.feature_layer_ix, device=device) | |
| return [self.proj_out(text_features), text_attention_mask] | |
| # Fix for CLAP bug when only one text is passed | |
| if len(texts) == 1: | |
| text_embedding = self.model.get_text_embedding([texts[0], ""], use_tensor=True)[:1, ...] | |
| else: | |
| text_embedding = self.model.get_text_embedding(texts, use_tensor=True) | |
| text_embedding = text_embedding.unsqueeze(1).to(device) | |
| return [self.proj_out(text_embedding), torch.ones(text_embedding.shape[0], 1).to(device)] | |
| class CLAPAudioConditioner(Conditioner): | |
| def __init__(self, | |
| output_dim: int, | |
| clap_ckpt_path, | |
| audio_model_type="HTSAT-base", | |
| enable_fusion=True, | |
| project_out: bool = False): | |
| super().__init__(512, output_dim, project_out=project_out) | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Suppress logging from transformers | |
| previous_level = logging.root.manager.disable | |
| logging.disable(logging.ERROR) | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| try: | |
| import laion_clap | |
| from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict | |
| model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu') | |
| if self.finetune: | |
| self.model = model | |
| else: | |
| self.__dict__["model"] = model | |
| state_dict = clap_load_state_dict(clap_ckpt_path) | |
| self.model.model.load_state_dict(state_dict, strict=False) | |
| if self.finetune: | |
| self.model.model.audio_branch.requires_grad_(True) | |
| self.model.model.audio_branch.train() | |
| else: | |
| self.model.model.audio_branch.requires_grad_(False) | |
| self.model.model.audio_branch.eval() | |
| finally: | |
| logging.disable(previous_level) | |
| del self.model.model.text_branch | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]] , device: tp.Any = "cuda") -> tp.Any: | |
| self.model.to(device) | |
| if isinstance(audios, list) or isinstance(audios, tuple): | |
| audios = torch.cat(audios, dim=0) | |
| # Convert to mono | |
| mono_audios = audios.mean(dim=1) | |
| with torch.cuda.amp.autocast(enabled=False): | |
| audio_embedding = self.model.get_audio_embedding_from_data(mono_audios.float(), use_tensor=True) | |
| audio_embedding = audio_embedding.unsqueeze(1).to(device) | |
| return [self.proj_out(audio_embedding), torch.ones(audio_embedding.shape[0], 1).to(device)] | |
| def disabled_train(self, mode=True): | |
| """Overwrite model.train with this function to make sure train/eval mode | |
| does not change anymore.""" | |
| return self | |
| # Define BertLayer for cross attention | |
| from .Qformer import BertConfig, BertLMHeadModel, BertAttention, BertIntermediate, BertOutput | |
| from transformers.modeling_utils import ( | |
| PreTrainedModel, | |
| apply_chunking_to_forward, | |
| find_pruneable_heads_and_indices, | |
| prune_linear_layer, | |
| ) | |
| class BertLayer(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.config = config | |
| self.chunk_size_feed_forward = config.chunk_size_feed_forward | |
| self.seq_len_dim = 1 | |
| self.attention = BertAttention(config) | |
| self.crossattention = BertAttention( | |
| config, is_cross_attention=True | |
| ) | |
| self.has_cross_attention = True | |
| self.intermediate = BertIntermediate(config) | |
| self.output = BertOutput(config) | |
| self.intermediate_query = BertIntermediate(config) | |
| self.output_query = BertOutput(config) | |
| def feed_forward_chunk(self, attention_output): | |
| intermediate_output = self.intermediate(attention_output) | |
| layer_output = self.output(intermediate_output, attention_output) | |
| return layer_output | |
| def forward( | |
| self, | |
| hidden_states, | |
| attention_mask=None, | |
| head_mask=None, | |
| encoder_hidden_states=None, | |
| encoder_attention_mask=None, | |
| past_key_value=None, | |
| output_attentions=False, | |
| query_length=0, | |
| ): | |
| # decoder uni-directional self-attention cached key/values tuple is at positions 1,2 | |
| self_attn_past_key_value = None | |
| self_attention_outputs = self.attention( | |
| hidden_states, | |
| attention_mask, | |
| head_mask, | |
| output_attentions=output_attentions, | |
| past_key_value=self_attn_past_key_value, | |
| ) | |
| attention_output = self_attention_outputs[0] | |
| outputs = self_attention_outputs[1:-1] | |
| present_key_value = self_attention_outputs[-1] | |
| cross_attention_outputs = self.crossattention( | |
| attention_output, | |
| attention_mask, | |
| head_mask, | |
| encoder_hidden_states, | |
| encoder_attention_mask, | |
| output_attentions=output_attentions, | |
| ) | |
| outputs = ( | |
| outputs + cross_attention_outputs[1:-1] | |
| ) # add cross attentions if we output attention weights | |
| layer_output = apply_chunking_to_forward( | |
| self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output | |
| ) | |
| outputs = (layer_output,) + outputs | |
| outputs = outputs + (present_key_value,) | |
| return outputs | |
| class T5Conditioner(Conditioner): | |
| T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b", | |
| "google/flan-t5-small", "google/flan-t5-base", "google/flan-t5-large", | |
| "google/flan-t5-xl", "google/flan-t5-xxl"] | |
| T5_MODEL_DIMS = { | |
| "t5-small": 512, | |
| "t5-base": 768, | |
| "t5-large": 1024, | |
| "t5-3b": 1024, | |
| "t5-11b": 1024, | |
| "t5-xl": 2048, | |
| "t5-xxl": 4096, | |
| "google/flan-t5-small": 512, | |
| "google/flan-t5-base": 768, | |
| "google/flan-t5-large": 1024, | |
| "google/flan-t5-3b": 1024, | |
| "google/flan-t5-11b": 1024, | |
| "google/flan-t5-xl": 2048, | |
| "google/flan-t5-xxl": 4096, | |
| } | |
| def init_Qformer(cls, num_query_token, vision_width, freeze, cross_attention_freq=2): | |
| encoder_config = BertConfig.from_pretrained("bert-base-uncased") | |
| encoder_config.encoder_width = vision_width | |
| # insert cross-attention layer every other block | |
| encoder_config.add_cross_attention = True | |
| encoder_config.cross_attention_freq = cross_attention_freq | |
| encoder_config.query_length = num_query_token | |
| Qformer = BertLMHeadModel.from_pretrained( | |
| "bert-base-uncased", config=encoder_config | |
| ) | |
| qformer_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", truncation_side="right") | |
| qformer_tokenizer.add_special_tokens({"bos_token": "[DEC]"}) | |
| Qformer.resize_token_embeddings(len(qformer_tokenizer)) | |
| query_tokens = nn.Parameter( | |
| torch.zeros(1, num_query_token, encoder_config.hidden_size) | |
| ) | |
| query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) | |
| # optional, if not loading weights | |
| Qformer.cls = None | |
| Qformer.bert.embeddings.word_embeddings = None | |
| Qformer.bert.embeddings.position_embeddings = None | |
| for layer in Qformer.bert.encoder.layer: | |
| layer.output = None | |
| layer.intermediate = None | |
| if freeze: | |
| for name, param in Qformer.named_parameters(): | |
| param.requires_grad = False | |
| Qformer = Qformer.eval() | |
| Qformer.train = disabled_train | |
| query_tokens.requires_grad = False | |
| print("freeze Qformer") | |
| return Qformer, query_tokens | |
| def __init__( | |
| self, | |
| output_dim: int, | |
| t5_model_name: str = "t5-base", | |
| max_length: str = 128, | |
| enable_grad: bool = False, | |
| project_out: bool = False | |
| ): | |
| assert t5_model_name in self.T5_MODELS, f"Unknown T5 model name: {t5_model_name}" | |
| super().__init__(self.T5_MODEL_DIMS[t5_model_name], output_dim, project_out=project_out) | |
| from transformers import T5EncoderModel, AutoTokenizer, AutoModel | |
| from peft import ( | |
| LoraConfig, | |
| get_peft_model, | |
| get_peft_model_state_dict, | |
| set_peft_model_state_dict, | |
| ) | |
| self.qformer_proj_norm = nn.LayerNorm(768) | |
| self.audio_proj_norm_qformer = nn.LayerNorm(768, elementwise_affine=False) | |
| # Cross attention layer for qformer and t5 | |
| bert_config = BertConfig.from_pretrained('bert-base-uncased') | |
| bert_config.encoder_width = 768 | |
| self.cross_attend = BertLayer(bert_config) | |
| # self.proj_out_cross = nn.Linear(1024,768) | |
| self.max_length = max_length | |
| self.enable_grad = enable_grad | |
| # Suppress logging from transformers | |
| previous_level = logging.root.manager.disable | |
| logging.disable(logging.ERROR) | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| try: | |
| # self.tokenizer = T5Tokenizer.from_pretrained(t5_model_name, model_max_length = max_length) | |
| # model = T5EncoderModel.from_pretrained(t5_model_name, max_length=max_length).train(enable_grad).requires_grad_(enable_grad) | |
| self.tokenizer = AutoTokenizer.from_pretrained(t5_model_name) | |
| ckpt = torch.load('/fs/nexus-projects/brain_project/try_t5.pt') | |
| model = T5EncoderModel.from_pretrained(t5_model_name).train(enable_grad).requires_grad_(enable_grad).to(torch.float16) | |
| model.load_state_dict(ckpt,strict=True) | |
| self.llm_model = AutoModel.from_pretrained( | |
| "meta-llama/Meta-Llama-3.1-8B", | |
| load_in_8bit=False, | |
| # torch_dtype=torch.float16, | |
| device_map="auto", | |
| ) | |
| self.llm_model_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B") | |
| self.num_new_tokens = 64 | |
| self.IGNORE_TOKEN_ID=-100 | |
| # 1. LLM has the special token "<ad>" for system message to generate image -> add_tokens "<img>" -> 32000 | |
| self.llm_model_tokenizer.add_tokens(["<ad>"], special_tokens=False) | |
| # 2. LLM contains 64 tokens to summarize image and text information for conversation system -> add_tokens "<img_0>...<img_63>" -> 32001~32064 | |
| new_token_list = [f"<ad_{i}>" for i in range(self.num_new_tokens)] | |
| self.llm_model_tokenizer.add_tokens(new_token_list, special_tokens=False) | |
| # 3. count new tokens and resize tokenizer | |
| self.num_new_tokens = self.num_new_tokens + 1 | |
| self.llm_model.resize_token_embeddings(len(self.llm_model_tokenizer)) | |
| self.llm_model_tokenizer.ad_start_token_id = self.llm_model_tokenizer.convert_tokens_to_ids("<ad_0>") | |
| # ------------------- # | |
| # build new lm head | |
| self.lm_head = nn.Linear(4096, len(self.llm_model_tokenizer), bias=False) | |
| # initialize a new variable to store vocab_size | |
| self.vocab_size = len(self.llm_model_tokenizer) | |
| # 4. Initialize the new embeddings with original embeddings | |
| input_embeddings = self.llm_model.model.embed_tokens.weight.data | |
| output_embeddings = self.llm_model.lm_head.weight.data | |
| self.original_LLM_word_embedding_0 = input_embeddings[0] | |
| self.original_LLM_language_model_head_0 = output_embeddings[0] | |
| # ------------- # | |
| input_embeddings_avg = input_embeddings[:-self.num_new_tokens].mean(dim=0, keepdim=True) | |
| output_embeddings_avg = output_embeddings[:-self.num_new_tokens].mean(dim=0, keepdim=True) | |
| # ------------- # | |
| input_embeddings[-self.num_new_tokens:] = input_embeddings_avg | |
| output_embeddings[-self.num_new_tokens:] = output_embeddings_avg | |
| # ------------- # | |
| self.llm_model.model.embed_tokens.weight.data = input_embeddings | |
| self.lm_head.weight.data = output_embeddings | |
| # 5. Initialize the Qformer | |
| self.Qformer, self.query_tokens = self.init_Qformer(32, 768, True) | |
| self.llm_to_qformer_projection = nn.Linear(4096,768) | |
| # Add lora modules to the model | |
| config = LoraConfig(r=8, lora_alpha=16, target_modules=["q_proj", "v_proj"], lora_dropout=0.05, bias="none", task_type='CAUSAL_LM') | |
| self.llm_model = get_peft_model(self.llm_model, config) | |
| finally: | |
| logging.disable(previous_level) | |
| if self.enable_grad: | |
| self.model = model | |
| else: | |
| self.__dict__["model"] = model | |
| def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: | |
| llm_input_ids = [] | |
| llm_input_attention_mask = [] | |
| llm_targets_list = [] | |
| qformer_input_attention_mask = [] | |
| for text in texts: | |
| # define a system prompt for the LLM | |
| llm_caption_system = "A chat between a curious user and an artificial intelligence assistant. The assistant can generate <ad>. " | |
| # construct the prompt for the LLM | |
| llm_caption_interim = "Please generate an audio for the following caption: " + text | |
| llm_caption_last = " Here is the audio for the given caption: [ad]" | |
| append_str = "" | |
| for i in range(self.num_new_tokens - 1): | |
| append_str += f" <ad_{i}>" | |
| llm_caption = llm_caption_last.replace(" [ad]", append_str) | |
| # add the system prompt to the LLM prompt | |
| llm_caption = llm_caption_system + llm_caption_interim + llm_caption_last | |
| # tokenize the prompt | |
| input_ids_max_len = 512 | |
| llm_caption_input_ids = self.llm_model_tokenizer( | |
| llm_caption, | |
| return_tensors="pt", | |
| padding="max_length", | |
| max_length=input_ids_max_len, | |
| truncation=True, | |
| ).input_ids[0] | |
| # generate LLM targets | |
| llm_targets = llm_caption_input_ids.clone() | |
| llm_targets[:1] = self.IGNORE_TOKEN_ID | |
| total_padding_len = int(llm_targets.ne(self.llm_model_tokenizer.pad_token_id).sum()) | |
| instruction_len = len( | |
| self.llm_model_tokenizer( | |
| llm_caption_system + llm_caption_interim, | |
| max_length=input_ids_max_len, | |
| truncation=True, | |
| ).input_ids) - 2 | |
| llm_targets[1:(1 + instruction_len)] = self.IGNORE_TOKEN_ID | |
| llm_targets[total_padding_len:] = self.IGNORE_TOKEN_ID | |
| # append all | |
| llm_input_ids.append(llm_caption_input_ids) | |
| llm_input_attention_mask.append(llm_caption_input_ids.ne(self.llm_model_tokenizer.pad_token_id)) | |
| llm_targets_list.append(llm_targets) | |
| qformer_input_attention_mask.append(llm_caption_input_ids.ge(self.llm_model_tokenizer.ad_start_token_id)) | |
| llm_input_ids = torch.stack([torch.tensor(input_id) for input_id in llm_input_ids]).to(device) | |
| llm_input_attention_mask = torch.stack([torch.tensor(attention_mask) for attention_mask in llm_input_attention_mask]).to(device) | |
| llm_targets = torch.stack([torch.tensor(llm_target) for llm_target in llm_targets_list]).to(device) | |
| qformer_input_attention_mask = torch.stack([torch.tensor(qformer_attention_mask) for qformer_attention_mask in qformer_input_attention_mask]).to(device) | |
| # LLM Model | |
| llm_outputs = self.llm_model.model( | |
| input_ids=llm_input_ids, | |
| attention_mask=llm_input_attention_mask, | |
| position_ids=None, | |
| past_key_values=None, | |
| use_cache=None, | |
| output_attentions=False, | |
| output_hidden_states=False, | |
| return_dict=False, | |
| ) | |
| hidden_states_llm = llm_outputs[0] | |
| shift_labels = llm_targets[..., 1:].contiguous() | |
| # 5. Next token prediction language model loss: Enable model parallelism | |
| hidden_states = hidden_states_llm.to(torch.float32) | |
| logits = self.lm_head(hidden_states) | |
| shift_logits = logits[..., :-1, :].contiguous() | |
| # Flatten the tokens | |
| ce_loss_fct = nn.CrossEntropyLoss() | |
| shift_logits = shift_logits.view(-1, self.vocab_size) | |
| shift_labels = shift_labels.view(-1) | |
| shift_labels = shift_labels.to(shift_logits.device) | |
| LM_loss = ce_loss_fct(shift_logits, shift_labels) * 1.0 #self.config.llm_loss_weight | |
| # for qformer | |
| hidden_states_llm = self.llm_to_qformer_projection(hidden_states_llm[:, :512, :]) # cut it to 512 max | |
| audio_input_for_qformer = self.qformer_proj_norm(hidden_states_llm) | |
| # audio_atts = torch.ones(audio_input_for_qformer.size()[:-1], dtype=torch.long).to(device) # can and should we convert the attention to pay attention only to the non padded tokens | |
| query_tokens = self.query_tokens.expand(audio_input_for_qformer.shape[0], -1, -1) | |
| query_output = self.Qformer.bert( | |
| query_embeds=query_tokens, | |
| encoder_hidden_states=audio_input_for_qformer, | |
| encoder_attention_mask=qformer_input_attention_mask, | |
| return_dict=True, | |
| ) | |
| query_output = self.audio_proj_norm_qformer(query_output.last_hidden_state) | |
| # T5 model | |
| self.model.to(device) | |
| self.proj_out.to(device) | |
| # self.proj_out_cross.to(device) | |
| encoded = self.tokenizer( | |
| texts, | |
| truncation=True, | |
| max_length=self.max_length, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| input_ids = encoded["input_ids"].to(device) | |
| attention_mask = encoded["attention_mask"].to(device).to(torch.bool) | |
| self.model.eval() | |
| with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad): | |
| embeddings = self.model( | |
| input_ids=input_ids, attention_mask=attention_mask | |
| )["last_hidden_state"] | |
| embeddings = self.proj_out(embeddings.float()) | |
| embeddings = embeddings * attention_mask.unsqueeze(-1).float() | |
| # -----------------# | |
| # cross attention between qformer and t5 | |
| qformer_attns = torch.ones(query_output.size()[:-1], dtype=torch.long).to(device) | |
| qformer_attns = self.get_bert_extended_attention_mask(qformer_attns, query_output.size()[:-1], device, False) | |
| embeddings = self.cross_attend(query_output,attention_mask=qformer_attns,encoder_hidden_states=embeddings,encoder_attention_mask=attention_mask) | |
| return embeddings, attention_mask, LM_loss | |
| class PhonemeConditioner(Conditioner): | |
| """ | |
| A conditioner that turns text into phonemes and embeds them using a lookup table | |
| Only works for English text | |
| Args: | |
| output_dim: the dimension of the output embeddings | |
| max_length: the maximum number of phonemes to embed | |
| project_out: whether to add another linear projection to the output embeddings | |
| """ | |
| def __init__( | |
| self, | |
| output_dim: int, | |
| max_length: int = 1024, | |
| project_out: bool = False, | |
| ): | |
| super().__init__(output_dim, output_dim, project_out=project_out) | |
| from g2p_en import G2p | |
| self.max_length = max_length | |
| self.g2p = G2p() | |
| # Reserving 0 for padding, 1 for ignored | |
| self.phoneme_embedder = nn.Embedding(len(self.g2p.phonemes) + 2, output_dim) | |
| def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: | |
| self.phoneme_embedder.to(device) | |
| self.proj_out.to(device) | |
| batch_phonemes = [self.g2p(text) for text in texts] # shape [batch_size, length] | |
| phoneme_ignore = [" ", *string.punctuation] | |
| # Remove ignored phonemes and cut to max length | |
| batch_phonemes = [[p if p not in phoneme_ignore else "_" for p in phonemes] for phonemes in batch_phonemes] | |
| # Convert to ids | |
| phoneme_ids = [[self.g2p.p2idx[p] + 2 if p in self.g2p.p2idx else 1 for p in phonemes] for phonemes in batch_phonemes] | |
| #Pad to match longest and make a mask tensor for the padding | |
| longest = max([len(ids) for ids in phoneme_ids]) | |
| phoneme_ids = [ids + [0] * (longest - len(ids)) for ids in phoneme_ids] | |
| phoneme_ids = torch.tensor(phoneme_ids).to(device) | |
| # Convert to embeddings | |
| phoneme_embeds = self.phoneme_embedder(phoneme_ids) | |
| phoneme_embeds = self.proj_out(phoneme_embeds) | |
| return phoneme_embeds, torch.ones(phoneme_embeds.shape[0], phoneme_embeds.shape[1]).to(device) | |
| class TokenizerLUTConditioner(Conditioner): | |
| """ | |
| A conditioner that embeds text using a lookup table on a pretrained tokenizer's vocabulary | |
| Args: | |
| tokenizer_name: the name of the tokenizer from the Hugging Face transformers library | |
| output_dim: the dimension of the output embeddings | |
| max_length: the maximum length of the text to embed | |
| project_out: whether to add another linear projection to the output embeddings | |
| """ | |
| def __init__( | |
| self, | |
| tokenizer_name: str, # Name of a tokenizer from the Hugging Face transformers library | |
| output_dim: int, | |
| max_length: int = 1024, | |
| project_out: bool = False, | |
| ): | |
| super().__init__(output_dim, output_dim, project_out=project_out) | |
| from transformers import AutoTokenizer | |
| # Suppress logging from transformers | |
| previous_level = logging.root.manager.disable | |
| logging.disable(logging.ERROR) | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| try: | |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) | |
| finally: | |
| logging.disable(previous_level) | |
| self.max_length = max_length | |
| self.token_embedder = nn.Embedding(len(self.tokenizer), output_dim) | |
| def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: | |
| self.proj_out.to(device) | |
| encoded = self.tokenizer( | |
| texts, | |
| truncation=True, | |
| max_length=self.max_length, | |
| padding="max_length", | |
| return_tensors="pt", | |
| ) | |
| input_ids = encoded["input_ids"].to(device) | |
| attention_mask = encoded["attention_mask"].to(device).to(torch.bool) | |
| embeddings = self.token_embedder(input_ids) | |
| embeddings = self.proj_out(embeddings) | |
| embeddings = embeddings * attention_mask.unsqueeze(-1).float() | |
| return embeddings, attention_mask | |
| class PretransformConditioner(Conditioner): | |
| """ | |
| A conditioner that uses a pretransform's encoder for conditioning | |
| Args: | |
| pretransform: an instantiated pretransform to use for conditioning | |
| output_dim: the dimension of the output embeddings | |
| """ | |
| def __init__(self, pretransform: Pretransform, output_dim: int): | |
| super().__init__(pretransform.encoded_channels, output_dim) | |
| self.pretransform = pretransform | |
| def forward(self, audio: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]], device: tp.Union[torch.device, str]) -> tp.Tuple[torch.Tensor, torch.Tensor]: | |
| self.pretransform.to(device) | |
| self.proj_out.to(device) | |
| if isinstance(audio, list) or isinstance(audio, tuple): | |
| audio = torch.cat(audio, dim=0) | |
| # Convert audio to pretransform input channels | |
| audio = set_audio_channels(audio, self.pretransform.io_channels) | |
| latents = self.pretransform.encode(audio) | |
| latents = self.proj_out(latents) | |
| return [latents, torch.ones(latents.shape[0], latents.shape[2]).to(latents.device)] | |
| class MultiConditioner(nn.Module): | |
| """ | |
| A module that applies multiple conditioners to an input dictionary based on the keys | |
| Args: | |
| conditioners: a dictionary of conditioners with keys corresponding to the keys of the conditioning input dictionary (e.g. "prompt") | |
| default_keys: a dictionary of default keys to use if the key is not in the input dictionary (e.g. {"prompt_t5": "prompt"}) | |
| """ | |
| def __init__(self, conditioners: tp.Dict[str, Conditioner], default_keys: tp.Dict[str, str] = {}): | |
| super().__init__() | |
| self.conditioners = nn.ModuleDict(conditioners) | |
| self.default_keys = default_keys | |
| def forward(self, batch_metadata: tp.List[tp.Dict[str, tp.Any]], device: tp.Union[torch.device, str]) -> tp.Dict[str, tp.Any]: | |
| output = {} | |
| for key, conditioner in self.conditioners.items(): | |
| condition_key = key | |
| conditioner_inputs = [] | |
| for x in batch_metadata: | |
| if condition_key not in x: | |
| if condition_key in self.default_keys: | |
| condition_key = self.default_keys[condition_key] | |
| else: | |
| raise ValueError(f"Conditioner key {condition_key} not found in batch metadata") | |
| #Unwrap the condition info if it's a single-element list or tuple, this is to support collation functions that wrap everything in a list | |
| if isinstance(x[condition_key], list) or isinstance(x[condition_key], tuple) and len(x[condition_key]) == 1: | |
| conditioner_input = x[condition_key][0] | |
| else: | |
| conditioner_input = x[condition_key] | |
| # if isinstance(conditioner_input, dict): | |
| # if len(conditioner_inputs) == 0: | |
| # conditioner_inputs = {C:[] for C in conditioner_input} | |
| # for cond in conditioner_input: | |
| # conditioner_inputs[cond].append(conditioner_input[cond]) | |
| # else: | |
| conditioner_inputs.append(conditioner_input) | |
| output[key] = conditioner(conditioner_inputs, device) | |
| return output | |
| def create_multi_conditioner_from_conditioning_config(config: tp.Dict[str, tp.Any]) -> MultiConditioner: | |
| """ | |
| Create a MultiConditioner from a conditioning config dictionary | |
| Args: | |
| config: the conditioning config dictionary | |
| device: the device to put the conditioners on | |
| """ | |
| conditioners = {} | |
| cond_dim = config["cond_dim"] | |
| default_keys = config.get("default_keys", {}) | |
| for conditioner_info in config["configs"]: | |
| id = conditioner_info["id"] | |
| conditioner_type = conditioner_info["type"] | |
| conditioner_config = {"output_dim": cond_dim} | |
| conditioner_config.update(conditioner_info["config"]) | |
| if conditioner_type == "t5": | |
| conditioners[id] = T5Conditioner(**conditioner_config) | |
| elif conditioner_type == "clap_text": | |
| conditioners[id] = CLAPTextConditioner(**conditioner_config) | |
| elif conditioner_type == "clap_audio": | |
| conditioners[id] = CLAPAudioConditioner(**conditioner_config) | |
| elif conditioner_type == "int": | |
| conditioners[id] = IntConditioner(**conditioner_config) | |
| elif conditioner_type == "number": | |
| conditioners[id] = NumberConditioner(**conditioner_config) | |
| elif conditioner_type == "phoneme": | |
| conditioners[id] = PhonemeConditioner(**conditioner_config) | |
| elif conditioner_type == "lut": | |
| conditioners[id] = TokenizerLUTConditioner(**conditioner_config) | |
| elif conditioner_type == "pretransform": | |
| sample_rate = conditioner_config.pop("sample_rate", None) | |
| assert sample_rate is not None, "Sample rate must be specified for pretransform conditioners" | |
| pretransform = create_pretransform_from_config(conditioner_config.pop("pretransform_config"), sample_rate=sample_rate) | |
| if conditioner_config.get("pretransform_ckpt_path", None) is not None: | |
| pretransform.load_state_dict(load_ckpt_state_dict(conditioner_config.pop("pretransform_ckpt_path"))) | |
| conditioners[id] = PretransformConditioner(pretransform, **conditioner_config) | |
| else: | |
| raise ValueError(f"Unknown conditioner type: {conditioner_type}") | |
| return MultiConditioner(conditioners, default_keys=default_keys) |