from typing import Optional, Union import torch from torch import nn from transformers import HubertModel, HubertPreTrainedModel, Qwen2_5_VLForConditionalGeneration, AutoConfig class ProjectorConv1d(nn.Module): def __init__(self, config, encoder_dim, llm_dim): super().__init__() self.conv1d = nn.Conv1d(in_channels=encoder_dim, out_channels=encoder_dim, kernel_size=1, stride=1, padding=0) self.linear1 = nn.Linear(encoder_dim, config.projector_hidden_size) self.relu1 = nn.ReLU() self.linear2 = nn.Linear(config.projector_hidden_size, llm_dim) self.relu2 = nn.ReLU() def forward(self, x): x = x.transpose(1, 2) x = self.conv1d(x) x = x.transpose(1, 2) x = self.relu1(x) x = self.linear1(x) x = self.relu2(x) x = self.linear2(x) return x class OolelSpeech(HubertPreTrainedModel): def __init__(self, config): super().__init__(config) self.hubert = HubertModel(config) self.projector = ProjectorConv1d( config, encoder_dim=config.hidden_size * 12, llm_dim=config.llm_hidden_size ) # Initialize weights and apply final processing self.post_init() def load_llm(self, llm_name_or_path, **kwargs): self.llm = Qwen2_5_VLForConditionalGeneration.from_pretrained(llm_name_or_path, **kwargs) def forward_hubert( self, input_values: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): return self.hubert( input_values, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) def embed_inputs( self, input_values: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ): hubert_outputs = self.forward_hubert( input_values, attention_mask=None, output_attentions=False, output_hidden_states=True, return_dict=return_dict, ) hidden_states = torch.cat(hubert_outputs.hidden_states[1:], dim=-1) speech_embedding = self.projector(hidden_states) token_embeddings = self.llm.model.language_model.embed_tokens(input_ids) bs, sql, *_ = speech_embedding.shape input_embedds = torch.cat([speech_embedding, token_embeddings], dim=1) return input_embedds def forward( self, input_values: Optional[torch.Tensor] = None, input_ids: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None ): input_embedds = self.embed_inputs( input_values=input_values, input_ids=input_ids, attention_mask=attention_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, ) return self.llm( inputs_embeds=input_embedds, attention_mask=attention_mask, ) def generate(self, input_values, input_ids, **kwargs): assert input_values.shape[0] == 1, "Batch generation is not supported with huggingface." inputs_embeds = self.embed_inputs( input_values=input_values, input_ids=input_ids, attention_mask=None, output_attentions=False, output_hidden_states=True, ) return self.llm.generate(inputs_embeds=inputs_embeds, **kwargs)