import torch import torch.nn as nn from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput from transformers.models.auto.configuration_auto import CONFIG_MAPPING from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING class LanceASRConfig(PretrainedConfig): model_type = "lance_asr" is_encoder_decoder = True def __init__(self, vocab_size=50257, hidden_size=256, num_layers=4, num_heads=4, num_mel_bins=128, architectures=["LanceASR"], **kwargs): super().__init__(**kwargs) self.vocab_size = vocab_size self.hidden_size = hidden_size self.num_layers = num_layers self.num_heads = num_heads self.num_mel_bins = num_mel_bins self.architectures = architectures self.is_encoder_decoder = True self.decoder_start_token_id = kwargs.get("decoder_start_token_id", 0) class LanceASR(PreTrainedModel, GenerationMixin): config_class = LanceASRConfig _supports_cache_class = False def __init__(self, config): config.is_encoder_decoder = True super().__init__(config) self.config = config # Audio feature extraction (Conv subsampling) self.conv1 = nn.Conv1d(config.num_mel_bins, config.hidden_size, kernel_size=3, padding=1) self.conv2 = nn.Conv1d(config.hidden_size, config.hidden_size, kernel_size=3, stride=2, padding=1) # Text embedding self.embedding = nn.Embedding(config.vocab_size, config.hidden_size) self.encoder = nn.TransformerEncoder( nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_heads, batch_first=True), num_layers=config.num_layers ) self.decoder = nn.TransformerDecoder( nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_heads, batch_first=True), num_layers=config.num_layers ) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size) self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100) # Generation config defaults self.generation_config.max_new_tokens = 250 self.generation_config.temperature = 0.8 self.generation_config.do_sample = True self.generation_config.decoder_start_token_id = self.config.decoder_start_token_id self.init_weights() self.to(torch.bfloat16) def get_encoder(self): class EncoderWrapper(nn.Module): def __init__(self, model): super().__init__() self.model = model self.main_input_name = "input_features" def forward(self, input_features, attention_mask=None, **kwargs): return self.model.forward_encoder(input_features) def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) return EncoderWrapper(self) def forward_encoder(self, input_features): hidden_states = nn.functional.gelu(self.conv1(input_features)) hidden_states = nn.functional.gelu(self.conv2(hidden_states)) inputs_embeds = hidden_states.permute(0, 2, 1) encoder_outputs = self.encoder(inputs_embeds) return BaseModelOutput(last_hidden_state=encoder_outputs) def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor): shifted_labels = labels.new_zeros(labels.shape) shifted_labels[..., 1:] = labels[..., :-1].clone() shifted_labels[..., 0] = self.config.decoder_start_token_id shifted_labels.masked_fill_(shifted_labels == -100, 0) return shifted_labels def forward(self, input_features=None, decoder_input_ids=None, input_ids=None, encoder_outputs=None, labels=None, return_dict=True, use_cache=False, **kwargs): if decoder_input_ids is None and input_ids is not None: decoder_input_ids = input_ids if decoder_input_ids is None and labels is not None: decoder_input_ids = self.prepare_decoder_input_ids_from_labels(labels) if encoder_outputs is None and input_features is not None: encoder_outputs = self.forward_encoder(input_features) memory = encoder_outputs.last_hidden_state if hasattr(encoder_outputs, "last_hidden_state") else (encoder_outputs[0] if isinstance(encoder_outputs, tuple) else encoder_outputs) if decoder_input_ids is not None: decoder_embeds = self.embedding(decoder_input_ids) else: raise ValueError("decoder_input_ids must be provided") seq_len = decoder_embeds.size(1) tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(device=decoder_embeds.device, dtype=decoder_embeds.dtype) decoder_output = self.decoder( tgt=decoder_embeds, memory=memory, tgt_mask=tgt_mask, tgt_is_causal=True ) logits = self.lm_head(decoder_output) loss = None if labels is not None: loss = self.loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1)) if return_dict: return Seq2SeqLMOutput(loss=loss, logits=logits, encoder_last_hidden_state=memory) return (loss, logits) if loss is not None else logits def prepare_inputs_for_generation(self, decoder_input_ids, past_key_values=None, attention_mask=None, encoder_outputs=None, **kwargs): return { "decoder_input_ids": decoder_input_ids, "encoder_outputs": encoder_outputs, } def _reorder_cache(self, past_key_values, beam_idx): pass CONFIG_MAPPING.register("lance_asr", LanceASRConfig) try: MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.register(LanceASRConfig, LanceASR) except Exception: pass LanceASRConfig.register_for_auto_class("AutoConfig") LanceASR.register_for_auto_class("AutoModelForSeq2SeqLM")