| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| | from transformers import PretrainedConfig, Qwen2AudioEncoder, Qwen2AudioForConditionalGeneration |
| |
|
| | from .audio_encoder import AudioTower |
| |
|
| | class Qwen2AudioTower(AudioTower): |
| | def __init__(self, model_name_or_path: str, config: PretrainedConfig): |
| | super().__init__(model_name_or_path, config) |
| | self.audio_tower = Qwen2AudioEncoder.from_pretrained(model_name_or_path, attn_implementation="flash_attention_2") |
| | self.is_loaded = True |
| | self.audio_chunk_unit_duration = 30 |
| | self.audio_chunk_unit_length = 3000 |
| |
|
| | def forward(self, sounds): |
| | if type(sounds) is list: |
| | sound_features = [] |
| | audio_output_lengths = [] |
| | for sound in sounds: |
| | if hasattr(sound, "input_features") or (type(sound) is dict and "input_features" in sound): |
| | sound = sound["input_features"] |
| |
|
| | sound_feature = self.forward_audio_tower_batch(sound) |
| | sound_feature = sound_feature.to(sound.dtype) |
| | sound_features.append(sound_feature) |
| | audio_output_lengths.append(sound_feature.shape[1]) |
| | if len(sound_features) > 0: |
| | sound_features = torch.cat(sound_features, dim=1).squeeze(0) |
| | else: |
| | raise NotImplementedError("Not implemented for this encoder") |
| |
|
| | return sound_features, audio_output_lengths |
| |
|
| |
|
| | def forward_audio_tower_batch(self, inp): |
| | """ |
| | Process long audio input by splitting into fixed-size chunks (30 seconds), |
| | padding if needed, batching them together, and processing through the audio tower. |
| | |
| | Args: |
| | inp: Tensor of shape (batch_size, n_mels, seq_len) |
| | |
| | Returns: |
| | Tensor of shape (batch_size, num_chunks * chunk_seq_len, hidden_size) |
| | """ |
| | batch_size, n_mels, seq_len = inp.shape |
| | chunk_length = self.audio_chunk_unit_length |
| | num_chunks = (seq_len + chunk_length - 1) // chunk_length |
| |
|
| | padded_chunks = [] |
| |
|
| | for i in range(num_chunks): |
| | start_idx = i * chunk_length |
| | end_idx = min(start_idx + chunk_length, seq_len) |
| |
|
| | |
| | chunk = inp[:, :, start_idx:end_idx] |
| | if chunk.shape[2] < chunk_length: |
| | pad_len = chunk_length - chunk.shape[2] |
| | chunk = torch.nn.functional.pad(chunk, (0, pad_len), mode='constant', value=0) |
| |
|
| | padded_chunks.append(chunk) |
| |
|
| | |
| | all_chunks = torch.cat(padded_chunks, dim=0).reshape(batch_size * num_chunks, n_mels, chunk_length) |
| |
|
| | |
| | chunk_outputs = self.audio_tower(all_chunks) |
| | hidden_states = chunk_outputs.last_hidden_state |
| |
|
| | |
| | _, chunk_seq_len, hidden_size = hidden_states.shape |
| | hidden_states = hidden_states.reshape(batch_size, num_chunks * chunk_seq_len, hidden_size) |
| |
|
| | return hidden_states |
| |
|