huseinzol05's picture
Upload GemmaWhisperForConditionalGeneration
e56435e verified
from transformers import WhisperPreTrainedModel
from transformers.models.whisper.modeling_whisper import (
WhisperEncoder,
WhisperDecoder,
WhisperModel,
shift_tokens_right,
)
from transformers import AutoModel, AutoFeatureExtractor, AutoTokenizer
from transformers import AutoConfig
from transformers.modeling_outputs import (
BaseModelOutput,
Seq2SeqModelOutput,
Seq2SeqLMOutput,
)
from transformers.generation import GenerationMixin
import torch
import torch.nn as nn
import librosa
class EncoderWithProjection(nn.Module):
def __init__(self, encoder, projection):
super().__init__()
self.encoder = encoder
self.projection = projection
def forward(self, *args, **kwargs):
encoder_outputs = self.encoder(*args, **kwargs)
return self.projection(encoder_outputs[0])
class GemmaWhisper(WhisperPreTrainedModel):
def __init__(self, config):
super().__init__(config)
encoder_config = AutoConfig.from_pretrained(
'mesolitica/gemma-3n-e4b-it-audio-encoder', trust_remote_code = True)
self.encoder = AutoModel.from_config(encoder_config, trust_remote_code = True)
self.decoder = WhisperDecoder(config)
self.projection = nn.Linear(
self.encoder.config.text_config.hidden_size, self.decoder.config.d_model)
self.post_init()
def get_input_embeddings(self):
return self.decoder.embed_tokens
def set_input_embeddings(self, value):
self.decoder.embed_tokens = value
def get_encoder(self):
return EncoderWithProjection(self.encoder, self.projection)
def get_decoder(self):
return self.decoder
def forward(
self,
input_features = None,
attention_mask = None,
decoder_input_ids = None,
decoder_attention_mask = None,
head_mask = None,
decoder_head_mask = None,
cross_attn_head_mask = None,
encoder_outputs = None,
past_key_values = None,
decoder_inputs_embeds = None,
decoder_position_ids = None,
use_cache = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
cache_position = None,
):
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if encoder_outputs is None:
encoder_outputs = self.encoder(
input_features=input_features,
input_features_mask=attention_mask,
)
encoder_outputs = self.projection(encoder_outputs[0])
decoder_outputs = self.decoder(
input_ids=decoder_input_ids,
attention_mask=decoder_attention_mask,
encoder_hidden_states=encoder_outputs,
head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
inputs_embeds=decoder_inputs_embeds,
position_ids=decoder_position_ids,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
if not return_dict:
return decoder_outputs + encoder_outputs
return Seq2SeqModelOutput(
last_hidden_state=decoder_outputs.last_hidden_state,
past_key_values=decoder_outputs.past_key_values,
decoder_hidden_states=decoder_outputs.hidden_states,
decoder_attentions=decoder_outputs.attentions,
cross_attentions=decoder_outputs.cross_attentions,
encoder_last_hidden_state=None,
encoder_hidden_states=None,
encoder_attentions=None,
)
class GemmaWhisperForConditionalGeneration(WhisperPreTrainedModel, GenerationMixin):
base_model_prefix = "model"
_tied_weights_keys = ["proj_out.weight"]
def __init__(self, config):
super().__init__(config)
self.model = GemmaWhisper(config)
self.proj_out = nn.Linear(config.d_model, config.vocab_size, bias=False)
self.max_target_positions = config.max_target_positions
self.post_init()
def get_output_embeddings(self):
return self.proj_out
def set_output_embeddings(self, new_embeddings):
self.proj_out = new_embeddings
def get_input_embeddings(self) -> nn.Module:
return self.model.get_input_embeddings()
def get_encoder(self):
return self.model.get_encoder()
def get_decoder(self):
return self.model.get_decoder()
def forward(
self,
input_features = None,
attention_mask = None,
decoder_input_ids = None,
decoder_attention_mask = None,
head_mask = None,
decoder_head_mask = None,
cross_attn_head_mask = None,
encoder_outputs = None,
past_key_values = None,
decoder_inputs_embeds = None,
decoder_position_ids = None,
labels = None,
use_cache = None,
output_attentions = None,
output_hidden_states = None,
return_dict = None,
cache_position = None,
**kwargs,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if labels is not None:
if labels.shape[1] > self.max_target_positions:
raise ValueError(
f"Labels' sequence length {labels.shape[1]} cannot exceed the maximum allowed length of {self.max_target_positions} tokens."
)
if decoder_input_ids is None and decoder_inputs_embeds is None:
decoder_input_ids = shift_tokens_right(
labels, self.config.pad_token_id, self.config.decoder_start_token_id
)
outputs = self.model(
input_features,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
encoder_outputs=encoder_outputs,
decoder_attention_mask=decoder_attention_mask,
head_mask=head_mask,
decoder_head_mask=decoder_head_mask,
cross_attn_head_mask=cross_attn_head_mask,
past_key_values=past_key_values,
decoder_inputs_embeds=decoder_inputs_embeds,
decoder_position_ids=decoder_position_ids,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
cache_position=cache_position,
)
lm_logits = self.proj_out(outputs[0])
loss = None
if not return_dict:
output = (lm_logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return Seq2SeqLMOutput(
loss=loss,
logits=lm_logits,
past_key_values=outputs.past_key_values,
decoder_hidden_states=outputs.decoder_hidden_states,
decoder_attentions=outputs.decoder_attentions,
cross_attentions=outputs.cross_attentions,
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
encoder_hidden_states=outputs.encoder_hidden_states,
encoder_attentions=outputs.encoder_attentions,
)