| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import torch |
| from torch import nn |
| from transformers import AutoConfig |
|
|
| from flashcosyvoice.config import CosyVoice2LLMConfig |
| from flashcosyvoice.modules.qwen2_components.layers import ( |
| ParallelLMHead, Qwen2DecoderLayer, RMSNorm, VocabParallelEmbedding) |
|
|
|
|
| class Qwen2Model(nn.Module): |
|
|
| def __init__( |
| self, |
| config: CosyVoice2LLMConfig, |
| ): |
| super().__init__() |
| self.vocab_size = config.vocab_size |
| self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) |
| self.layers = nn.ModuleList([Qwen2DecoderLayer(config) for _ in range(config.num_hidden_layers)]) |
| self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| positions: torch.Tensor, |
| ) -> torch.Tensor: |
| hidden_states = self.embed_tokens(input_ids) |
| residual = None |
| for layer in self.layers: |
| hidden_states, residual = layer( |
| positions, |
| hidden_states, |
| residual, |
| ) |
| hidden_states, _ = self.norm(hidden_states, residual) |
| return hidden_states |
|
|
|
|
| class Qwen2ForCausalLM(nn.Module): |
| packed_modules_mapping = { |
| "q_proj": ("qkv_proj", "q"), |
| "k_proj": ("qkv_proj", "k"), |
| "v_proj": ("qkv_proj", "v"), |
| "gate_proj": ("gate_up_proj", 0), |
| "up_proj": ("gate_up_proj", 1), |
| } |
|
|
| def __init__( |
| self, |
| config: CosyVoice2LLMConfig | AutoConfig |
| ): |
| super().__init__() |
| self.model = Qwen2Model(config) |
| if hasattr(config, "speech_vocab_size"): |
| self.lm_head = ParallelLMHead(config.speech_vocab_size, config.hidden_size, bias=getattr(config, "lm_head_bias", True)) |
| self.model_type = "speech_llm" |
| else: |
| self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size, bias=False) |
| self.model_type = "text_llm" |
| self.tie_word_embeddings = config.tie_word_embeddings |
| if self.tie_word_embeddings: |
| if self.model_type == "speech_llm": |
| assert config.vocab_size == config.speech_vocab_size, "vocab_size and speech_vocab_size must be the same when tie_word_embeddings is True" |
| self.lm_head.weight.data = self.model.embed_tokens.weight.data |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| positions: torch.Tensor, |
| ) -> torch.Tensor: |
| hidden_states = self.model(input_ids, positions) |
| return hidden_states |
|
|
| def compute_logits( |
| self, |
| hidden_states: torch.Tensor, |
| ) -> torch.Tensor: |
| logits = self.lm_head(hidden_states) |
| return logits |
|
|