| | import torch |
| | import torch.nn as nn |
| | from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin |
| | from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput |
| | from transformers.models.auto.configuration_auto import CONFIG_MAPPING |
| | from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING |
| |
|
| | class LanceASRConfig(PretrainedConfig): |
| | model_type = "lance_asr" |
| | is_encoder_decoder = True |
| | def __init__(self, vocab_size=50257, hidden_size=256, num_layers=4, num_heads=4, num_mel_bins=128, architectures=["LanceASR"], **kwargs): |
| | super().__init__(**kwargs) |
| | self.vocab_size = vocab_size |
| | self.hidden_size = hidden_size |
| | self.num_layers = num_layers |
| | self.num_heads = num_heads |
| | self.num_mel_bins = num_mel_bins |
| | self.architectures = architectures |
| | self.is_encoder_decoder = True |
| | self.decoder_start_token_id = kwargs.get("decoder_start_token_id", 0) |
| |
|
| | class LanceASR(PreTrainedModel, GenerationMixin): |
| | config_class = LanceASRConfig |
| | _supports_cache_class = False |
| |
|
| | def __init__(self, config): |
| | config.is_encoder_decoder = True |
| | super().__init__(config) |
| | self.config = config |
| | |
| | |
| | self.conv1 = nn.Conv1d(config.num_mel_bins, config.hidden_size, kernel_size=3, padding=1) |
| | self.conv2 = nn.Conv1d(config.hidden_size, config.hidden_size, kernel_size=3, stride=2, padding=1) |
| | |
| | |
| | self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) |
| |
|
| | self.encoder = nn.TransformerEncoder( |
| | nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_heads, batch_first=True), |
| | num_layers=config.num_layers |
| | ) |
| | self.decoder = nn.TransformerDecoder( |
| | nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_heads, batch_first=True), |
| | num_layers=config.num_layers |
| | ) |
| |
|
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) |
| | self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100) |
| | |
| | |
| | self.generation_config.max_new_tokens = 250 |
| | self.generation_config.temperature = 0.8 |
| | self.generation_config.do_sample = True |
| | self.generation_config.decoder_start_token_id = self.config.decoder_start_token_id |
| |
|
| | self.init_weights() |
| |
|
| | self.to(torch.bfloat16) |
| |
|
| | def get_encoder(self): |
| | class EncoderWrapper(nn.Module): |
| | def __init__(self, model): |
| | super().__init__() |
| | self.model = model |
| | self.main_input_name = "input_features" |
| | def forward(self, input_features, attention_mask=None, **kwargs): |
| | return self.model.forward_encoder(input_features) |
| | def __call__(self, *args, **kwargs): |
| | return self.forward(*args, **kwargs) |
| | return EncoderWrapper(self) |
| |
|
| | def forward_encoder(self, input_features): |
| | hidden_states = nn.functional.gelu(self.conv1(input_features)) |
| | hidden_states = nn.functional.gelu(self.conv2(hidden_states)) |
| | |
| | inputs_embeds = hidden_states.permute(0, 2, 1) |
| | encoder_outputs = self.encoder(inputs_embeds) |
| | return BaseModelOutput(last_hidden_state=encoder_outputs) |
| |
|
| | def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): |
| | shifted_labels = labels.new_zeros(labels.shape) |
| | shifted_labels[..., 1:] = labels[..., :-1].clone() |
| | shifted_labels[..., 0] = self.config.decoder_start_token_id |
| | shifted_labels.masked_fill_(shifted_labels == -100, 0) |
| | return shifted_labels |
| |
|
| | def forward(self, input_features=None, decoder_input_ids=None, input_ids=None, encoder_outputs=None, labels=None, return_dict=True, use_cache=False, **kwargs): |
| | if decoder_input_ids is None and input_ids is not None: |
| | decoder_input_ids = input_ids |
| | |
| | if decoder_input_ids is None and labels is not None: |
| | decoder_input_ids = self.prepare_decoder_input_ids_from_labels(labels) |
| | |
| | if encoder_outputs is None and input_features is not None: |
| | encoder_outputs = self.forward_encoder(input_features) |
| | |
| | memory = encoder_outputs.last_hidden_state if hasattr(encoder_outputs, "last_hidden_state") else (encoder_outputs[0] if isinstance(encoder_outputs, tuple) else encoder_outputs) |
| |
|
| | if decoder_input_ids is not None: |
| | decoder_embeds = self.embedding(decoder_input_ids) |
| | else: |
| | raise ValueError("decoder_input_ids must be provided") |
| | |
| | seq_len = decoder_embeds.size(1) |
| | tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(device=decoder_embeds.device, dtype=decoder_embeds.dtype) |
| |
|
| | decoder_output = self.decoder( |
| | tgt=decoder_embeds, |
| | memory=memory, |
| | tgt_mask=tgt_mask, |
| | tgt_is_causal=True |
| | ) |
| |
|
| | logits = self.lm_head(decoder_output) |
| | loss = None |
| |
|
| | if labels is not None: |
| | loss = self.loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) |
| |
|
| | if return_dict: |
| | return Seq2SeqLMOutput(loss=loss, logits=logits, encoder_last_hidden_state=memory) |
| |
|
| | return (loss, logits) if loss is not None else logits |
| |
|
| | def prepare_inputs_for_generation(self, decoder_input_ids, past_key_values=None, attention_mask=None, encoder_outputs=None, **kwargs): |
| | return { |
| | "decoder_input_ids": decoder_input_ids, |
| | "encoder_outputs": encoder_outputs, |
| | } |
| |
|
| | def _reorder_cache(self, past_key_values, beam_idx): |
| | pass |
| |
|
| | CONFIG_MAPPING.register("lance_asr", LanceASRConfig) |
| | try: |
| | MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.register(LanceASRConfig, LanceASR) |
| | except Exception: |
| | pass |
| | LanceASRConfig.register_for_auto_class("AutoConfig") |
| | LanceASR.register_for_auto_class("AutoModelForSeq2SeqLM") |