|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchaudio |
|
|
import torch.nn.functional as F |
|
|
import numpy as np |
|
|
import whisper |
|
|
from torch import Tensor |
|
|
from einops import rearrange |
|
|
from typing import Optional, List |
|
|
|
|
|
from peft import ( |
|
|
LoraConfig, |
|
|
get_peft_model |
|
|
) |
|
|
from transformers import ( |
|
|
AutoModelForCausalLM, |
|
|
AutoTokenizer, |
|
|
PreTrainedModel, |
|
|
GenerationMixin, |
|
|
AutoConfig |
|
|
) |
|
|
from .modeling_whisper import AudioEncoder |
|
|
from .configuration_symphony import SymphonyConfig |
|
|
|
|
|
try: |
|
|
from torch.nn.functional import scaled_dot_product_attention |
|
|
SDPA_AVAILABLE = True |
|
|
except (ImportError, RuntimeError, OSError): |
|
|
scaled_dot_product_attention = None |
|
|
SDPA_AVAILABLE = False |
|
|
|
|
|
LANGUAGES = { |
|
|
"en": "english", |
|
|
"ko": "korean" |
|
|
} |
|
|
|
|
|
def set_trainable_parameters(module, requires_grad=False): |
|
|
for param in module.parameters(): |
|
|
param.requires_grad = requires_grad |
|
|
module._requires_grad = requires_grad |
|
|
|
|
|
|
|
|
|
|
|
class Compressor(nn.Module): |
|
|
def __init__(self, embed_dim, num_heads, num_query, n_ctx): |
|
|
super().__init__() |
|
|
self.num_heads = num_heads |
|
|
self.head_dims = embed_dim // num_heads |
|
|
self.n_ctx = n_ctx |
|
|
|
|
|
self.query = nn.Parameter(torch.randn(1, num_query, embed_dim)) |
|
|
nn.init.normal_(self.query, mean=0.0, std=0.02) |
|
|
|
|
|
self.q_ln = nn.LayerNorm(embed_dim, eps=1e-5) |
|
|
self.kv_ln = nn.LayerNorm(embed_dim, eps=1e-5) |
|
|
|
|
|
self.kv_proj = nn.Identity() |
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim) |
|
|
|
|
|
self.register_buffer("q_pos_embeds", self.sinusoids(num_query, embed_dim)) |
|
|
self.register_buffer("kv_pos_embeds", self.sinusoids(n_ctx, embed_dim)) |
|
|
|
|
|
self.init_weights() |
|
|
|
|
|
def init_weights(self): |
|
|
nn.init.constant_(self.q_ln.bias, 0) |
|
|
nn.init.constant_(self.q_ln.weight, 1.0) |
|
|
nn.init.constant_(self.kv_ln.bias, 0) |
|
|
nn.init.constant_(self.kv_ln.weight, 1.0) |
|
|
|
|
|
def sinusoids(self, length, channels, max_timescale=10000): |
|
|
assert channels % 2 == 0 |
|
|
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) |
|
|
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) |
|
|
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] |
|
|
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) |
|
|
|
|
|
def forward(self, x: Tensor): |
|
|
q = self.q_ln(self.query.to(x.device)) |
|
|
x = self.kv_ln(self.kv_proj(x)) |
|
|
|
|
|
q = rearrange(q + self.q_pos_embeds, 'b l (h d) -> b h l d', h=self.num_heads, d=self.head_dims) |
|
|
k = rearrange(x + self.kv_pos_embeds, 'b l (h d) -> b h l d', h=self.num_heads, d=self.head_dims) |
|
|
v = rearrange(x, 'b l (h d) -> b h l d', h=self.num_heads, d=self.head_dims) |
|
|
|
|
|
attn = scaled_dot_product_attention(q, k, v) |
|
|
attn = rearrange(attn, 'b h l d -> b l (h d)') |
|
|
x = self.out_proj(attn) |
|
|
return x |
|
|
|
|
|
class MHSA(nn.Module): |
|
|
def __init__(self, embed_dim, num_heads): |
|
|
super().__init__() |
|
|
self.embed_dim = embed_dim |
|
|
self.num_heads = num_heads |
|
|
self.head_dims = embed_dim // num_heads |
|
|
self.q = nn.Linear(embed_dim, embed_dim, bias=True) |
|
|
self.k = nn.Linear(embed_dim, embed_dim, bias=False) |
|
|
self.v = nn.Linear(embed_dim, embed_dim, bias=True) |
|
|
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True) |
|
|
|
|
|
def forward(self, x, xa=None, mask=None): |
|
|
q = self.q(x) |
|
|
k = self.k(x if xa is None else xa) |
|
|
v = self.v(x if xa is None else xa) |
|
|
|
|
|
q = rearrange(q, 'b l (h d) -> b h l d', h=self.num_heads, d=self.head_dims) |
|
|
k = rearrange(k, 'b l (h d) -> b h l d', h=self.num_heads, d=self.head_dims) |
|
|
v = rearrange(v, 'b l (h d) -> b h l d', h=self.num_heads, d=self.head_dims) |
|
|
|
|
|
attn = scaled_dot_product_attention(q, k, v, is_causal=mask is not None) |
|
|
attn = rearrange(attn, 'b h l d -> b l (h d)') |
|
|
|
|
|
out = self.out_proj(attn) |
|
|
return out |
|
|
|
|
|
class Attention(nn.Module): |
|
|
def __init__(self, embed_dim, num_heads): |
|
|
super().__init__() |
|
|
self.attn = MHSA(embed_dim=embed_dim, num_heads=num_heads) |
|
|
self.cross_attn = MHSA(embed_dim=embed_dim, num_heads=num_heads) |
|
|
self.norm1 = nn.LayerNorm(embed_dim, eps=1e-5) |
|
|
self.norm2 = nn.LayerNorm(embed_dim, eps=1e-5) |
|
|
|
|
|
def forward(self, x: Tensor, xa: Optional[Tensor] = None): |
|
|
x = x + self.attn(self.norm1(x)) |
|
|
x = x + self.cross_attn(x=self.norm2(x), xa=xa) |
|
|
return x |
|
|
|
|
|
class Downsampler(nn.Module): |
|
|
def __init__(self, embed_dim: int): |
|
|
super().__init__() |
|
|
self.conv1 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, padding=1) |
|
|
self.conv2 = nn.Conv1d(embed_dim, embed_dim, kernel_size=3, stride=2, padding=1) |
|
|
self.ln_post = nn.LayerNorm(embed_dim, eps=1e-5) |
|
|
|
|
|
def forward(self, x: Tensor): |
|
|
x = F.gelu(self.conv1(x)) |
|
|
x = F.gelu(self.conv2(x)) |
|
|
x = x.permute(0, 2, 1) |
|
|
x = self.ln_post(x) |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
class SpeechEncoder(nn.Module): |
|
|
def __init__(self, config: SymphonyConfig): |
|
|
super().__init__() |
|
|
|
|
|
self._device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
self.whisper = AudioEncoder( |
|
|
n_mels=config.encoder_config.n_mels, |
|
|
n_ctx=config.encoder_config.n_ctx, |
|
|
n_state=config.encoder_config.n_state, |
|
|
n_head=config.encoder_config.n_head, |
|
|
n_layer=config.encoder_config.n_layer |
|
|
) |
|
|
self.n_mels = config.encoder_config.n_mels |
|
|
|
|
|
for param in self.whisper.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
self.llm_proj = nn.Linear(config.encoder_config.n_state, config.llm_config.hidden_size) |
|
|
|
|
|
|
|
|
num_heads = config.encoder_config.n_head |
|
|
stage_tokens = config.encoder_config.stage_tokens |
|
|
self.compression_size = config.encoder_config.compression_size |
|
|
self.n_state = config.encoder_config.n_state |
|
|
self.low_resource = config.low_resource |
|
|
|
|
|
self.compressor1 = Compressor(config.encoder_config.n_state, num_heads, stage_tokens[0], 1500) |
|
|
self.stage1 = Downsampler(config.encoder_config.n_state) |
|
|
self.compressor2 = Compressor(config.encoder_config.n_state, num_heads, stage_tokens[1], 750) |
|
|
self.stage2 = Downsampler(config.encoder_config.n_state) |
|
|
self.compressor3 = Compressor(config.encoder_config.n_state, num_heads, stage_tokens[2], 375) |
|
|
self.compressor = Compressor(config.encoder_config.n_state, num_heads, self.compression_size, sum(stage_tokens)) |
|
|
|
|
|
self.out_attn = nn.ModuleList([ |
|
|
Attention(config.encoder_config.n_state, num_heads) for _ in range(2) |
|
|
]) |
|
|
|
|
|
def embed_audio(self, mel: torch.Tensor): |
|
|
output = self.whisper(mel) |
|
|
|
|
|
return output |
|
|
|
|
|
def forward(self, wav_list: List[torch.Tensor]): |
|
|
if len(wav_list) <= 1: |
|
|
speech_features = self.process_audio_for_llm_input(wav_list) |
|
|
speech_attn_mask = torch.zeros(1,speech_features.size(1)).bool().to(speech_features.device) |
|
|
return speech_features, speech_attn_mask |
|
|
else: |
|
|
speech_features = [] |
|
|
speech_attn_mask = [] |
|
|
for wav in wav_list: |
|
|
speech_feature = self.process_audio_for_llm_input(wav) |
|
|
speech_features.append(speech_feature) |
|
|
speech_attn_mask.append(torch.zeros(1,speech_feature.size(1)).bool()) |
|
|
|
|
|
speech_features = self.pad_sequence(speech_features,padding_side='right',padding_value=0.0) |
|
|
speech_attn_mask = self.pad_sequence(speech_attn_mask,padding_side='right',padding_value=True).squeeze(1) |
|
|
return speech_features, speech_attn_mask |
|
|
|
|
|
def process_audio_for_llm_input(self, wav: torch.Tensor): |
|
|
n_frames = 3000 |
|
|
min_length = 16000 |
|
|
wav = wav.flatten() |
|
|
|
|
|
if wav.shape[0] < min_length: |
|
|
wav = F.pad(wav, (0, min_length - wav.shape[0])) |
|
|
|
|
|
mels = whisper.log_mel_spectrogram(wav, n_mels=self.n_mels).unsqueeze(0).to(self._device) |
|
|
if mels.shape[-1] > n_frames: |
|
|
mel_segments = [] |
|
|
|
|
|
for i in range(0, mels.shape[-1], n_frames): |
|
|
mel = mels[:,:,i:i+n_frames] |
|
|
if mel.shape[-1] < n_frames: |
|
|
mel = self.pad_or_trim(mel,n_frames) |
|
|
mel_segments.append(mel) |
|
|
|
|
|
if self.low_resource: |
|
|
audio_features = [self._process_mel_segment(mel) for mel in mel_segments] |
|
|
speech_tokens = torch.cat(audio_features, dim=1) |
|
|
else: |
|
|
|
|
|
mel_segments = torch.cat(mel_segments,dim=0) |
|
|
B, _, _ = mel_segments.shape |
|
|
audio_features = self._process_mel_segment(mel_segments) |
|
|
speech_tokens = audio_features.view(1, B * self.compression_size, self.n_state) |
|
|
else: |
|
|
if mels.shape[-1] < n_frames: |
|
|
mels = self.pad_or_trim(mels,n_frames) |
|
|
speech_tokens = self._process_mel_segment(mels) |
|
|
|
|
|
return self.llm_proj(speech_tokens) |
|
|
|
|
|
def _process_mel_segment(self, mel_segment: torch.Tensor): |
|
|
|
|
|
audio_feature = self.embed_audio(mel_segment) |
|
|
|
|
|
stage_1_token = self.compressor1(x=audio_feature) |
|
|
stage_1_feature = self.stage1(audio_feature.transpose(1, 2)) |
|
|
stage_2_token = self.compressor2(x=stage_1_feature) |
|
|
stage_2_feature = self.stage2(stage_1_feature.transpose(1, 2)) |
|
|
stage_3_token = self.compressor3(x=stage_2_feature) |
|
|
|
|
|
stage_tokens = torch.cat([stage_1_token, stage_2_token, stage_3_token], dim=1) |
|
|
compressed_tokens = self.compressor(stage_tokens) |
|
|
|
|
|
|
|
|
h_audio_feature = torch.cat([audio_feature, stage_1_feature, stage_2_feature], dim=1) |
|
|
for block in self.out_attn: |
|
|
compressed_tokens = block(x=compressed_tokens, xa=h_audio_feature) |
|
|
|
|
|
return compressed_tokens |
|
|
|
|
|
def pad_sequence(self, sequences, padding_side='right', padding_value=0.0): |
|
|
max_len = max(seq.size(1) for seq in sequences) |
|
|
output_dims = (len(sequences), max_len) + sequences[0].shape[2:] |
|
|
output = torch.full(output_dims, padding_value, dtype=sequences[0].dtype, device=sequences[0].device) |
|
|
|
|
|
for i, seq in enumerate(sequences): |
|
|
length = seq.size(1) |
|
|
if padding_side == 'right': |
|
|
output[i, :length, ...] = seq |
|
|
else: |
|
|
output[i, -length:, ...] = seq |
|
|
return output |
|
|
|
|
|
def pad_or_trim(self, array, length: int = 480000, *, axis: int = -1): |
|
|
""" |
|
|
Pad or trim the audio array to N_SAMPLES, as expected by the encoder. |
|
|
""" |
|
|
if torch.is_tensor(array): |
|
|
pad_widths = [(0, 0)] * array.ndim |
|
|
pad_widths[axis] = (0, length - array.shape[axis]) |
|
|
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes]) |
|
|
else: |
|
|
pad_widths = [(0, 0)] * array.ndim |
|
|
pad_widths[axis] = (0, length - array.shape[axis]) |
|
|
array = np.pad(array, pad_widths) |
|
|
return array |
|
|
|
|
|
|
|
|
class SymphonyPreTrainedModel(PreTrainedModel): |
|
|
config_class = SymphonyConfig |
|
|
base_model_prefix = "symphony" |
|
|
|
|
|
def _init_weights(self, module): |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.normal_(module.weight, std=0.02) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
class SymphonyForConditionalGeneration(SymphonyPreTrainedModel, GenerationMixin): |
|
|
config_class = SymphonyConfig |
|
|
def __init__(self, config: SymphonyConfig): |
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
self.encoder = SpeechEncoder(config) |
|
|
self.llm = AutoModelForCausalLM.from_config( |
|
|
config.llm_config, |
|
|
trust_remote_code=True |
|
|
) |
|
|
if self.llm._tied_weights_keys is not None: |
|
|
self._tied_weights_keys = [f"llm.{k}" for k in self.llm._tied_weights_keys] |
|
|
|
|
|
llm_lora_config = LoraConfig( |
|
|
r=config.lora_r, |
|
|
lora_alpha=config.lora_a, |
|
|
target_modules=config.llm_modules, |
|
|
lora_dropout=0.01, |
|
|
task_type="CAUSAL_LM", |
|
|
) |
|
|
self.llm = get_peft_model(self.llm, llm_lora_config) |
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(config.llm_config._name_or_path, use_fast=False, trust_remote_code=True) |
|
|
|
|
|
audio_token = ['<|AUDIO|>', '<|audio_bos|>', '<|audio_eos|>'] |
|
|
task_token = ['<|ASR|>', '<|AST|>', '<|SSUM|>', '<|SQQA|>'] |
|
|
language_token = [f"<|{lang.upper()}|>" for lang in LANGUAGES] |
|
|
special_tokens = audio_token + language_token + task_token |
|
|
self.tokenizer.add_special_tokens({"additional_special_tokens": special_tokens}) |
|
|
|
|
|
def get_input_embeddings(self) -> nn.Module: |
|
|
"""Returns the input embedding layer of the LLM.""" |
|
|
return self.llm.get_input_embeddings() |
|
|
|
|
|
def set_input_embeddings(self, value: nn.Module): |
|
|
"""Sets the input embedding layer of the LLM.""" |
|
|
self.llm.set_input_embeddings(value) |
|
|
|
|
|
def process_audio(self, audio_array: np.ndarray, sample_rate: int) -> torch.Tensor: |
|
|
audio = torch.tensor(audio_array, dtype=torch.float32) |
|
|
if sample_rate != 16000: |
|
|
resampler = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=16000) |
|
|
audio = resampler(audio) |
|
|
return audio |
|
|
|
|
|
def save_pretrained(self, save_directory, **kwargs): |
|
|
super().save_pretrained(save_directory, **kwargs) |
|
|
if hasattr(self.llm, "save_pretrained"): |
|
|
self.llm.save_pretrained(f"{save_directory}/llm") |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
audio: List[torch.Tensor], |
|
|
input_ids: torch.LongTensor = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
|
labels: Optional[torch.LongTensor] = None, |
|
|
**kwargs |
|
|
): |
|
|
speech_query, speech_attn_mask = self.encoder(audio) |
|
|
|
|
|
token_embedding = self.llm.get_input_embeddings() |
|
|
|
|
|
|
|
|
speech_label_len = int(speech_query.shape[1]) |
|
|
speech_labels = torch.full( |
|
|
(speech_query.shape[0], speech_label_len), |
|
|
fill_value=-100, |
|
|
dtype=torch.long, |
|
|
device=speech_query.device |
|
|
) |
|
|
|
|
|
audio_token_id = self.tokenizer.convert_tokens_to_ids("<|AUDIO|>") |
|
|
idx = torch.nonzero(input_ids[0] == audio_token_id)[0][0].item() |
|
|
left_token, right_token = input_ids[:,:idx], input_ids[:,idx+1:] |
|
|
|
|
|
left_label, right_label = labels[:,:idx], labels[:,idx+1:] |
|
|
left_embed = token_embedding(left_token.long()).to(speech_query.device) |
|
|
right_embed = token_embedding(right_token.long()).to(speech_query.device) |
|
|
|
|
|
left_mask = (left_token != self.tokenizer.pad_token_id).long().to(self.device) |
|
|
right_mask = (right_token != self.tokenizer.pad_token_id).long().to(self.device) |
|
|
speech_attn_mask = (speech_attn_mask.int() <= 0).long() |
|
|
|
|
|
inputs_embeds = torch.cat([left_embed,speech_query,right_embed],dim=1) |
|
|
labels = torch.cat([left_label,speech_labels,right_label], dim=1).long() |
|
|
attention_mask = torch.cat([ |
|
|
left_mask, speech_attn_mask, right_mask |
|
|
], dim=1 |
|
|
) |
|
|
|
|
|
outputs = self.llm( |
|
|
inputs_embeds=inputs_embeds, |
|
|
attention_mask=attention_mask, |
|
|
labels=labels, |
|
|
return_dict=True, |
|
|
) |
|
|
return outputs |
|
|
|
|
|
def generate(self, input_ids, audio: List[torch.Tensor] = None, **kwargs): |
|
|
token_embedding = self.llm.get_input_embeddings() |
|
|
if audio is not None: |
|
|
speech_query, speech_attn_mask = self.encoder(audio) |
|
|
audio_token_id = self.tokenizer.convert_tokens_to_ids("<|AUDIO|>") |
|
|
idx = torch.nonzero(input_ids[0] == audio_token_id)[0][0].item() |
|
|
|
|
|
left_embed = token_embedding(input_ids[:, :idx]) |
|
|
right_embed = token_embedding(input_ids[:, idx+1:]) |
|
|
|
|
|
input_embeds = torch.cat([left_embed, speech_query, right_embed], dim=1) |
|
|
|
|
|
|
|
|
left_mask = torch.ones_like(input_ids[:, :idx]).to(input_ids.device) |
|
|
right_mask = torch.ones_like(input_ids[:, idx+1:]).to(input_ids.device) |
|
|
attention_mask = torch.cat([left_mask, (~speech_attn_mask).long().to(input_ids.device), right_mask], dim=1) |
|
|
|
|
|
generated_ids = self.llm.generate( |
|
|
inputs_embeds=input_embeds, |
|
|
attention_mask=attention_mask, |
|
|
pad_token_id=self.tokenizer.eos_token_id, |
|
|
**kwargs |
|
|
) |
|
|
else: |
|
|
input_embeds = token_embedding(input_ids) |
|
|
attention_mask = torch.ones([ |
|
|
input_embeds.size(0), input_embeds.size(1)], dtype=torch.long, device=input_embeds.device |
|
|
) |
|
|
with self.llm.disable_adapter(): |
|
|
generated_ids = self.llm.generate( |
|
|
inputs_embeds=input_embeds, |
|
|
attention_mask=attention_mask, |
|
|
pad_token_id=self.tokenizer.eos_token_id, |
|
|
**kwargs |
|
|
) |
|
|
return generated_ids |
|
|
|
|
|
def pad_embeddings(self, sequences, padding_side='right', padding_value=0.0): |
|
|
"""Pads a list of tensors to the same length.""" |
|
|
max_len = max(seq.size(0) for seq in sequences) |
|
|
output_dims = (len(sequences), max_len) + sequences[0].shape[1:] |
|
|
output = torch.full(output_dims, padding_value, dtype=sequences[0].dtype, device=sequences[0].device) |
|
|
|
|
|
for i, seq in enumerate(sequences): |
|
|
length = seq.size(0) |
|
|
if padding_side == 'right': |
|
|
output[i, :length, ...] = seq |
|
|
else: |
|
|
output[i, -length:, ...] = seq |
|
|
return output |
|
|
|
|
|
|
|
|
AutoConfig.register("symphony", SymphonyConfig) |
|
|
AutoModelForCausalLM.register(SymphonyConfig, SymphonyForConditionalGeneration) |