| | from typing import List, Optional, Tuple, Union |
| |
|
| | import torch |
| | import torch.nn as nn |
| | from transformers import LlamaConfig, LlamaModel, PreTrainedModel |
| | from transformers.cache_utils import Cache |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| | from transformers.models.llama.modeling_llama import KwargsForCausalLM |
| | from transformers.processing_utils import Unpack |
| |
|
| | from configuration_speechunit import SpeechUnitConfig |
| |
|
| |
|
| | |
| | class SpeechUnitPreTrainedModel(PreTrainedModel): |
| | config_class = SpeechUnitConfig |
| | base_model_prefix = "model" |
| | supports_gradient_checkpointing = True |
| | _no_split_modules = ["LlamaDecoderLayer"] |
| | _skip_keys_device_placement = ["past_key_values"] |
| | _supports_flash_attn_2 = True |
| | _supports_sdpa = True |
| | _supports_cache_class = True |
| | _supports_quantized_cache = True |
| | _supports_static_cache = True |
| |
|
| | def _init_weights(self, module): |
| | std = self.config.initializer_range |
| | if isinstance(module, nn.Linear): |
| | module.weight.data.normal_(mean=0.0, std=std) |
| | if module.bias is not None: |
| | module.bias.data.zero_() |
| | elif isinstance(module, nn.Embedding): |
| | module.weight.data.normal_(mean=0.0, std=std) |
| | if module.padding_idx is not None: |
| | module.weight.data[module.padding_idx].zero_() |
| | elif isinstance(module, SpeechUnitModel): |
| | src_model = LlamaModel.from_pretrained(self.config.base_model_id) |
| | with torch.no_grad(): |
| | for name, param in module.llama_model.named_parameters(): |
| | param.copy_(src_model.state_dict()[name]) |
| |
|
| | class SpeechUnitModel(SpeechUnitPreTrainedModel): |
| | def __init__(self, config: SpeechUnitConfig): |
| | super(SpeechUnitModel, self).__init__(config) |
| | |
| | |
| | llama_config = LlamaConfig.from_pretrained(config.base_model_id) |
| | llama_config.num_hidden_layers = config.num_hidden_layers |
| | self.llama_model = LlamaModel._from_config(llama_config) |
| | |
| | |
| | original_vocab_size, embed_dim = self.llama_model.embed_tokens.weight.shape |
| | |
| | |
| | self.audio_embed = nn.Embedding(16400, embed_dim) |
| | nn.init.xavier_uniform_(self.audio_embed.weight.data) |
| | |
| | |
| | self.token_weights = nn.Parameter(torch.ones(config.num_heads)) |
| | |
| | |
| | self.heads = nn.ModuleList([nn.Linear(embed_dim, config.output_dim) for _ in range(config.num_heads)]) |
| | |
| | self.post_init() |
| | |
| | def forward(self, |
| | input_ids: torch.LongTensor = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | position_ids: Optional[torch.LongTensor] = None, |
| | past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None, |
| | inputs_embeds: Optional[torch.FloatTensor] = None, |
| | labels: Optional[torch.LongTensor] = None, |
| | use_cache: Optional[bool] = None, |
| | output_attentions: Optional[bool] = None, |
| | output_hidden_states: Optional[bool] = None, |
| | return_dict: Optional[bool] = None, |
| | cache_position: Optional[torch.LongTensor] = None, |
| | num_logits_to_keep: int = 0, |
| | **kwargs: Unpack[KwargsForCausalLM], |
| | ) -> Union[Tuple, CausalLMOutputWithPast]: |
| | |
| | pass |