| from transformers import WhisperPreTrainedModel |
| from transformers.models.whisper.modeling_whisper import ( |
| WhisperEncoder, |
| WhisperDecoder, |
| WhisperModel, |
| shift_tokens_right, |
| ) |
| from transformers import AutoModel, AutoFeatureExtractor, AutoTokenizer |
| from transformers import AutoConfig |
| from transformers.modeling_outputs import ( |
| BaseModelOutput, |
| Seq2SeqModelOutput, |
| Seq2SeqLMOutput, |
| ) |
| from transformers.generation import GenerationMixin |
| import torch |
| import torch.nn as nn |
| import librosa |
|
|
| class EncoderWithProjection(nn.Module): |
| def __init__(self, encoder, projection): |
| super().__init__() |
| self.encoder = encoder |
| self.projection = projection |
|
|
| def forward(self, *args, **kwargs): |
| encoder_outputs = self.encoder(*args, **kwargs) |
|
|
| return self.projection(encoder_outputs[0]) |
| |
| class GemmaWhisper(WhisperPreTrainedModel): |
| def __init__(self, config): |
| super().__init__(config) |
| |
| encoder_config = AutoConfig.from_pretrained( |
| 'mesolitica/gemma-3n-e4b-it-audio-encoder', trust_remote_code = True) |
| self.encoder = AutoModel.from_config(encoder_config, trust_remote_code = True) |
| self.decoder = WhisperDecoder(config) |
| |
| self.projection = nn.Linear( |
| self.encoder.config.text_config.hidden_size, self.decoder.config.d_model) |
| |
| self.post_init() |
| |
| def get_input_embeddings(self): |
| return self.decoder.embed_tokens |
|
|
| def set_input_embeddings(self, value): |
| self.decoder.embed_tokens = value |
| |
| def get_encoder(self): |
| return EncoderWithProjection(self.encoder, self.projection) |
|
|
| def get_decoder(self): |
| return self.decoder |
| |
| def forward( |
| self, |
| input_features = None, |
| attention_mask = None, |
| decoder_input_ids = None, |
| decoder_attention_mask = None, |
| head_mask = None, |
| decoder_head_mask = None, |
| cross_attn_head_mask = None, |
| encoder_outputs = None, |
| past_key_values = None, |
| decoder_inputs_embeds = None, |
| decoder_position_ids = None, |
| use_cache = None, |
| output_attentions = None, |
| output_hidden_states = None, |
| return_dict = None, |
| cache_position = None, |
| ): |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if encoder_outputs is None: |
| encoder_outputs = self.encoder( |
| input_features=input_features, |
| input_features_mask=attention_mask, |
| ) |
| encoder_outputs = self.projection(encoder_outputs[0]) |
| |
| decoder_outputs = self.decoder( |
| input_ids=decoder_input_ids, |
| attention_mask=decoder_attention_mask, |
| encoder_hidden_states=encoder_outputs, |
| head_mask=decoder_head_mask, |
| cross_attn_head_mask=cross_attn_head_mask, |
| past_key_values=past_key_values, |
| inputs_embeds=decoder_inputs_embeds, |
| position_ids=decoder_position_ids, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| ) |
|
|
| if not return_dict: |
| return decoder_outputs + encoder_outputs |
|
|
| return Seq2SeqModelOutput( |
| last_hidden_state=decoder_outputs.last_hidden_state, |
| past_key_values=decoder_outputs.past_key_values, |
| decoder_hidden_states=decoder_outputs.hidden_states, |
| decoder_attentions=decoder_outputs.attentions, |
| cross_attentions=decoder_outputs.cross_attentions, |
| encoder_last_hidden_state=None, |
| encoder_hidden_states=None, |
| encoder_attentions=None, |
| ) |
| |
| class GemmaWhisperForConditionalGeneration(WhisperPreTrainedModel, GenerationMixin): |
| base_model_prefix = "model" |
| _tied_weights_keys = ["proj_out.weight"] |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.model = GemmaWhisper(config) |
| self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False) |
| self.max_target_positions = config.max_target_positions |
|
|
| self.post_init() |
| |
| def get_output_embeddings(self): |
| return self.proj_out |
|
|
| def set_output_embeddings(self, new_embeddings): |
| self.proj_out = new_embeddings |
| |
| def get_input_embeddings(self) -> nn.Module: |
| return self.model.get_input_embeddings() |
| |
| def get_encoder(self): |
| return self.model.get_encoder() |
|
|
| def get_decoder(self): |
| return self.model.get_decoder() |
| |
| def forward( |
| self, |
| input_features = None, |
| attention_mask = None, |
| decoder_input_ids = None, |
| decoder_attention_mask = None, |
| head_mask = None, |
| decoder_head_mask = None, |
| cross_attn_head_mask = None, |
| encoder_outputs = None, |
| past_key_values = None, |
| decoder_inputs_embeds = None, |
| decoder_position_ids = None, |
| labels = None, |
| use_cache = None, |
| output_attentions = None, |
| output_hidden_states = None, |
| return_dict = None, |
| cache_position = None, |
| **kwargs, |
| ): |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| if labels is not None: |
| if labels.shape[1] > self.max_target_positions: |
| raise ValueError( |
| f"Labels' sequence length {labels.shape[1]} cannot exceed the maximum allowed length of {self.max_target_positions} tokens." |
| ) |
| if decoder_input_ids is None and decoder_inputs_embeds is None: |
| decoder_input_ids = shift_tokens_right( |
| labels, self.config.pad_token_id, self.config.decoder_start_token_id |
| ) |
|
|
| outputs = self.model( |
| input_features, |
| attention_mask=attention_mask, |
| decoder_input_ids=decoder_input_ids, |
| encoder_outputs=encoder_outputs, |
| decoder_attention_mask=decoder_attention_mask, |
| head_mask=head_mask, |
| decoder_head_mask=decoder_head_mask, |
| cross_attn_head_mask=cross_attn_head_mask, |
| past_key_values=past_key_values, |
| decoder_inputs_embeds=decoder_inputs_embeds, |
| decoder_position_ids=decoder_position_ids, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| ) |
| lm_logits = self.proj_out(outputs[0]) |
| |
| loss = None |
| |
| if not return_dict: |
| output = (lm_logits,) + outputs[1:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return Seq2SeqLMOutput( |
| loss=loss, |
| logits=lm_logits, |
| past_key_values=outputs.past_key_values, |
| decoder_hidden_states=outputs.decoder_hidden_states, |
| decoder_attentions=outputs.decoder_attentions, |
| cross_attentions=outputs.cross_attentions, |
| encoder_last_hidden_state=outputs.encoder_last_hidden_state, |
| encoder_hidden_states=outputs.encoder_hidden_states, |
| encoder_attentions=outputs.encoder_attentions, |
| ) |