File size: 4,271 Bytes
06ab9ed |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from typing import Optional, Union
import torch
from torch import nn
from transformers import HubertModel, HubertPreTrainedModel, Qwen2_5_VLForConditionalGeneration, AutoConfig
class ProjectorConv1d(nn.Module):
def __init__(self, config, encoder_dim, llm_dim):
super().__init__()
self.conv1d = nn.Conv1d(in_channels=encoder_dim,
out_channels=encoder_dim,
kernel_size=1,
stride=1,
padding=0)
self.linear1 = nn.Linear(encoder_dim, config.projector_hidden_size)
self.relu1 = nn.ReLU()
self.linear2 = nn.Linear(config.projector_hidden_size, llm_dim)
self.relu2 = nn.ReLU()
def forward(self, x):
x = x.transpose(1, 2)
x = self.conv1d(x)
x = x.transpose(1, 2)
x = self.relu1(x)
x = self.linear1(x)
x = self.relu2(x)
x = self.linear2(x)
return x
class OolelSpeech(HubertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.hubert = HubertModel(config)
self.projector = ProjectorConv1d(
config,
encoder_dim=config.hidden_size * 12,
llm_dim=config.llm_hidden_size
)
# Initialize weights and apply final processing
self.post_init()
def load_llm(self, llm_name_or_path, **kwargs):
self.llm = Qwen2_5_VLForConditionalGeneration.from_pretrained(llm_name_or_path, **kwargs)
def forward_hubert(
self,
input_values: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
return self.hubert(
input_values,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
def embed_inputs(
self,
input_values: Optional[torch.Tensor] = None,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
hubert_outputs = self.forward_hubert(
input_values,
attention_mask=None,
output_attentions=False,
output_hidden_states=True,
return_dict=return_dict,
)
hidden_states = torch.cat(hubert_outputs.hidden_states[1:], dim=-1)
speech_embedding = self.projector(hidden_states)
token_embeddings = self.llm.model.language_model.embed_tokens(input_ids)
bs, sql, *_ = speech_embedding.shape
input_embedds = torch.cat([speech_embedding, token_embeddings], dim=1)
return input_embedds
def forward(
self,
input_values: Optional[torch.Tensor] = None,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None
):
input_embedds = self.embed_inputs(
input_values=input_values,
input_ids=input_ids,
attention_mask=attention_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
return self.llm(
inputs_embeds=input_embedds,
attention_mask=attention_mask,
)
def generate(self, input_values, input_ids, **kwargs):
assert input_values.shape[0] == 1, "Batch generation is not supported with huggingface."
inputs_embeds = self.embed_inputs(
input_values=input_values,
input_ids=input_ids,
attention_mask=None,
output_attentions=False,
output_hidden_states=True,
)
return self.llm.generate(inputs_embeds=inputs_embeds, **kwargs) |