| |
|
|
| import torch |
| import torch.nn as nn |
| import torchaudio |
| import torch.nn.functional as F |
| import numpy as np |
| import whisper |
| from torch import Tensor |
| from einops import rearrange |
| from typing import Optional, List |
|
|
| from peft import ( |
| LoraConfig, |
| get_peft_model |
| ) |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoTokenizer, |
| PreTrainedModel, |
| GenerationMixin, |
| AutoConfig |
| ) |
| from .modeling_whisper import AudioEncoder |
| from .configuration_symphony import SymphonyConfig |
| |
| try: |
| from torch.nn.functional import scaled_dot_product_attention |
| SDPA_AVAILABLE = True |
| except (ImportError, RuntimeError, OSError): |
| scaled_dot_product_attention = None |
| SDPA_AVAILABLE = False |
|
|
| LANGUAGES = { |
| "en": "english", |
| "ko": "korean" |
| } |
|
|
| def set_trainable_parameters(module, requires_grad=False): |
| for param in module.parameters(): |
| param.requires_grad = requires_grad |
| module._requires_grad = requires_grad |
|
|
| |
|
|
| class Compressor(nn.Module): |
| def __init__(self, embed_dim, num_heads, num_query, n_ctx): |
| super().__init__() |
| self.num_heads = num_heads |
| self.head_dims = embed_dim // num_heads |
| self.n_ctx = n_ctx |
| |
| self.query = nn.Parameter(torch.randn(1, num_query, embed_dim)) |
| nn.init.normal_(self.query, mean=0.0, std=0.02) |
| |
| self.q_ln = nn.LayerNorm(embed_dim, eps=1e-5) |
| self.kv_ln = nn.LayerNorm(embed_dim, eps=1e-5) |
| |
| self.kv_proj = nn.Identity() |
| self.out_proj = nn.Linear(embed_dim, embed_dim) |
|
|
| self.register_buffer("q_pos_embeds", self.sinusoids(num_query, embed_dim)) |
| self.register_buffer("kv_pos_embeds", self.sinusoids(n_ctx, embed_dim)) |
| |
| self.init_weights() |
| |
| def init_weights(self): |
| nn.init.constant_(self.q_ln.bias, 0) |
| nn.init.constant_(self.q_ln.weight, 1.0) |
| nn.init.constant_(self.kv_ln.bias, 0) |
| nn.init.constant_(self.kv_ln.weight, 1.0) |
| |
| def sinusoids(self, length, channels, max_timescale=10000): |
| assert channels % 2 == 0 |
| log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) |
| inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) |
| scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] |
| return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) |
|
|
| def forward(self, x: Tensor): |
| q = self.q_ln(self.query.to(x.device)) |
| x = self.kv_ln(self.kv_proj(x)) |
|
|
| q = rearrange(q + self.q_pos_embeds, 'b l (h d) -> b h l d', h=self.num_heads, d=self.head_dims) |
| k = rearrange(x + self.kv_pos_embeds, 'b l (h d) -> b h l d', h=self.num_heads, d=self.head_dims) |
| v = rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads, d=self.head_dims) |
|
|
| attn = scaled_dot_product_attention(q, k, v) |
| attn = rearrange(attn, 'b h l d -> b l (h d)') |
| x = self.out_proj(attn) |
| return x |
|
|
| class MHSA(nn.Module): |
| def __init__(self, embed_dim, num_heads): |
| super().__init__() |
| self.embed_dim = embed_dim |
| self.num_heads = num_heads |
| self.head_dims = embed_dim // num_heads |
| self.q = nn.Linear(embed_dim, embed_dim, bias=True) |
| self.k = nn.Linear(embed_dim, embed_dim, bias=False) |
| self.v = nn.Linear(embed_dim, embed_dim, bias=True) |
| self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
| |
| def forward(self, x, xa=None, mask=None): |
| q = self.q(x) |
| k = self.k(x if xa is None else xa) |
| v = self.v(x if xa is None else xa) |
| |
| q = rearrange(q, 'b l (h d) -> b h l d', h=self.num_heads, d=self.head_dims) |
| k = rearrange(k, 'b l (h d) -> b h l d', h=self.num_heads, d=self.head_dims) |
| v = rearrange(v, 'b l (h d) -> b h l d', h=self.num_heads, d=self.head_dims) |
|
|
| attn = scaled_dot_product_attention(q, k, v, is_causal=mask is not None) |
| attn = rearrange(attn, 'b h l d -> b l (h d)') |
| |
| out = self.out_proj(attn) |
| return out |
|
|
| class Attention(nn.Module): |
| def __init__(self, embed_dim, num_heads): |
| super().__init__() |
| self.attn = MHSA(embed_dim=embed_dim, num_heads=num_heads) |
| self.cross_attn = MHSA(embed_dim=embed_dim, num_heads=num_heads) |
| self.norm1 = nn.LayerNorm(embed_dim, eps=1e-5) |
| self.norm2 = nn.LayerNorm(embed_dim, eps=1e-5) |
|
|
| def forward(self, x: Tensor, xa: Optional[Tensor] = None): |
| x = x + self.attn(self.norm1(x)) |
| x = x + self.cross_attn(x=self.norm2(x), xa=xa) |
| return x |
|
|
| class Downsampler(nn.Module): |
| def __init__(self, embed_dim: int): |
| super().__init__() |
| self.conv1 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=1) |
| self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) |
| self.ln_post = nn.LayerNorm(embed_dim, eps=1e-5) |
|
|
| def forward(self, x: Tensor): |
| x = F.gelu(self.conv1(x)) |
| x = F.gelu(self.conv2(x)) |
| x = x.permute(0, 2, 1) |
| x = self.ln_post(x) |
| return x |
|
|
| |
|
|
| class SpeechEncoder(nn.Module): |
| def __init__(self, config: SymphonyConfig): |
| super().__init__() |
| |
| self._device = 'cuda' if torch.cuda.is_available() else 'cpu' |
| self.whisper = AudioEncoder( |
| n_mels=config.encoder_config.n_mels, |
| n_ctx=config.encoder_config.n_ctx, |
| n_state=config.encoder_config.n_state, |
| n_head=config.encoder_config.n_head, |
| n_layer=config.encoder_config.n_layer |
| ) |
| self.n_mels = config.encoder_config.n_mels |
| |
| for param in self.whisper.parameters(): |
| param.requires_grad = False |
| |
| |
| self.llm_proj = nn.Linear(config.encoder_config.n_state, config.llm_config.hidden_size) |
|
|
| |
| num_heads = config.encoder_config.n_head |
| stage_tokens = config.encoder_config.stage_tokens |
| self.compression_size = config.encoder_config.compression_size |
| self.n_state = config.encoder_config.n_state |
| self.low_resource = config.low_resource |
|
|
| self.compressor1 = Compressor(config.encoder_config.n_state, num_heads, stage_tokens[0], 1500) |
| self.stage1 = Downsampler(config.encoder_config.n_state) |
| self.compressor2 = Compressor(config.encoder_config.n_state, num_heads, stage_tokens[1], 750) |
| self.stage2 = Downsampler(config.encoder_config.n_state) |
| self.compressor3 = Compressor(config.encoder_config.n_state, num_heads, stage_tokens[2], 375) |
| self.compressor = Compressor(config.encoder_config.n_state, num_heads, self.compression_size, sum(stage_tokens)) |
| |
| self.out_attn = nn.ModuleList([ |
| Attention(config.encoder_config.n_state, num_heads) for _ in range(2) |
| ]) |
|
|
| def embed_audio(self, mel: torch.Tensor): |
| output = self.whisper(mel) |
| |
| return output |
| |
| def forward(self, wav_list: List[torch.Tensor]): |
| if len(wav_list) <= 1: |
| speech_features = self.process_audio_for_llm_input(wav_list) |
| speech_attn_mask = torch.zeros(1,speech_features.size(1)).bool().to(speech_features.device) |
| return speech_features, speech_attn_mask |
| else: |
| speech_features = [] |
| speech_attn_mask = [] |
| for wav in wav_list: |
| speech_feature = self.process_audio_for_llm_input(wav) |
| speech_features.append(speech_feature) |
| speech_attn_mask.append(torch.zeros(1,speech_feature.size(1)).bool()) |
|
|
| speech_features = self.pad_sequence(speech_features,padding_side='right',padding_value=0.0) |
| speech_attn_mask = self.pad_sequence(speech_attn_mask,padding_side='right',padding_value=True).squeeze(1) |
| return speech_features, speech_attn_mask |
| |
| def process_audio_for_llm_input(self, wav: torch.Tensor): |
| n_frames = 3000 |
| min_length = 16000 |
| wav = wav.flatten() |
|
|
| if wav.shape[0] < min_length: |
| wav = F.pad(wav, (0, min_length - wav.shape[0])) |
| |
| mels = whisper.log_mel_spectrogram(wav, n_mels=self.n_mels).unsqueeze(0).to(self._device) |
| if mels.shape[-1] > n_frames: |
| mel_segments = [] |
| |
| for i in range(0, mels.shape[-1], n_frames): |
| mel = mels[:,:,i:i+n_frames] |
| if mel.shape[-1] < n_frames: |
| mel = self.pad_or_trim(mel,n_frames) |
| mel_segments.append(mel) |
|
|
| if self.low_resource: |
| audio_features = [self._process_mel_segment(mel) for mel in mel_segments] |
| speech_tokens = torch.cat(audio_features, dim=1) |
| else: |
| |
| mel_segments = torch.cat(mel_segments,dim=0) |
| B, _, _ = mel_segments.shape |
| audio_features = self._process_mel_segment(mel_segments) |
| speech_tokens = audio_features.view(1, B * self.compression_size, self.n_state) |
| else: |
| if mels.shape[-1] < n_frames: |
| mels = self.pad_or_trim(mels,n_frames) |
| speech_tokens = self._process_mel_segment(mels) |
| |
| return self.llm_proj(speech_tokens) |
|
|
| def _process_mel_segment(self, mel_segment: torch.Tensor): |
| |
| audio_feature = self.embed_audio(mel_segment) |
| |
| stage_1_token = self.compressor1(x=audio_feature) |
| stage_1_feature = self.stage1(audio_feature.transpose(1, 2)) |
| stage_2_token = self.compressor2(x=stage_1_feature) |
| stage_2_feature = self.stage2(stage_1_feature.transpose(1, 2)) |
| stage_3_token = self.compressor3(x=stage_2_feature) |
|
|
| stage_tokens = torch.cat([stage_1_token, stage_2_token, stage_3_token], dim=1) |
| compressed_tokens = self.compressor(stage_tokens) |
|
|
| |
| h_audio_feature = torch.cat([audio_feature, stage_1_feature, stage_2_feature], dim=1) |
| for block in self.out_attn: |
| compressed_tokens = block(x=compressed_tokens, xa=h_audio_feature) |
| |
| return compressed_tokens |
| |
| def pad_sequence(self, sequences, padding_side='right', padding_value=0.0): |
| max_len = max(seq.size(1) for seq in sequences) |
| output_dims = (len(sequences), max_len) + sequences[0].shape[2:] |
| output = torch.full(output_dims, padding_value, dtype=sequences[0].dtype, device=sequences[0].device) |
| |
| for i, seq in enumerate(sequences): |
| length = seq.size(1) |
| if padding_side == 'right': |
| output[i, :length, ...] = seq |
| else: |
| output[i, -length:, ...] = seq |
| return output |
| |
| def pad_or_trim(self, array, length: int = 480000, *, axis: int = -1): |
| """ |
| Pad or trim the audio array to N_SAMPLES, as expected by the encoder. |
| """ |
| if torch.is_tensor(array): |
| pad_widths = [(0, 0)] * array.ndim |
| pad_widths[axis] = (0, length - array.shape[axis]) |
| array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) |
| else: |
| pad_widths = [(0, 0)] * array.ndim |
| pad_widths[axis] = (0, length - array.shape[axis]) |
| array = np.pad(array, pad_widths) |
| return array |
| |
|
|
| class SymphonyPreTrainedModel(PreTrainedModel): |
| config_class = SymphonyConfig |
| base_model_prefix = "symphony" |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.normal_(module.weight, std=0.02) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
| class SymphonyForConditionalGeneration(SymphonyPreTrainedModel, GenerationMixin): |
| config_class = SymphonyConfig |
| def __init__(self, config: SymphonyConfig): |
| super().__init__(config) |
| |
| |
| self.encoder = SpeechEncoder(config) |
| self.llm = AutoModelForCausalLM.from_config( |
| config.llm_config, |
| trust_remote_code=True |
| ) |
| if self.llm._tied_weights_keys is not None: |
| self._tied_weights_keys = [f"llm.{k}" for k in self.llm._tied_weights_keys] |
|
|
| llm_lora_config = LoraConfig( |
| r=config.lora_r, |
| lora_alpha=config.lora_a, |
| target_modules=config.llm_modules, |
| lora_dropout=0.01, |
| task_type="CAUSAL_LM", |
| ) |
| self.llm = get_peft_model(self.llm, llm_lora_config) |
|
|
| self.tokenizer = AutoTokenizer.from_pretrained(config.llm_config._name_or_path, use_fast=False, trust_remote_code=True) |
| |
| audio_token = ['<|AUDIO|>', '<|audio_bos|>', '<|audio_eos|>'] |
| task_token = ['<|ASR|>', '<|AST|>', '<|SSUM|>', '<|SQQA|>'] |
| language_token = [f"<|{lang.upper()}|>" for lang in LANGUAGES] |
| special_tokens = audio_token + language_token + task_token |
| self.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) |
|
|
| def get_input_embeddings(self) -> nn.Module: |
| """Returns the input embedding layer of the LLM.""" |
| return self.llm.get_input_embeddings() |
|
|
| def set_input_embeddings(self, value: nn.Module): |
| """Sets the input embedding layer of the LLM.""" |
| self.llm.set_input_embeddings(value) |
|
|
| def process_audio(self, audio_array: np.ndarray, sample_rate: int) -> torch.Tensor: |
| audio = torch.tensor(audio_array, dtype=torch.float32) |
| if sample_rate != 16000: |
| resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) |
| audio = resampler(audio) |
| return audio |
| |
| def save_pretrained(self, save_directory, **kwargs): |
| super().save_pretrained(save_directory, **kwargs) |
| if hasattr(self.llm, "save_pretrained"): |
| self.llm.save_pretrained(f"{save_directory}/llm") |
|
|
| def forward( |
| self, |
| audio: List[torch.Tensor], |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| **kwargs |
| ): |
| speech_query, speech_attn_mask = self.encoder(audio) |
|
|
| token_embedding = self.llm.get_input_embeddings() |
| |
| |
| speech_label_len = int(speech_query.shape[1]) |
| speech_labels = torch.full( |
| (speech_query.shape[0], speech_label_len), |
| fill_value=-100, |
| dtype=torch.long, |
| device=speech_query.device |
| ) |
|
|
| audio_token_id = self.tokenizer.convert_tokens_to_ids("<|AUDIO|>") |
| idx = torch.nonzero(input_ids[0] == audio_token_id)[0][0].item() |
| left_token, right_token = input_ids[:,:idx], input_ids[:,idx+1:] |
|
|
| left_label, right_label = labels[:,:idx], labels[:,idx+1:] |
| left_embed = token_embedding(left_token.long()).to(speech_query.device) |
| right_embed = token_embedding(right_token.long()).to(speech_query.device) |
|
|
| left_mask = (left_token != self.tokenizer.pad_token_id).long().to(self.device) |
| right_mask = (right_token != self.tokenizer.pad_token_id).long().to(self.device) |
| speech_attn_mask = (speech_attn_mask.int() <= 0).long() |
|
|
| inputs_embeds = torch.cat([left_embed,speech_query,right_embed],dim=1) |
| labels = torch.cat([left_label,speech_labels,right_label], dim=1).long() |
| attention_mask = torch.cat([ |
| left_mask, speech_attn_mask, right_mask |
| ], dim=1 |
| ) |
|
|
| outputs = self.llm( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| labels=labels, |
| return_dict=True, |
| ) |
| return outputs |
|
|
| def generate(self, input_ids, audio: List[torch.Tensor] = None, **kwargs): |
| token_embedding = self.llm.get_input_embeddings() |
| if audio is not None: |
| speech_query, speech_attn_mask = self.encoder(audio) |
| audio_token_id = self.tokenizer.convert_tokens_to_ids("<|AUDIO|>") |
| idx = torch.nonzero(input_ids[0] == audio_token_id)[0][0].item() |
|
|
| left_embed = token_embedding(input_ids[:, :idx]) |
| right_embed = token_embedding(input_ids[:, idx+1:]) |
|
|
| input_embeds = torch.cat([left_embed, speech_query, right_embed], dim=1) |
| |
| |
| left_mask = torch.ones_like(input_ids[:, :idx]).to(input_ids.device) |
| right_mask = torch.ones_like(input_ids[:, idx+1:]).to(input_ids.device) |
| attention_mask = torch.cat([left_mask, (~speech_attn_mask).long().to(input_ids.device), right_mask], dim=1) |
|
|
| generated_ids = self.llm.generate( |
| inputs_embeds=input_embeds, |
| attention_mask=attention_mask, |
| pad_token_id=self.tokenizer.eos_token_id, |
| **kwargs |
| ) |
| else: |
| input_embeds = token_embedding(input_ids) |
| attention_mask = torch.ones([ |
| input_embeds.size(0), input_embeds.size(1)], dtype=torch.long, device=input_embeds.device |
| ) |
| with self.llm.disable_adapter(): |
| generated_ids = self.llm.generate( |
| inputs_embeds=input_embeds, |
| attention_mask=attention_mask, |
| pad_token_id=self.tokenizer.eos_token_id, |
| **kwargs |
| ) |
| return generated_ids |
|
|
| def pad_embeddings(self, sequences, padding_side='right', padding_value=0.0): |
| """Pads a list of tensors to the same length.""" |
| max_len = max(seq.size(0) for seq in sequences) |
| output_dims = (len(sequences), max_len) + sequences[0].shape[1:] |
| output = torch.full(output_dims, padding_value, dtype=sequences[0].dtype, device=sequences[0].device) |
| |
| for i, seq in enumerate(sequences): |
| length = seq.size(0) |
| if padding_side == 'right': |
| output[i, :length, ...] = seq |
| else: |
| output[i, -length:, ...] = seq |
| return output |
|
|
| |
| AutoConfig.register("symphony", SymphonyConfig) |
| AutoModelForCausalLM.register(SymphonyConfig, SymphonyForConditionalGeneration) |
|
|