# SYMPHONY/modeling_symphony.py import torch import torch.nn as nn import torchaudio import torch.nn.functional as F import numpy as np import whisper from torch import Tensor from einops import rearrange from typing import Optional, List from peft import ( LoraConfig, get_peft_model ) from transformers import ( AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, GenerationMixin, AutoConfig ) from .modeling_whisper import AudioEncoder from .configuration_symphony import SymphonyConfig # Check for scaled_dot_product_attention availability try: from torch.nn.functional import scaled_dot_product_attention SDPA_AVAILABLE = True except (ImportError, RuntimeError, OSError): scaled_dot_product_attention = None SDPA_AVAILABLE = False LANGUAGES = { "en": "english", "ko": "korean" } def set_trainable_parameters(module, requires_grad=False): for param in module.parameters(): param.requires_grad = requires_grad module._requires_grad = requires_grad # --- Helper Modules (Compressor, MHSA, Attention, Downsampler) --- class Compressor(nn.Module): def __init__(self, embed_dim, num_heads, num_query, n_ctx): super().__init__() self.num_heads = num_heads self.head_dims = embed_dim // num_heads self.n_ctx = n_ctx self.query = nn.Parameter(torch.randn(1, num_query, embed_dim)) nn.init.normal_(self.query, mean=0.0, std=0.02) self.q_ln = nn.LayerNorm(embed_dim, eps=1e-5) self.kv_ln = nn.LayerNorm(embed_dim, eps=1e-5) self.kv_proj = nn.Identity() self.out_proj = nn.Linear(embed_dim, embed_dim) self.register_buffer("q_pos_embeds", self.sinusoids(num_query, embed_dim)) self.register_buffer("kv_pos_embeds", self.sinusoids(n_ctx, embed_dim)) self.init_weights() def init_weights(self): nn.init.constant_(self.q_ln.bias, 0) nn.init.constant_(self.q_ln.weight, 1.0) nn.init.constant_(self.kv_ln.bias, 0) nn.init.constant_(self.kv_ln.weight, 1.0) def sinusoids(self, length, channels, max_timescale=10000): assert channels % 2 == 0 log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) def forward(self, x: Tensor): q = self.q_ln(self.query.to(x.device)) x = self.kv_ln(self.kv_proj(x)) q = rearrange(q + self.q_pos_embeds, 'b l (h d) -> b h l d', h=self.num_heads, d=self.head_dims) k = rearrange(x + self.kv_pos_embeds, 'b l (h d) -> b h l d', h=self.num_heads, d=self.head_dims) v = rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads, d=self.head_dims) attn = scaled_dot_product_attention(q, k, v) attn = rearrange(attn, 'b h l d -> b l (h d)') x = self.out_proj(attn) return x class MHSA(nn.Module): def __init__(self, embed_dim, num_heads): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dims = embed_dim // num_heads self.q = nn.Linear(embed_dim, embed_dim, bias=True) self.k = nn.Linear(embed_dim, embed_dim, bias=False) self.v = nn.Linear(embed_dim, embed_dim, bias=True) self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) def forward(self, x, xa=None, mask=None): q = self.q(x) k = self.k(x if xa is None else xa) v = self.v(x if xa is None else xa) q = rearrange(q, 'b l (h d) -> b h l d', h=self.num_heads, d=self.head_dims) k = rearrange(k, 'b l (h d) -> b h l d', h=self.num_heads, d=self.head_dims) v = rearrange(v, 'b l (h d) -> b h l d', h=self.num_heads, d=self.head_dims) attn = scaled_dot_product_attention(q, k, v, is_causal=mask is not None) attn = rearrange(attn, 'b h l d -> b l (h d)') out = self.out_proj(attn) return out class Attention(nn.Module): def __init__(self, embed_dim, num_heads): super().__init__() self.attn = MHSA(embed_dim=embed_dim, num_heads=num_heads) self.cross_attn = MHSA(embed_dim=embed_dim, num_heads=num_heads) self.norm1 = nn.LayerNorm(embed_dim, eps=1e-5) self.norm2 = nn.LayerNorm(embed_dim, eps=1e-5) def forward(self, x: Tensor, xa: Optional[Tensor] = None): x = x + self.attn(self.norm1(x)) x = x + self.cross_attn(x=self.norm2(x), xa=xa) return x class Downsampler(nn.Module): def __init__(self, embed_dim: int): super().__init__() self.conv1 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=1) self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) self.ln_post = nn.LayerNorm(embed_dim, eps=1e-5) def forward(self, x: Tensor): x = F.gelu(self.conv1(x)) x = F.gelu(self.conv2(x)) x = x.permute(0, 2, 1) x = self.ln_post(x) return x # --- Speech Encoder Module --- class SpeechEncoder(nn.Module): def __init__(self, config: SymphonyConfig): super().__init__() # Initialize the Whisper encoder from its specific sub-configuration self._device = 'cuda' if torch.cuda.is_available() else 'cpu' self.whisper = AudioEncoder( n_mels=config.encoder_config.n_mels, n_ctx=config.encoder_config.n_ctx, n_state=config.encoder_config.n_state, n_head=config.encoder_config.n_head, n_layer=config.encoder_config.n_layer ) self.n_mels = config.encoder_config.n_mels # Freeze the Whisper encoder as it's not trained for param in self.whisper.parameters(): param.requires_grad = False # Initialize the projection layer to match the LLM's hidden dimension self.llm_proj = nn.Linear(config.encoder_config.n_state, config.llm_config.hidden_size) # Initialize the hierarchical compressors using parameters from the config num_heads = config.encoder_config.n_head stage_tokens = config.encoder_config.stage_tokens self.compression_size = config.encoder_config.compression_size self.n_state = config.encoder_config.n_state self.low_resource = config.low_resource self.compressor1 = Compressor(config.encoder_config.n_state, num_heads, stage_tokens[0], 1500) self.stage1 = Downsampler(config.encoder_config.n_state) self.compressor2 = Compressor(config.encoder_config.n_state, num_heads, stage_tokens[1], 750) self.stage2 = Downsampler(config.encoder_config.n_state) self.compressor3 = Compressor(config.encoder_config.n_state, num_heads, stage_tokens[2], 375) self.compressor = Compressor(config.encoder_config.n_state, num_heads, self.compression_size, sum(stage_tokens)) self.out_attn = nn.ModuleList([ Attention(config.encoder_config.n_state, num_heads) for _ in range(2) ]) def embed_audio(self, mel: torch.Tensor): output = self.whisper(mel) # return output.last_hidden_state return output def forward(self, wav_list: List[torch.Tensor]): if len(wav_list) <= 1: speech_features = self.process_audio_for_llm_input(wav_list) speech_attn_mask = torch.zeros(1,speech_features.size(1)).bool().to(speech_features.device) return speech_features, speech_attn_mask else: speech_features = [] speech_attn_mask = [] for wav in wav_list: speech_feature = self.process_audio_for_llm_input(wav) speech_features.append(speech_feature) speech_attn_mask.append(torch.zeros(1,speech_feature.size(1)).bool()) speech_features = self.pad_sequence(speech_features,padding_side='right',padding_value=0.0) speech_attn_mask = self.pad_sequence(speech_attn_mask,padding_side='right',padding_value=True).squeeze(1) return speech_features, speech_attn_mask def process_audio_for_llm_input(self, wav: torch.Tensor): n_frames = 3000 min_length = 16000 wav = wav.flatten() if wav.shape[0] < min_length: wav = F.pad(wav, (0, min_length - wav.shape[0])) mels = whisper.log_mel_spectrogram(wav, n_mels=self.n_mels).unsqueeze(0).to(self._device) if mels.shape[-1] > n_frames: mel_segments = [] # Segment and process long audio for i in range(0, mels.shape[-1], n_frames): mel = mels[:,:,i:i+n_frames] if mel.shape[-1] < n_frames: mel = self.pad_or_trim(mel,n_frames) mel_segments.append(mel) if self.low_resource: audio_features = [self._process_mel_segment(mel) for mel in mel_segments] speech_tokens = torch.cat(audio_features, dim=1) else: # Batch Inference Mode mel_segments = torch.cat(mel_segments,dim=0) B, _, _ = mel_segments.shape audio_features = self._process_mel_segment(mel_segments) speech_tokens = audio_features.view(1, B * self.compression_size, self.n_state) else: if mels.shape[-1] < n_frames: mels = self.pad_or_trim(mels,n_frames) speech_tokens = self._process_mel_segment(mels) return self.llm_proj(speech_tokens) def _process_mel_segment(self, mel_segment: torch.Tensor): # Feature extraction and hierarchical compression audio_feature = self.embed_audio(mel_segment) stage_1_token = self.compressor1(x=audio_feature) stage_1_feature = self.stage1(audio_feature.transpose(1, 2)) stage_2_token = self.compressor2(x=stage_1_feature) stage_2_feature = self.stage2(stage_1_feature.transpose(1, 2)) stage_3_token = self.compressor3(x=stage_2_feature) stage_tokens = torch.cat([stage_1_token, stage_2_token, stage_3_token], dim=1) compressed_tokens = self.compressor(stage_tokens) # Cross-attention with hierarchical features h_audio_feature = torch.cat([audio_feature, stage_1_feature, stage_2_feature], dim=1) for block in self.out_attn: compressed_tokens = block(x=compressed_tokens, xa=h_audio_feature) return compressed_tokens def pad_sequence(self, sequences, padding_side='right', padding_value=0.0): max_len = max(seq.size(1) for seq in sequences) output_dims = (len(sequences), max_len) + sequences[0].shape[2:] output = torch.full(output_dims, padding_value, dtype=sequences[0].dtype, device=sequences[0].device) for i, seq in enumerate(sequences): length = seq.size(1) if padding_side == 'right': output[i, :length, ...] = seq else: output[i, -length:, ...] = seq return output def pad_or_trim(self, array, length: int = 480000, *, axis: int = -1): """ Pad or trim the audio array to N_SAMPLES, as expected by the encoder. """ if torch.is_tensor(array): pad_widths = [(0, 0)] * array.ndim pad_widths[axis] = (0, length - array.shape[axis]) array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) else: pad_widths = [(0, 0)] * array.ndim pad_widths[axis] = (0, length - array.shape[axis]) array = np.pad(array, pad_widths) return array # --- Main Model Class --- class SymphonyPreTrainedModel(PreTrainedModel): config_class = SymphonyConfig base_model_prefix = "symphony" def _init_weights(self, module): if isinstance(module, nn.Linear): nn.init.normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) class SymphonyForConditionalGeneration(SymphonyPreTrainedModel, GenerationMixin): config_class = SymphonyConfig def __init__(self, config: SymphonyConfig): super().__init__(config) # Initialize the two main components using their respective sub-configs self.encoder = SpeechEncoder(config) self.llm = AutoModelForCausalLM.from_config( config.llm_config, trust_remote_code=True ) if self.llm._tied_weights_keys is not None: self._tied_weights_keys = [f"llm.{k}" for k in self.llm._tied_weights_keys] llm_lora_config = LoraConfig( r=config.lora_r, lora_alpha=config.lora_a, target_modules=config.llm_modules, lora_dropout=0.01, task_type="CAUSAL_LM", ) self.llm = get_peft_model(self.llm, llm_lora_config) self.tokenizer = AutoTokenizer.from_pretrained(config.llm_config._name_or_path, use_fast=False, trust_remote_code=True) # Add special tokens audio_token = ['<|AUDIO|>', '<|audio_bos|>', '<|audio_eos|>'] task_token = ['<|ASR|>', '<|AST|>', '<|SSUM|>', '<|SQQA|>'] language_token = [f"<|{lang.upper()}|>" for lang in LANGUAGES] special_tokens = audio_token + language_token + task_token self.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) def get_input_embeddings(self) -> nn.Module: """Returns the input embedding layer of the LLM.""" return self.llm.get_input_embeddings() def set_input_embeddings(self, value: nn.Module): """Sets the input embedding layer of the LLM.""" self.llm.set_input_embeddings(value) def process_audio(self, audio_array: np.ndarray, sample_rate: int) -> torch.Tensor: audio = torch.tensor(audio_array, dtype=torch.float32) if sample_rate != 16000: resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) audio = resampler(audio) return audio def save_pretrained(self, save_directory, **kwargs): super().save_pretrained(save_directory, **kwargs) if hasattr(self.llm, "save_pretrained"): self.llm.save_pretrained(f"{save_directory}/llm") def forward( self, audio: List[torch.Tensor], input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, **kwargs ): speech_query, speech_attn_mask = self.encoder(audio) token_embedding = self.llm.get_input_embeddings() # Create speech labels (-100 to ignore in loss calculation) speech_label_len = int(speech_query.shape[1]) speech_labels = torch.full( (speech_query.shape[0], speech_label_len), fill_value=-100, dtype=torch.long, device=speech_query.device ) audio_token_id = self.tokenizer.convert_tokens_to_ids("<|AUDIO|>") idx = torch.nonzero(input_ids[0] == audio_token_id)[0][0].item() left_token, right_token = input_ids[:,:idx], input_ids[:,idx+1:] left_label, right_label = labels[:,:idx], labels[:,idx+1:] left_embed = token_embedding(left_token.long()).to(speech_query.device) right_embed = token_embedding(right_token.long()).to(speech_query.device) left_mask = (left_token != self.tokenizer.pad_token_id).long().to(self.device) right_mask = (right_token != self.tokenizer.pad_token_id).long().to(self.device) speech_attn_mask = (speech_attn_mask.int() <= 0).long() inputs_embeds = torch.cat([left_embed,speech_query,right_embed],dim=1) labels = torch.cat([left_label,speech_labels,right_label], dim=1).long() attention_mask = torch.cat([ left_mask, speech_attn_mask, right_mask ], dim=1 ) outputs = self.llm( inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels, return_dict=True, ) return outputs def generate(self, input_ids, audio: List[torch.Tensor] = None, **kwargs): token_embedding = self.llm.get_input_embeddings() if audio is not None: speech_query, speech_attn_mask = self.encoder(audio) audio_token_id = self.tokenizer.convert_tokens_to_ids("<|AUDIO|>") idx = torch.nonzero(input_ids[0] == audio_token_id)[0][0].item() left_embed = token_embedding(input_ids[:, :idx]) right_embed = token_embedding(input_ids[:, idx+1:]) input_embeds = torch.cat([left_embed, speech_query, right_embed], dim=1) # Create attention mask left_mask = torch.ones_like(input_ids[:, :idx]).to(input_ids.device) right_mask = torch.ones_like(input_ids[:, idx+1:]).to(input_ids.device) attention_mask = torch.cat([left_mask, (~speech_attn_mask).long().to(input_ids.device), right_mask], dim=1) generated_ids = self.llm.generate( inputs_embeds=input_embeds, attention_mask=attention_mask, pad_token_id=self.tokenizer.eos_token_id, **kwargs ) else: input_embeds = token_embedding(input_ids) attention_mask = torch.ones([ input_embeds.size(0), input_embeds.size(1)], dtype=torch.long, device=input_embeds.device ) with self.llm.disable_adapter(): generated_ids = self.llm.generate( inputs_embeds=input_embeds, attention_mask=attention_mask, pad_token_id=self.tokenizer.eos_token_id, **kwargs ) return generated_ids def pad_embeddings(self, sequences, padding_side='right', padding_value=0.0): """Pads a list of tensors to the same length.""" max_len = max(seq.size(0) for seq in sequences) output_dims = (len(sequences), max_len) + sequences[0].shape[1:] output = torch.full(output_dims, padding_value, dtype=sequences[0].dtype, device=sequences[0].device) for i, seq in enumerate(sequences): length = seq.size(0) if padding_side == 'right': output[i, :length, ...] = seq else: output[i, -length:, ...] = seq return output # Register the model with AutoModelForCausalLM AutoConfig.register("symphony", SymphonyConfig) AutoModelForCausalLM.register(SymphonyConfig, SymphonyForConditionalGeneration)