Lance-ASR / lance_asr_model.py
NeuraCraft's picture
Upload LanceASR
963642f
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
class LanceASRConfig(PretrainedConfig):
model_type = "lance_asr"
is_encoder_decoder = True
def __init__(self, vocab_size=50257, hidden_size=256, num_layers=4, num_heads=4, num_mel_bins=128, architectures=["LanceASR"], **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.num_mel_bins = num_mel_bins
self.architectures = architectures
self.is_encoder_decoder = True
self.decoder_start_token_id = kwargs.get("decoder_start_token_id", 0)
class LanceASR(PreTrainedModel, GenerationMixin):
config_class = LanceASRConfig
_supports_cache_class = False
def __init__(self, config):
config.is_encoder_decoder = True
super().__init__(config)
self.config = config
# Audio feature extraction (Conv subsampling)
self.conv1 = nn.Conv1d(config.num_mel_bins, config.hidden_size, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(config.hidden_size, config.hidden_size, kernel_size=3, stride=2, padding=1)
# Text embedding
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_heads, batch_first=True),
num_layers=config.num_layers
)
self.decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_heads, batch_first=True),
num_layers=config.num_layers
)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
# Generation config defaults
self.generation_config.max_new_tokens = 250
self.generation_config.temperature = 0.8
self.generation_config.do_sample = True
self.generation_config.decoder_start_token_id = self.config.decoder_start_token_id
self.init_weights()
self.to(torch.bfloat16)
def get_encoder(self):
class EncoderWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.main_input_name = "input_features"
def forward(self, input_features, attention_mask=None, **kwargs):
return self.model.forward_encoder(input_features)
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
return EncoderWrapper(self)
def forward_encoder(self, input_features):
hidden_states = nn.functional.gelu(self.conv1(input_features))
hidden_states = nn.functional.gelu(self.conv2(hidden_states))
inputs_embeds = hidden_states.permute(0, 2, 1)
encoder_outputs = self.encoder(inputs_embeds)
return BaseModelOutput(last_hidden_state=encoder_outputs)
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
shifted_labels = labels.new_zeros(labels.shape)
shifted_labels[..., 1:] = labels[..., :-1].clone()
shifted_labels[..., 0] = self.config.decoder_start_token_id
shifted_labels.masked_fill_(shifted_labels == -100, 0)
return shifted_labels
def forward(self, input_features=None, decoder_input_ids=None, input_ids=None, encoder_outputs=None, labels=None, return_dict=True, use_cache=False, **kwargs):
if decoder_input_ids is None and input_ids is not None:
decoder_input_ids = input_ids
if decoder_input_ids is None and labels is not None:
decoder_input_ids = self.prepare_decoder_input_ids_from_labels(labels)
if encoder_outputs is None and input_features is not None:
encoder_outputs = self.forward_encoder(input_features)
memory = encoder_outputs.last_hidden_state if hasattr(encoder_outputs, "last_hidden_state") else (encoder_outputs[0] if isinstance(encoder_outputs, tuple) else encoder_outputs)
if decoder_input_ids is not None:
decoder_embeds = self.embedding(decoder_input_ids)
else:
raise ValueError("decoder_input_ids must be provided")
seq_len = decoder_embeds.size(1)
tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(device=decoder_embeds.device, dtype=decoder_embeds.dtype)
decoder_output = self.decoder(
tgt=decoder_embeds,
memory=memory,
tgt_mask=tgt_mask,
tgt_is_causal=True
)
logits = self.lm_head(decoder_output)
loss = None
if labels is not None:
loss = self.loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if return_dict:
return Seq2SeqLMOutput(loss=loss, logits=logits, encoder_last_hidden_state=memory)
return (loss, logits) if loss is not None else logits
def prepare_inputs_for_generation(self, decoder_input_ids, past_key_values=None, attention_mask=None, encoder_outputs=None, **kwargs):
return {
"decoder_input_ids": decoder_input_ids,
"encoder_outputs": encoder_outputs,
}
def _reorder_cache(self, past_key_values, beam_idx):
pass
CONFIG_MAPPING.register("lance_asr", LanceASRConfig)
try:
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.register(LanceASRConfig, LanceASR)
except Exception:
pass
LanceASRConfig.register_for_auto_class("AutoConfig")
LanceASR.register_for_auto_class("AutoModelForSeq2SeqLM")