File size: 5,707 Bytes
05e39c3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from typing import Optional

import torch
from torch import Tensor, nn
from transformers import LlamaForCausalLM
from transformers.modeling_outputs import CausalLMOutputWithPast

from .configuration_glmasr import GlmasrConfig
from .modeling_audio import WhisperSpecialEncoder


class AudioMLPAdapter(nn.Module):
    def __init__(self, config: GlmasrConfig):
        super().__init__()
        whisper_config = config.whisper_config
        self.merge_factor = config.merge_factor
        self.whisper = WhisperSpecialEncoder(
            whisper_config,
            use_rope=config.use_rope,
        )
        self.whisper.layer_norm = nn.Identity()
        self.layer_norm = nn.LayerNorm(whisper_config.hidden_size)
        act = {
            "gelu": nn.GELU(),
            "relu": nn.ReLU(),
            "selu": nn.SELU(),
        }[config.mlp_adapter_act]
        hidden = whisper_config.hidden_size * self.merge_factor
        output_dim = config.lm_config.hidden_size
        self.adapting = nn.Sequential(
            nn.Linear(hidden, output_dim * 2),
            act,
            nn.Linear(output_dim * 2, output_dim),
        )
        self.audio_bos_eos_token = nn.Embedding(2, output_dim)

    def forward(self, audios: Tensor) -> tuple[Tensor, Tensor, Tensor]:
        bsz = audios.size(0)
        encoded = self.whisper(audios)[0]
        encoded = self.layer_norm(encoded)
        encoded = encoded.reshape(bsz, -1, encoded.size(-1) * self.merge_factor)
        adapted = self.adapting(encoded)
        boa = self.audio_bos_eos_token.weight[0][None, :]
        eoa = self.audio_bos_eos_token.weight[1][None, :]
        return adapted, boa, eoa


class GlmasrModel(LlamaForCausalLM):
    config_class = GlmasrConfig

    def __init__(self, config: GlmasrConfig):
        super().__init__(config.lm_config)
        self.audio_encoder = AudioMLPAdapter(config)
        self.all_config = config

    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        audios: Optional[Tensor] = None,
        audio_offsets: Optional[list[list[int]]] = None,
        audio_length: Optional[list[list[int]]] = None,
        attention_mask: Optional[Tensor] = None,
        position_ids: Optional[Tensor] = None,
        past_key_values: Optional[tuple] = None,
        use_cache: Optional[bool] = None,
        **kwargs,
    ) -> CausalLMOutputWithPast:
        tokens = input_ids
        vocab_size = self.config.vocab_size
        tokens = torch.clamp(tokens, 0, vocab_size - 1)
        language_embs = self.model.embed_tokens(tokens)

        have_audio = audios is not None and (
            kwargs.get("past_key_values") is None or len(kwargs["past_key_values"]) == 0
        )
        if have_audio:
            if audio_length is None:
                raise ValueError("audio_length is required when audio_offsets are provided")
            audio_embs, boa, eoa = self.audio_encoder(audios)
            index = 0
            for batch, (offsets, lengths) in enumerate(zip(audio_offsets, audio_length)):
                for offset, length in zip(offsets, lengths):
                    language_embs[batch, offset : offset + length] = audio_embs[index, :length]
                    language_embs[batch, offset - 1] = boa
                    language_embs[batch, offset + length] = eoa
                    index += 1

        kwargs.pop("inputs_embeds", None)
        kwargs.pop("is_first_forward", None)

        outputs = self.model(
            inputs_embeds=language_embs,
            attention_mask=attention_mask,
            position_ids=position_ids,
            past_key_values=past_key_values,
            use_cache=use_cache,
            **kwargs,
        )
        logits = self.lm_head(outputs[0])
        return CausalLMOutputWithPast(
            loss=None,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def _update_model_kwargs_for_generation(self, *args, **kwargs):
        model_kwargs = super()._update_model_kwargs_for_generation(*args, **kwargs)
        model_kwargs["is_first_forward"] = False
        position_ids = model_kwargs.get("position_ids")
        if position_ids is not None:
            next_pos = position_ids[..., -1:].clone() + 1
            model_kwargs["position_ids"] = torch.cat([position_ids, next_pos], dim=-1)
        return model_kwargs

    def prepare_inputs_for_generation(
        self,
        *args,
        past_key_values: Optional[tuple] = None,
        attention_mask: Optional[Tensor] = None,
        position_ids: Optional[Tensor] = None,
        use_cache: Optional[bool] = None,
        is_first_forward: bool = True,
        **kwargs,
    ):
        prepared = super().prepare_inputs_for_generation(
            *args,
            past_key_values=past_key_values,
            attention_mask=attention_mask,
            position_ids=position_ids,
            use_cache=use_cache,
            is_first_forward=is_first_forward,
            **kwargs,
        )
        for key, value in kwargs.items():
            if key not in prepared and key.startswith("audio"):
                prepared[key] = value
        if is_first_forward and past_key_values is not None and len(past_key_values) > 0:
            cached_len = past_key_values[0][0].shape[2]
            prepared["input_ids"] = prepared["input_ids"][:, cached_len:]
            if "position_ids" in prepared:
                prepared["position_ids"] = prepared["position_ids"][:, cached_len:]
        if not is_first_forward:
            prepared["audios"] = None
        return prepared


__all__ = ["GlmasrModel"]