File size: 6,056 Bytes
963642f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
class LanceASRConfig(PretrainedConfig):
model_type = "lance_asr"
is_encoder_decoder = True
def __init__(self, vocab_size=50257, hidden_size=256, num_layers=4, num_heads=4, num_mel_bins=128, architectures=["LanceASR"], **kwargs):
super().__init__(**kwargs)
self.vocab_size = vocab_size
self.hidden_size = hidden_size
self.num_layers = num_layers
self.num_heads = num_heads
self.num_mel_bins = num_mel_bins
self.architectures = architectures
self.is_encoder_decoder = True
self.decoder_start_token_id = kwargs.get("decoder_start_token_id", 0)
class LanceASR(PreTrainedModel, GenerationMixin):
config_class = LanceASRConfig
_supports_cache_class = False
def __init__(self, config):
config.is_encoder_decoder = True
super().__init__(config)
self.config = config
# Audio feature extraction (Conv subsampling)
self.conv1 = nn.Conv1d(config.num_mel_bins, config.hidden_size, kernel_size=3, padding=1)
self.conv2 = nn.Conv1d(config.hidden_size, config.hidden_size, kernel_size=3, stride=2, padding=1)
# Text embedding
self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
self.encoder = nn.TransformerEncoder(
nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_heads, batch_first=True),
num_layers=config.num_layers
)
self.decoder = nn.TransformerDecoder(
nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_heads, batch_first=True),
num_layers=config.num_layers
)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
# Generation config defaults
self.generation_config.max_new_tokens = 250
self.generation_config.temperature = 0.8
self.generation_config.do_sample = True
self.generation_config.decoder_start_token_id = self.config.decoder_start_token_id
self.init_weights()
self.to(torch.bfloat16)
def get_encoder(self):
class EncoderWrapper(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.main_input_name = "input_features"
def forward(self, input_features, attention_mask=None, **kwargs):
return self.model.forward_encoder(input_features)
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
return EncoderWrapper(self)
def forward_encoder(self, input_features):
hidden_states = nn.functional.gelu(self.conv1(input_features))
hidden_states = nn.functional.gelu(self.conv2(hidden_states))
inputs_embeds = hidden_states.permute(0, 2, 1)
encoder_outputs = self.encoder(inputs_embeds)
return BaseModelOutput(last_hidden_state=encoder_outputs)
def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
shifted_labels = labels.new_zeros(labels.shape)
shifted_labels[..., 1:] = labels[..., :-1].clone()
shifted_labels[..., 0] = self.config.decoder_start_token_id
shifted_labels.masked_fill_(shifted_labels == -100, 0)
return shifted_labels
def forward(self, input_features=None, decoder_input_ids=None, input_ids=None, encoder_outputs=None, labels=None, return_dict=True, use_cache=False, **kwargs):
if decoder_input_ids is None and input_ids is not None:
decoder_input_ids = input_ids
if decoder_input_ids is None and labels is not None:
decoder_input_ids = self.prepare_decoder_input_ids_from_labels(labels)
if encoder_outputs is None and input_features is not None:
encoder_outputs = self.forward_encoder(input_features)
memory = encoder_outputs.last_hidden_state if hasattr(encoder_outputs, "last_hidden_state") else (encoder_outputs[0] if isinstance(encoder_outputs, tuple) else encoder_outputs)
if decoder_input_ids is not None:
decoder_embeds = self.embedding(decoder_input_ids)
else:
raise ValueError("decoder_input_ids must be provided")
seq_len = decoder_embeds.size(1)
tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(device=decoder_embeds.device, dtype=decoder_embeds.dtype)
decoder_output = self.decoder(
tgt=decoder_embeds,
memory=memory,
tgt_mask=tgt_mask,
tgt_is_causal=True
)
logits = self.lm_head(decoder_output)
loss = None
if labels is not None:
loss = self.loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
if return_dict:
return Seq2SeqLMOutput(loss=loss, logits=logits, encoder_last_hidden_state=memory)
return (loss, logits) if loss is not None else logits
def prepare_inputs_for_generation(self, decoder_input_ids, past_key_values=None, attention_mask=None, encoder_outputs=None, **kwargs):
return {
"decoder_input_ids": decoder_input_ids,
"encoder_outputs": encoder_outputs,
}
def _reorder_cache(self, past_key_values, beam_idx):
pass
CONFIG_MAPPING.register("lance_asr", LanceASRConfig)
try:
MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.register(LanceASRConfig, LanceASR)
except Exception:
pass
LanceASRConfig.register_for_auto_class("AutoConfig")
LanceASR.register_for_auto_class("AutoModelForSeq2SeqLM") |