|
|
from typing import Optional, Union |
|
|
import torch |
|
|
from torch import nn |
|
|
from transformers import HubertModel, HubertPreTrainedModel, Qwen2_5_VLForConditionalGeneration, AutoConfig |
|
|
|
|
|
class ProjectorConv1d(nn.Module): |
|
|
|
|
|
def __init__(self, config, encoder_dim, llm_dim): |
|
|
super().__init__() |
|
|
self.conv1d = nn.Conv1d(in_channels=encoder_dim, |
|
|
out_channels=encoder_dim, |
|
|
kernel_size=1, |
|
|
stride=1, |
|
|
padding=0) |
|
|
self.linear1 = nn.Linear(encoder_dim, config.projector_hidden_size) |
|
|
self.relu1 = nn.ReLU() |
|
|
self.linear2 = nn.Linear(config.projector_hidden_size, llm_dim) |
|
|
self.relu2 = nn.ReLU() |
|
|
|
|
|
def forward(self, x): |
|
|
x = x.transpose(1, 2) |
|
|
x = self.conv1d(x) |
|
|
x = x.transpose(1, 2) |
|
|
x = self.relu1(x) |
|
|
x = self.linear1(x) |
|
|
x = self.relu2(x) |
|
|
x = self.linear2(x) |
|
|
return x |
|
|
|
|
|
class OolelSpeech(HubertPreTrainedModel): |
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
self.hubert = HubertModel(config) |
|
|
self.projector = ProjectorConv1d( |
|
|
config, |
|
|
encoder_dim=config.hidden_size * 12, |
|
|
llm_dim=config.llm_hidden_size |
|
|
) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def load_llm(self, llm_name_or_path, **kwargs): |
|
|
self.llm = Qwen2_5_VLForConditionalGeneration.from_pretrained(llm_name_or_path, **kwargs) |
|
|
|
|
|
def forward_hubert( |
|
|
self, |
|
|
input_values: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
): |
|
|
|
|
|
return self.hubert( |
|
|
input_values, |
|
|
attention_mask=attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
def embed_inputs( |
|
|
self, |
|
|
input_values: Optional[torch.Tensor] = None, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
): |
|
|
|
|
|
hubert_outputs = self.forward_hubert( |
|
|
input_values, |
|
|
attention_mask=None, |
|
|
output_attentions=False, |
|
|
output_hidden_states=True, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
|
|
|
hidden_states = torch.cat(hubert_outputs.hidden_states[1:], dim=-1) |
|
|
speech_embedding = self.projector(hidden_states) |
|
|
token_embeddings = self.llm.model.language_model.embed_tokens(input_ids) |
|
|
|
|
|
bs, sql, *_ = speech_embedding.shape |
|
|
input_embedds = torch.cat([speech_embedding, token_embeddings], dim=1) |
|
|
|
|
|
return input_embedds |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_values: Optional[torch.Tensor] = None, |
|
|
input_ids: Optional[torch.Tensor] = None, |
|
|
attention_mask: Optional[torch.Tensor] = None, |
|
|
output_attentions: Optional[bool] = None, |
|
|
output_hidden_states: Optional[bool] = None, |
|
|
return_dict: Optional[bool] = None |
|
|
): |
|
|
|
|
|
input_embedds = self.embed_inputs( |
|
|
input_values=input_values, |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
output_attentions=output_attentions, |
|
|
output_hidden_states=output_hidden_states, |
|
|
return_dict=return_dict, |
|
|
) |
|
|
return self.llm( |
|
|
inputs_embeds=input_embedds, |
|
|
attention_mask=attention_mask, |
|
|
) |
|
|
|
|
|
def generate(self, input_values, input_ids, **kwargs): |
|
|
assert input_values.shape[0] == 1, "Batch generation is not supported with huggingface." |
|
|
inputs_embeds = self.embed_inputs( |
|
|
input_values=input_values, |
|
|
input_ids=input_ids, |
|
|
attention_mask=None, |
|
|
output_attentions=False, |
|
|
output_hidden_states=True, |
|
|
) |
|
|
|
|
|
return self.llm.generate(inputs_embeds=inputs_embeds, **kwargs) |