File size: 9,402 Bytes
05f7466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
from __future__ import annotations

from typing import Optional, List, Tuple, Union, Dict

import torch
from torch import Tensor, nn
from transformers import Qwen2ForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast

from .configuration_arkasr import ArkasrConfig
from .modeling_audio import WhisperSpecialEncoder


class AudioMLPAdapter(nn.Module):
    def __init__(self, config: ArkasrConfig):
        super().__init__()
        whisper_config = config.whisper_config
        self.merge_factor = int(config.merge_factor)

        # 音频编码器
        self.whisper = WhisperSpecialEncoder(
            whisper_config,
            use_rope=getattr(config, "use_rope", False),
        )
        # 禁用 Whisper 自带 LayerNorm
        self.whisper.layer_norm = nn.Identity()
        self.layer_norm = nn.LayerNorm(whisper_config.hidden_size)

        act_fn_map = {
            "gelu": nn.GELU(),
            "relu": nn.ReLU(),
            "selu": nn.SELU(),
        }
        act = act_fn_map.get(getattr(config, "mlp_adapter_act", "gelu"), nn.GELU())

        input_dim = whisper_config.hidden_size * self.merge_factor
        output_dim = config.hidden_size

        self.adapting = nn.Sequential(
            nn.Linear(input_dim, output_dim * 2),
            act,
            nn.Linear(output_dim * 2, output_dim),
        )

    def forward(self, audios: Tensor) -> Tensor:
        """
        Args:
            audios: (B, mel, T) 或 (B, raw_len) —— 由 WhisperSpecialEncoder 决定
        Returns:
            adapted_features: (B, Seq_Audio, LLM_Hidden_Dim)
        """
        bsz = audios.size(0)

        encoded = self.whisper(audios)[0]  # (B, T, D)
        encoded = self.layer_norm(encoded)

        seq_len = encoded.size(1)
        if seq_len % self.merge_factor != 0:
            target_len = (seq_len // self.merge_factor) * self.merge_factor
            if target_len <= 0:
                # 极短音频兜底:pad 到 merge_factor
                target_len = self.merge_factor
                if seq_len < target_len:
                    pad_len = target_len - seq_len
                    pad = encoded.new_zeros((bsz, pad_len, encoded.size(-1)))
                    encoded = torch.cat([encoded, pad], dim=1)
            else:
                encoded = encoded[:, :target_len, :]

        encoded = encoded.reshape(bsz, -1, encoded.size(-1) * self.merge_factor)
        adapted = self.adapting(encoded)  # (B, T/k, hidden)
        return adapted


class ArkasrForConditionalGeneration(Qwen2ForCausalLM):
    config_class = ArkasrConfig
    _no_split_modules = ["WhisperSpecialEncoder"]

    def __init__(self, config: ArkasrConfig):
        super().__init__(config)
        self.audio_encoder = AudioMLPAdapter(config)

        self.audio_token_id = getattr(config, "audio_token_id", None)
        if self.audio_token_id is None:
            raise ValueError("`audio_token_id` must be defined in config.")

    @staticmethod
    def _cache_seq_len(past_key_values) -> int:
        if past_key_values is None:
            return 0
        if hasattr(past_key_values, "get_seq_length"):
            try:
                return int(past_key_values.get_seq_length())
            except Exception:
                return 0
        try:
            return int(past_key_values[0][0].shape[-2])
        except Exception:
            return 0

    def _inject_audio_embeddings_batch_encode_then_loop_scatter(
        self,
        input_ids: torch.LongTensor,          # (B, S)
        inputs_embeds: torch.FloatTensor,     # (B, S, H)
        audios: Tensor,                       # (B, ...)
    ) -> torch.FloatTensor:
        """
        先对「有 audio token 的样本」做一次 batch 音频编码,
        然后 for-loop 把每个样本的 audio features 按 audio_token 位置写回 inputs_embeds。

        好处:
        - encoder 只跑一次(快)
        - 写回按样本做,不会跨样本错位(稳)
        - 碰到某行没有 audio_token:直接跳过(TTS 行无影响)

        约束:
        - 每条样本的 audio_token 数量 n_i 需要和 audio_encoder 输出的 Sa 对齐。
          如果不对齐:这里采用截断/补零对齐到 n_i(不报错)。
        """
        B, S = input_ids.shape
        H = inputs_embeds.size(-1)
        device = inputs_embeds.device
        dtype = inputs_embeds.dtype

        # 找到哪些样本需要注入
        mask = (input_ids == self.audio_token_id)  # (B, S)
        per_counts = mask.sum(dim=1)               # (B,)
        need_idx = (per_counts > 0).nonzero(as_tuple=False).squeeze(1)  # (K,)

        if need_idx.numel() == 0:
            return inputs_embeds

        # 只编码需要注入的那部分音频(K, ...)
        audios_sub = audios.index_select(0, need_idx)
        feats_sub = self.audio_encoder(audios_sub)  # (K, Sa, H)

        # 写回:逐样本替换(写回操作本身几乎不耗时)
        feats_sub = feats_sub.to(device=device, dtype=dtype)
        Sa = feats_sub.size(1)

        # 逐个样本注入
        for k in range(need_idx.numel()):
            i = int(need_idx[k].item())
            n_i = int(per_counts[i].item())
            if n_i <= 0:
                continue

            feat_i = feats_sub[k]  # (Sa, H)

            # 对齐到该样本的 audio token 数 n_i
            if Sa < n_i:
                pad = feat_i.new_zeros((n_i - Sa, H))
                feat_i_use = torch.cat([feat_i, pad], dim=0)
            elif Sa > n_i:
                feat_i_use = feat_i[:n_i]
            else:
                feat_i_use = feat_i

            pos_i = mask[i].nonzero(as_tuple=False).squeeze(1)  # (n_i,)
            # 写回 embeddings
            inputs_embeds[i, pos_i, :] = feat_i_use

        return inputs_embeds

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        audios: Optional[Tensor] = None,
        attention_mask: Optional[Tensor] = None,
        position_ids: Optional[Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        labels: Optional[torch.LongTensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        logits_to_keep: int | torch.Tensor = 0,
        **kwargs,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        if inputs_embeds is None:
            if input_ids is None:
                raise ValueError("Either `input_ids` or `inputs_embeds` must be provided.")
            inputs_embeds = self.model.embed_tokens(input_ids)

        # 只在首步(past_len==0)注入,避免 generation 后续重复 encode
        past_len = self._cache_seq_len(past_key_values)
        if audios is not None and input_ids is not None and past_len == 0:
            inputs_embeds = self._inject_audio_embeddings_batch_encode_then_loop_scatter(
                input_ids=input_ids,
                inputs_embeds=inputs_embeds,
                audios=audios,
            )

        outputs = self.model(
            input_ids=None,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
        )

        hidden_states = outputs[0]

        # logits(避免重复算 lm_head)
        if isinstance(logits_to_keep, int) and logits_to_keep > 0:
            hidden_for_logits = hidden_states[:, -logits_to_keep:, :]
        elif isinstance(logits_to_keep, torch.Tensor):
            hidden_for_logits = hidden_states[:, logits_to_keep, :]
        else:
            hidden_for_logits = hidden_states

        logits = self.lm_head(hidden_for_logits)

        loss = None
        if labels is not None:
            loss = self.loss_function(
                logits=logits,
                labels=labels,
                vocab_size=self.config.vocab_size,
                **kwargs,
            )

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        inputs_embeds=None,
        **kwargs,
    ):
        past_len = self._cache_seq_len(past_key_values)
        if past_len > 0:
            input_ids = input_ids[:, -1:]

        model_inputs = {
            "input_ids": input_ids,
            "past_key_values": past_key_values,
            "use_cache": kwargs.get("use_cache"),
            "attention_mask": attention_mask,
            # audios 透传;forward 内 past_len==0 才注入,所以后续 step 不会重复编码
            "audios": kwargs.get("audios", None),
        }

        if inputs_embeds is not None and past_key_values is None:
            model_inputs["inputs_embeds"] = inputs_embeds
            del model_inputs["input_ids"]

        return model_inputs


__all__ = ["ArkasrForConditionalGeneration", "AudioMLPAdapter"]