Spaces:
Running
on
Zero
Running
on
Zero
| from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import os | |
| import torch.nn as nn | |
| from typing import List, Optional, Tuple, Union | |
| import math | |
| from transformers.models.llama.modeling_llama import LlamaDecoderLayer | |
| from transformers.models.llama.modeling_llama import BaseModelOutputWithPast | |
| from transformers import LlamaConfig | |
| from transformers.models.llama.modeling_llama import ( | |
| LlamaAttention, | |
| apply_rotary_pos_emb, | |
| Cache, | |
| repeat_kv, | |
| ) | |
| class SinusoidalPosEmb(nn.Module): | |
| def __init__(self, dim): | |
| super().__init__() | |
| self.dim = dim | |
| def forward(self, x): | |
| device = x.device | |
| half_dim = self.dim // 2 | |
| emb = math.log(10000) / (half_dim - 1) | |
| emb = torch.exp(torch.arange(half_dim, device=device) * -emb) | |
| emb = x[:, None] * emb[None, :] * 1.0 | |
| emb = torch.cat((emb.sin(), emb.cos()), dim=-1) | |
| return emb | |
| class LlamaAdaptiveRMSNorm(nn.Module): | |
| def __init__(self, hidden_size=1024, eps=1e-6, dim_cond=1024): | |
| super().__init__() | |
| self.to_weight = nn.Linear(dim_cond, hidden_size) | |
| nn.init.zeros_(self.to_weight.weight) | |
| nn.init.ones_(self.to_weight.bias) | |
| self.variance_epsilon = eps | |
| self._is_hf_initialized = True # disable automatic init | |
| def forward(self, hidden_states, cond_embedding): | |
| input_dtype = hidden_states.dtype | |
| variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) | |
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) | |
| weight = self.to_weight(cond_embedding) | |
| if len(weight.shape) == 2: | |
| weight = weight.unsqueeze(1) | |
| return (weight * hidden_states).to(input_dtype) | |
| class LlamaNARDecoderLayer(LlamaDecoderLayer): | |
| def __init__(self, config: LlamaConfig, layer_idx: int): | |
| """Override to adaptive layer norm""" | |
| super().__init__(config, layer_idx) # init attention, mlp, etc. | |
| # self.self_attn = LlamaXformersAttention(config=config, layer_idx=layer_idx) | |
| self.self_attn.is_causal = False # for flash attn.. | |
| self.input_layernorm = LlamaAdaptiveRMSNorm( | |
| config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size | |
| ) | |
| self.post_attention_layernorm = LlamaAdaptiveRMSNorm( | |
| config.hidden_size, eps=config.rms_norm_eps, dim_cond=config.hidden_size | |
| ) | |
| # add `cond` in forward function | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| cond_embedding: torch.Tensor, | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_value: Optional[Tuple[torch.Tensor]] = None, | |
| output_attentions: Optional[bool] = False, | |
| use_cache: Optional[bool] = False, | |
| ) -> Tuple[ | |
| torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] | |
| ]: | |
| """ | |
| Args: | |
| hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)` | |
| attention_mask (`torch.FloatTensor`, *optional*): attention mask of size | |
| `(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values. | |
| output_attentions (`bool`, *optional*): | |
| Whether or not to return the attentions tensors of all attention layers. See `attentions` under | |
| returned tensors for more detail. | |
| use_cache (`bool`, *optional*): | |
| If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding | |
| (see `past_key_values`). | |
| past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states | |
| """ | |
| residual = hidden_states | |
| hidden_states = self.input_layernorm( | |
| hidden_states, cond_embedding=cond_embedding | |
| ) | |
| # Self Attention | |
| hidden_states, self_attn_weights, present_key_value = self.self_attn( | |
| hidden_states=hidden_states, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_value=past_key_value, | |
| output_attentions=output_attentions, | |
| use_cache=use_cache, | |
| ) | |
| hidden_states = residual + hidden_states | |
| # Fully Connected | |
| residual = hidden_states | |
| hidden_states = self.post_attention_layernorm( | |
| hidden_states, cond_embedding=cond_embedding | |
| ) | |
| hidden_states = self.mlp(hidden_states) | |
| hidden_states = residual + hidden_states | |
| outputs = (hidden_states,) | |
| if output_attentions: | |
| outputs += (self_attn_weights,) | |
| if use_cache: | |
| outputs += (present_key_value,) | |
| return outputs | |
| class DiffLlamaConcat(LlamaModel): | |
| def __init__( | |
| self, | |
| mel_dim=100, | |
| hidden_size=1024, | |
| num_heads=16, | |
| num_layers=16, | |
| dropout=0.1, | |
| ffn_dropout=0.1, | |
| attention_dropout=0.0, | |
| config=LlamaConfig(0, 256, 1024, 1, 1), | |
| flash_attention=False, | |
| ): | |
| super().__init__(config) | |
| self.flash_attention = flash_attention | |
| self.layers = nn.ModuleList( | |
| [ | |
| LlamaNARDecoderLayer( | |
| LlamaConfig( | |
| hidden_size=hidden_size, | |
| num_attention_heads=num_heads, | |
| max_position_embeddings=4096, | |
| intermediate_size=hidden_size * 4, | |
| attn_implementation=( | |
| "flash_attention_2" if self.flash_attention else "eager" | |
| ), | |
| ), | |
| layer_idx=i, | |
| ) | |
| for i in range(num_layers) | |
| ] | |
| ) | |
| self.norm = LlamaAdaptiveRMSNorm(hidden_size, dim_cond=hidden_size) | |
| self.diff_step_embedding = SinusoidalPosEmb(hidden_size) | |
| self.diff_step_mlp = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size * 4), | |
| nn.SiLU(), | |
| nn.Linear(hidden_size * 4, hidden_size), | |
| ) | |
| self.cond_mlp = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size * 4), | |
| nn.SiLU(), | |
| nn.Linear(hidden_size * 4, hidden_size), | |
| ) | |
| self.mel_mlp = nn.Sequential( | |
| nn.Linear(mel_dim, hidden_size * 4), | |
| nn.SiLU(), | |
| nn.Linear(hidden_size * 4, hidden_size), | |
| ) | |
| self.mel_out_mlp = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size * 4), | |
| nn.SiLU(), | |
| nn.Linear(hidden_size * 4, mel_dim), | |
| ) | |
| for layer in self.layers: | |
| layer.input_layernorm = LlamaAdaptiveRMSNorm( | |
| hidden_size, dim_cond=hidden_size | |
| ) | |
| layer.post_attention_layernorm = LlamaAdaptiveRMSNorm( | |
| hidden_size, dim_cond=hidden_size | |
| ) | |
| self.embed_tokens = None | |
| self.post_init() | |
| # self.reset_parameters() | |
| def _prepare_decoder_attention_mask( | |
| self, attention_mask, input_shape, inputs_embeds, past_key_values_length | |
| ): | |
| # create noncausal mask | |
| # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
| combined_attention_mask = None | |
| def _expand_mask( | |
| mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None | |
| ): | |
| """ | |
| Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. | |
| """ | |
| bsz, src_len = mask.size() | |
| tgt_len = tgt_len if tgt_len is not None else src_len | |
| expanded_mask = ( | |
| mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) | |
| ) | |
| inverted_mask = 1.0 - expanded_mask | |
| return inverted_mask.masked_fill( | |
| inverted_mask.to(torch.bool), torch.finfo(dtype).min | |
| ) | |
| if attention_mask is not None: | |
| # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
| expanded_attn_mask = _expand_mask( | |
| attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] | |
| ).to(inputs_embeds.device) | |
| combined_attention_mask = ( | |
| expanded_attn_mask | |
| if combined_attention_mask is None | |
| else expanded_attn_mask + combined_attention_mask | |
| ) | |
| return combined_attention_mask | |
| def forward( | |
| self, | |
| x, | |
| diffusion_step, | |
| x_mask, | |
| cond, | |
| input_ids: torch.LongTensor = None, # [num_quant, B, T] | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, BaseModelOutputWithPast]: | |
| # retrieve some shape info | |
| batch_size, seq_length, _ = x.shape | |
| # condtion mlp | |
| cond_embedding = self.cond_mlp(cond) # (B, T, C) | |
| # condition mel | |
| x = self.mel_mlp(x) | |
| # diffusion step embedding | |
| diffusion_step = self.diff_step_embedding(diffusion_step).to(x.device) | |
| diffusion_step = self.diff_step_mlp(diffusion_step) # (B, C) | |
| x = x + cond_embedding | |
| inputs_embeds = x | |
| # if self.flash_attention: | |
| # attention_mask = None | |
| # else: | |
| attention_mask = x_mask | |
| # assert x_mask.shape == batch_size, seq_length | |
| output_attentions = ( | |
| output_attentions | |
| if output_attentions is not None | |
| else self.config.output_attentions | |
| ) | |
| output_hidden_states = ( | |
| output_hidden_states | |
| if output_hidden_states is not None | |
| else self.config.output_hidden_states | |
| ) | |
| use_cache = use_cache if use_cache is not None else self.config.use_cache | |
| return_dict = ( | |
| return_dict if return_dict is not None else self.config.use_return_dict | |
| ) | |
| seq_length_with_past = seq_length | |
| past_key_values_length = 0 | |
| if past_key_values is not None: | |
| past_key_values_length = past_key_values[0][0].shape[2] | |
| seq_length_with_past = seq_length_with_past + past_key_values_length | |
| if position_ids is None: | |
| device = input_ids.device if input_ids is not None else inputs_embeds.device | |
| position_ids = torch.arange( | |
| past_key_values_length, | |
| seq_length + past_key_values_length, | |
| dtype=torch.long, | |
| device=device, | |
| ) | |
| position_ids = position_ids.unsqueeze(0).view(-1, seq_length) | |
| else: | |
| position_ids = position_ids.view(-1, seq_length).long() | |
| if not self.flash_attention: | |
| # embed positions | |
| if attention_mask is None: | |
| attention_mask = torch.ones( | |
| (batch_size, seq_length_with_past), | |
| dtype=torch.bool, | |
| device=inputs_embeds.device, | |
| ) | |
| attention_mask = self._prepare_decoder_attention_mask( | |
| attention_mask, | |
| (batch_size, seq_length), | |
| inputs_embeds, | |
| past_key_values_length, | |
| ) | |
| hidden_states = inputs_embeds | |
| if self.gradient_checkpointing and self.training: | |
| if use_cache: | |
| use_cache = False | |
| # decoder layers | |
| all_hidden_states = () if output_hidden_states else None | |
| all_self_attns = () if output_attentions else None | |
| next_decoder_cache = () if use_cache else None | |
| for idx, decoder_layer in enumerate(self.layers): | |
| if output_hidden_states: | |
| all_hidden_states += (hidden_states,) | |
| past_key_value = ( | |
| past_key_values[idx] if past_key_values is not None else None | |
| ) | |
| if self.gradient_checkpointing and self.training: | |
| raise NotImplementedError | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| # None for past_key_value | |
| return module(*inputs, output_attentions, None) | |
| return custom_forward | |
| layer_outputs = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(decoder_layer), | |
| hidden_states, | |
| attention_mask, | |
| position_ids, | |
| None, | |
| ) | |
| else: | |
| layer_outputs = decoder_layer( | |
| hidden_states, | |
| # attention_mask=attention_mask if not self.flash_attention else None, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_value=past_key_value, | |
| output_attentions=output_attentions, | |
| use_cache=use_cache, | |
| cond_embedding=diffusion_step, | |
| ) | |
| hidden_states = layer_outputs[0] | |
| if use_cache: | |
| next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) | |
| if output_attentions: | |
| all_self_attns += (layer_outputs[1],) | |
| hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step) | |
| # add hidden states from the last decoder layer | |
| if output_hidden_states: | |
| all_hidden_states += (hidden_states,) | |
| next_cache = next_decoder_cache if use_cache else None | |
| hidden_states = self.mel_out_mlp(hidden_states) | |
| if not output_hidden_states: | |
| return hidden_states | |
| else: | |
| return { | |
| "hidden_states": hidden_states, | |
| "all_hidden_states": all_hidden_states, | |
| } | |
| class DiffLlama(LlamaModel): | |
| def __init__( | |
| self, | |
| mel_dim=100, | |
| hidden_size=1024, | |
| num_heads=16, | |
| num_layers=16, | |
| dropout=0.1, | |
| ffn_dropout=0.1, | |
| attention_dropout=0.0, | |
| config=LlamaConfig(0, 256, 1024, 1, 1), | |
| flash_attention=False, | |
| ): | |
| super().__init__(config) | |
| self.flash_attention = flash_attention | |
| self.layers = nn.ModuleList( | |
| [ | |
| LlamaNARDecoderLayer( | |
| LlamaConfig( | |
| hidden_size=hidden_size, | |
| num_attention_heads=num_heads, | |
| max_position_embeddings=4096, | |
| intermediate_size=hidden_size * 4, | |
| attn_implementation=( | |
| "flash_attention_2" if self.flash_attention else "eager" | |
| ), | |
| is_causal=False, | |
| ), | |
| layer_idx=i, | |
| ) | |
| for i in range(num_layers) | |
| ] | |
| ) | |
| self.norm = LlamaAdaptiveRMSNorm(hidden_size, dim_cond=hidden_size) | |
| self.diff_step_embedding = SinusoidalPosEmb(hidden_size) | |
| self.diff_step_mlp = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size * 4), | |
| nn.SiLU(), | |
| nn.Linear(hidden_size * 4, hidden_size), | |
| ) | |
| # self.cond_mlp = nn.Sequential( | |
| # nn.Linear(hidden_size, hidden_size * 4), | |
| # nn.SiLU(), | |
| # nn.Linear(hidden_size * 4, hidden_size), | |
| # ) | |
| self.mel_mlp = nn.Sequential( | |
| nn.Linear(mel_dim, hidden_size * 4), | |
| nn.SiLU(), | |
| nn.Linear(hidden_size * 4, hidden_size), | |
| ) | |
| self.mel_out_mlp = nn.Sequential( | |
| nn.Linear(hidden_size, hidden_size * 4), | |
| nn.SiLU(), | |
| nn.Linear(hidden_size * 4, mel_dim), | |
| ) | |
| for layer in self.layers: | |
| layer.input_layernorm = LlamaAdaptiveRMSNorm( | |
| hidden_size, dim_cond=hidden_size | |
| ) | |
| layer.post_attention_layernorm = LlamaAdaptiveRMSNorm( | |
| hidden_size, dim_cond=hidden_size | |
| ) | |
| self.embed_tokens = None | |
| self.post_init() | |
| # self.reset_parameters() | |
| def _prepare_decoder_attention_mask( | |
| self, attention_mask, input_shape, inputs_embeds, past_key_values_length | |
| ): | |
| # create noncausal mask | |
| # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
| combined_attention_mask = None | |
| def _expand_mask( | |
| mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None | |
| ): | |
| """ | |
| Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. | |
| """ | |
| bsz, src_len = mask.size() | |
| tgt_len = tgt_len if tgt_len is not None else src_len | |
| expanded_mask = ( | |
| mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) | |
| ) | |
| inverted_mask = 1.0 - expanded_mask | |
| return inverted_mask.masked_fill( | |
| inverted_mask.to(torch.bool), torch.finfo(dtype).min | |
| ) | |
| if attention_mask is not None: | |
| # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len] | |
| expanded_attn_mask = _expand_mask( | |
| attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1] | |
| ).to(inputs_embeds.device) | |
| combined_attention_mask = ( | |
| expanded_attn_mask | |
| if combined_attention_mask is None | |
| else expanded_attn_mask + combined_attention_mask | |
| ) | |
| return combined_attention_mask | |
| def forward( | |
| self, | |
| x, | |
| diffusion_step, | |
| x_mask, | |
| cond, | |
| input_ids: torch.LongTensor = None, # [num_quant, B, T] | |
| attention_mask: Optional[torch.Tensor] = None, | |
| position_ids: Optional[torch.LongTensor] = None, | |
| past_key_values: Optional[List[torch.FloatTensor]] = None, | |
| inputs_embeds: Optional[torch.FloatTensor] = None, | |
| use_cache: Optional[bool] = None, | |
| output_attentions: Optional[bool] = None, | |
| output_hidden_states: Optional[bool] = None, | |
| return_dict: Optional[bool] = None, | |
| ) -> Union[Tuple, BaseModelOutputWithPast]: | |
| # retrieve some shape info | |
| batch_size, seq_length, _ = x.shape | |
| # condtion mlp | |
| cond_embedding = self.cond_mlp(cond) # (B, T, C) | |
| # condition mel | |
| x = self.mel_mlp(x) | |
| # diffusion step embedding | |
| diffusion_step = self.diff_step_embedding(diffusion_step).to(x.device) | |
| diffusion_step = self.diff_step_mlp(diffusion_step) # (B, C) | |
| x = x + cond_embedding | |
| inputs_embeds = x | |
| attention_mask = x_mask | |
| output_attentions = ( | |
| output_attentions | |
| if output_attentions is not None | |
| else self.config.output_attentions | |
| ) | |
| output_hidden_states = ( | |
| output_hidden_states | |
| if output_hidden_states is not None | |
| else self.config.output_hidden_states | |
| ) | |
| use_cache = use_cache if use_cache is not None else self.config.use_cache | |
| return_dict = ( | |
| return_dict if return_dict is not None else self.config.use_return_dict | |
| ) | |
| seq_length_with_past = seq_length | |
| past_key_values_length = 0 | |
| if past_key_values is not None: | |
| past_key_values_length = past_key_values[0][0].shape[2] | |
| seq_length_with_past = seq_length_with_past + past_key_values_length | |
| if position_ids is None: | |
| device = input_ids.device if input_ids is not None else inputs_embeds.device | |
| position_ids = torch.arange( | |
| past_key_values_length, | |
| seq_length + past_key_values_length, | |
| dtype=torch.long, | |
| device=device, | |
| ) | |
| position_ids = position_ids.unsqueeze(0).view(-1, seq_length) | |
| else: | |
| position_ids = position_ids.view(-1, seq_length).long() | |
| hidden_states = inputs_embeds | |
| if self.gradient_checkpointing and self.training: | |
| if use_cache: | |
| use_cache = False | |
| # decoder layers | |
| all_hidden_states = () if output_hidden_states else None | |
| all_self_attns = () if output_attentions else None | |
| next_decoder_cache = () if use_cache else None | |
| for idx, decoder_layer in enumerate(self.layers): | |
| if output_hidden_states: | |
| all_hidden_states += (hidden_states,) | |
| past_key_value = ( | |
| past_key_values[idx] if past_key_values is not None else None | |
| ) | |
| if self.gradient_checkpointing and self.training: | |
| raise NotImplementedError | |
| def create_custom_forward(module): | |
| def custom_forward(*inputs): | |
| # None for past_key_value | |
| return module(*inputs, output_attentions, None) | |
| return custom_forward | |
| layer_outputs = torch.utils.checkpoint.checkpoint( | |
| create_custom_forward(decoder_layer), | |
| hidden_states, | |
| attention_mask, | |
| position_ids, | |
| None, | |
| ) | |
| else: | |
| layer_outputs = decoder_layer( | |
| hidden_states, | |
| attention_mask=attention_mask, | |
| position_ids=position_ids, | |
| past_key_value=past_key_value, | |
| output_attentions=output_attentions, | |
| use_cache=use_cache, | |
| cond_embedding=diffusion_step, | |
| ) | |
| hidden_states = layer_outputs[0] | |
| if use_cache: | |
| next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) | |
| if output_attentions: | |
| all_self_attns += (layer_outputs[1],) | |
| hidden_states = self.norm(hidden_states, cond_embedding=diffusion_step) | |
| # add hidden states from the last decoder layer | |
| if output_hidden_states: | |
| all_hidden_states += (hidden_states,) | |
| next_cache = next_decoder_cache if use_cache else None | |
| hidden_states = self.mel_out_mlp(hidden_states) | |
| if not output_hidden_states: | |
| return hidden_states | |
| else: | |
| return { | |
| "hidden_states": hidden_states, | |
| "all_hidden_states": all_hidden_states, | |
| } | |