NeuraCraft commited on
Commit
963642f
·
1 Parent(s): 4514b3d

Upload LanceASR

Browse files
Files changed (3) hide show
  1. config.json +19 -0
  2. generation_config.json +8 -0
  3. lance_asr_model.py +140 -0
config.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "LanceASR"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "lance_asr_model.LanceASRConfig",
7
+ "AutoModelForSeq2SeqLM": "lance_asr_model.LanceASR"
8
+ },
9
+ "decoder_start_token_id": 100257,
10
+ "dtype": "bfloat16",
11
+ "hidden_size": 768,
12
+ "is_encoder_decoder": true,
13
+ "model_type": "lance_asr",
14
+ "num_heads": 12,
15
+ "num_layers": 4,
16
+ "num_mel_bins": 128,
17
+ "transformers_version": "4.57.3",
18
+ "vocab_size": 100277
19
+ }
generation_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "decoder_start_token_id": 100257,
4
+ "do_sample": true,
5
+ "max_new_tokens": 250,
6
+ "temperature": 0.8,
7
+ "transformers_version": "4.57.3"
8
+ }
lance_asr_model.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
4
+ from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
5
+ from transformers.models.auto.configuration_auto import CONFIG_MAPPING
6
+ from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING
7
+
8
+ class LanceASRConfig(PretrainedConfig):
9
+ model_type = "lance_asr"
10
+ is_encoder_decoder = True
11
+ def __init__(self, vocab_size=50257, hidden_size=256, num_layers=4, num_heads=4, num_mel_bins=128, architectures=["LanceASR"], **kwargs):
12
+ super().__init__(**kwargs)
13
+ self.vocab_size = vocab_size
14
+ self.hidden_size = hidden_size
15
+ self.num_layers = num_layers
16
+ self.num_heads = num_heads
17
+ self.num_mel_bins = num_mel_bins
18
+ self.architectures = architectures
19
+ self.is_encoder_decoder = True
20
+ self.decoder_start_token_id = kwargs.get("decoder_start_token_id", 0)
21
+
22
+ class LanceASR(PreTrainedModel, GenerationMixin):
23
+ config_class = LanceASRConfig
24
+ _supports_cache_class = False
25
+
26
+ def __init__(self, config):
27
+ config.is_encoder_decoder = True
28
+ super().__init__(config)
29
+ self.config = config
30
+
31
+ # Audio feature extraction (Conv subsampling)
32
+ self.conv1 = nn.Conv1d(config.num_mel_bins, config.hidden_size, kernel_size=3, padding=1)
33
+ self.conv2 = nn.Conv1d(config.hidden_size, config.hidden_size, kernel_size=3, stride=2, padding=1)
34
+
35
+ # Text embedding
36
+ self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)
37
+
38
+ self.encoder = nn.TransformerEncoder(
39
+ nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_heads, batch_first=True),
40
+ num_layers=config.num_layers
41
+ )
42
+ self.decoder = nn.TransformerDecoder(
43
+ nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_heads, batch_first=True),
44
+ num_layers=config.num_layers
45
+ )
46
+
47
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
48
+ self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
49
+
50
+ # Generation config defaults
51
+ self.generation_config.max_new_tokens = 250
52
+ self.generation_config.temperature = 0.8
53
+ self.generation_config.do_sample = True
54
+ self.generation_config.decoder_start_token_id = self.config.decoder_start_token_id
55
+
56
+ self.init_weights()
57
+
58
+ self.to(torch.bfloat16)
59
+
60
+ def get_encoder(self):
61
+ class EncoderWrapper(nn.Module):
62
+ def __init__(self, model):
63
+ super().__init__()
64
+ self.model = model
65
+ self.main_input_name = "input_features"
66
+ def forward(self, input_features, attention_mask=None, **kwargs):
67
+ return self.model.forward_encoder(input_features)
68
+ def __call__(self, *args, **kwargs):
69
+ return self.forward(*args, **kwargs)
70
+ return EncoderWrapper(self)
71
+
72
+ def forward_encoder(self, input_features):
73
+ hidden_states = nn.functional.gelu(self.conv1(input_features))
74
+ hidden_states = nn.functional.gelu(self.conv2(hidden_states))
75
+
76
+ inputs_embeds = hidden_states.permute(0, 2, 1)
77
+ encoder_outputs = self.encoder(inputs_embeds)
78
+ return BaseModelOutput(last_hidden_state=encoder_outputs)
79
+
80
+ def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
81
+ shifted_labels = labels.new_zeros(labels.shape)
82
+ shifted_labels[..., 1:] = labels[..., :-1].clone()
83
+ shifted_labels[..., 0] = self.config.decoder_start_token_id
84
+ shifted_labels.masked_fill_(shifted_labels == -100, 0)
85
+ return shifted_labels
86
+
87
+ def forward(self, input_features=None, decoder_input_ids=None, input_ids=None, encoder_outputs=None, labels=None, return_dict=True, use_cache=False, **kwargs):
88
+ if decoder_input_ids is None and input_ids is not None:
89
+ decoder_input_ids = input_ids
90
+
91
+ if decoder_input_ids is None and labels is not None:
92
+ decoder_input_ids = self.prepare_decoder_input_ids_from_labels(labels)
93
+
94
+ if encoder_outputs is None and input_features is not None:
95
+ encoder_outputs = self.forward_encoder(input_features)
96
+
97
+ memory = encoder_outputs.last_hidden_state if hasattr(encoder_outputs, "last_hidden_state") else (encoder_outputs[0] if isinstance(encoder_outputs, tuple) else encoder_outputs)
98
+
99
+ if decoder_input_ids is not None:
100
+ decoder_embeds = self.embedding(decoder_input_ids)
101
+ else:
102
+ raise ValueError("decoder_input_ids must be provided")
103
+
104
+ seq_len = decoder_embeds.size(1)
105
+ tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(device=decoder_embeds.device, dtype=decoder_embeds.dtype)
106
+
107
+ decoder_output = self.decoder(
108
+ tgt=decoder_embeds,
109
+ memory=memory,
110
+ tgt_mask=tgt_mask,
111
+ tgt_is_causal=True
112
+ )
113
+
114
+ logits = self.lm_head(decoder_output)
115
+ loss = None
116
+
117
+ if labels is not None:
118
+ loss = self.loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))
119
+
120
+ if return_dict:
121
+ return Seq2SeqLMOutput(loss=loss, logits=logits, encoder_last_hidden_state=memory)
122
+
123
+ return (loss, logits) if loss is not None else logits
124
+
125
+ def prepare_inputs_for_generation(self, decoder_input_ids, past_key_values=None, attention_mask=None, encoder_outputs=None, **kwargs):
126
+ return {
127
+ "decoder_input_ids": decoder_input_ids,
128
+ "encoder_outputs": encoder_outputs,
129
+ }
130
+
131
+ def _reorder_cache(self, past_key_values, beam_idx):
132
+ pass
133
+
134
+ CONFIG_MAPPING.register("lance_asr", LanceASRConfig)
135
+ try:
136
+ MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.register(LanceASRConfig, LanceASR)
137
+ except Exception:
138
+ pass
139
+ LanceASRConfig.register_for_auto_class("AutoConfig")
140
+ LanceASR.register_for_auto_class("AutoModelForSeq2SeqLM")