llspch / oolel_speech.py
Guerte's picture
Upload folder using huggingface_hub
06ab9ed verified
from typing import Optional, Union
import torch
from torch import nn
from transformers import HubertModel, HubertPreTrainedModel, Qwen2_5_VLForConditionalGeneration, AutoConfig
class ProjectorConv1d(nn.Module):
def __init__(self, config, encoder_dim, llm_dim):
super().__init__()
self.conv1d = nn.Conv1d(in_channels=encoder_dim,
out_channels=encoder_dim,
kernel_size=1,
stride=1,
padding=0)
self.linear1 = nn.Linear(encoder_dim, config.projector_hidden_size)
self.relu1 = nn.ReLU()
self.linear2 = nn.Linear(config.projector_hidden_size, llm_dim)
self.relu2 = nn.ReLU()
def forward(self, x):
x = x.transpose(1, 2)
x = self.conv1d(x)
x = x.transpose(1, 2)
x = self.relu1(x)
x = self.linear1(x)
x = self.relu2(x)
x = self.linear2(x)
return x
class OolelSpeech(HubertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.hubert = HubertModel(config)
self.projector = ProjectorConv1d(
config,
encoder_dim=config.hidden_size * 12,
llm_dim=config.llm_hidden_size
)
# Initialize weights and apply final processing
self.post_init()
def load_llm(self, llm_name_or_path, **kwargs):
self.llm = Qwen2_5_VLForConditionalGeneration.from_pretrained(llm_name_or_path, **kwargs)
def forward_hubert(
self,
input_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
return self.hubert(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
def embed_inputs(
self,
input_values: Optional[torch.Tensor] = None,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
hubert_outputs = self.forward_hubert(
input_values,
attention_mask=None,
output_attentions=False,
output_hidden_states=True,
return_dict=return_dict,
)
hidden_states = torch.cat(hubert_outputs.hidden_states[1:], dim=-1)
speech_embedding = self.projector(hidden_states)
token_embeddings = self.llm.model.language_model.embed_tokens(input_ids)
bs, sql, *_ = speech_embedding.shape
input_embedds = torch.cat([speech_embedding, token_embeddings], dim=1)
return input_embedds
def forward(
self,
input_values: Optional[torch.Tensor] = None,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None
):
input_embedds = self.embed_inputs(
input_values=input_values,
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return self.llm(
inputs_embeds=input_embedds,
attention_mask=attention_mask,
)
def generate(self, input_values, input_ids, **kwargs):
assert input_values.shape[0] == 1, "Batch generation is not supported with huggingface."
inputs_embeds = self.embed_inputs(
input_values=input_values,
input_ids=input_ids,
attention_mask=None,
output_attentions=False,
output_hidden_states=True,
)
return self.llm.generate(inputs_embeds=inputs_embeds, **kwargs)