File size: 6,056 Bytes
963642f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig, GenerationMixin
from transformers.modeling_outputs import Seq2SeqLMOutput, BaseModelOutput
from transformers.models.auto.configuration_auto import CONFIG_MAPPING
from transformers.models.auto.modeling_auto import MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING

class LanceASRConfig(PretrainedConfig):
    model_type = "lance_asr"
    is_encoder_decoder = True
    def __init__(self, vocab_size=50257, hidden_size=256, num_layers=4, num_heads=4, num_mel_bins=128, architectures=["LanceASR"], **kwargs):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_heads = num_heads
        self.num_mel_bins = num_mel_bins
        self.architectures = architectures
        self.is_encoder_decoder = True
        self.decoder_start_token_id = kwargs.get("decoder_start_token_id", 0)

class LanceASR(PreTrainedModel, GenerationMixin):
    config_class = LanceASRConfig
    _supports_cache_class = False

    def __init__(self, config):
        config.is_encoder_decoder = True
        super().__init__(config)
        self.config = config
        
        # Audio feature extraction (Conv subsampling)
        self.conv1 = nn.Conv1d(config.num_mel_bins, config.hidden_size, kernel_size=3, padding=1)
        self.conv2 = nn.Conv1d(config.hidden_size, config.hidden_size, kernel_size=3, stride=2, padding=1)
        
        # Text embedding
        self.embedding = nn.Embedding(config.vocab_size, config.hidden_size)

        self.encoder = nn.TransformerEncoder(
            nn.TransformerEncoderLayer(d_model=config.hidden_size, nhead=config.num_heads, batch_first=True),
            num_layers=config.num_layers
        )
        self.decoder = nn.TransformerDecoder(
            nn.TransformerDecoderLayer(d_model=config.hidden_size, nhead=config.num_heads, batch_first=True),
            num_layers=config.num_layers
        )

        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size)
        self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
        
        # Generation config defaults
        self.generation_config.max_new_tokens = 250
        self.generation_config.temperature = 0.8
        self.generation_config.do_sample = True
        self.generation_config.decoder_start_token_id = self.config.decoder_start_token_id

        self.init_weights()

        self.to(torch.bfloat16)

    def get_encoder(self):
        class EncoderWrapper(nn.Module):
            def __init__(self, model):
                super().__init__()
                self.model = model
                self.main_input_name = "input_features"
            def forward(self, input_features, attention_mask=None, **kwargs):
                return self.model.forward_encoder(input_features)
            def __call__(self, *args, **kwargs):
                return self.forward(*args, **kwargs)
        return EncoderWrapper(self)

    def forward_encoder(self, input_features):
        hidden_states = nn.functional.gelu(self.conv1(input_features))
        hidden_states = nn.functional.gelu(self.conv2(hidden_states))
        
        inputs_embeds = hidden_states.permute(0, 2, 1)
        encoder_outputs = self.encoder(inputs_embeds)
        return BaseModelOutput(last_hidden_state=encoder_outputs)

    def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
        shifted_labels = labels.new_zeros(labels.shape)
        shifted_labels[..., 1:] = labels[..., :-1].clone()
        shifted_labels[..., 0] = self.config.decoder_start_token_id
        shifted_labels.masked_fill_(shifted_labels == -100, 0)
        return shifted_labels

    def forward(self, input_features=None, decoder_input_ids=None, input_ids=None, encoder_outputs=None, labels=None, return_dict=True, use_cache=False, **kwargs):
        if decoder_input_ids is None and input_ids is not None:
            decoder_input_ids = input_ids
            
        if decoder_input_ids is None and labels is not None:
            decoder_input_ids = self.prepare_decoder_input_ids_from_labels(labels)
            
        if encoder_outputs is None and input_features is not None:
            encoder_outputs = self.forward_encoder(input_features)
            
        memory = encoder_outputs.last_hidden_state if hasattr(encoder_outputs, "last_hidden_state") else (encoder_outputs[0] if isinstance(encoder_outputs, tuple) else encoder_outputs)

        if decoder_input_ids is not None:
            decoder_embeds = self.embedding(decoder_input_ids)
        else:
            raise ValueError("decoder_input_ids must be provided")
            
        seq_len = decoder_embeds.size(1)
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(seq_len).to(device=decoder_embeds.device, dtype=decoder_embeds.dtype)

        decoder_output = self.decoder(
            tgt=decoder_embeds, 
            memory=memory, 
            tgt_mask=tgt_mask, 
            tgt_is_causal=True
        )

        logits = self.lm_head(decoder_output)
        loss = None

        if labels is not None:
            loss = self.loss_fct(logits.view(-1, self.config.vocab_size), labels.view(-1))

        if return_dict:
            return Seq2SeqLMOutput(loss=loss, logits=logits, encoder_last_hidden_state=memory)

        return (loss, logits) if loss is not None else logits

    def prepare_inputs_for_generation(self, decoder_input_ids, past_key_values=None, attention_mask=None, encoder_outputs=None, **kwargs):
        return {
            "decoder_input_ids": decoder_input_ids,
            "encoder_outputs": encoder_outputs,
        }

    def _reorder_cache(self, past_key_values, beam_idx):
        pass

CONFIG_MAPPING.register("lance_asr", LanceASRConfig)
try:
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING.register(LanceASRConfig, LanceASR)
except Exception:
    pass
LanceASRConfig.register_for_auto_class("AutoConfig")
LanceASR.register_for_auto_class("AutoModelForSeq2SeqLM")