File size: 4,271 Bytes
06ab9ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from typing import Optional, Union
import torch
from torch import nn
from transformers import HubertModel, HubertPreTrainedModel, Qwen2_5_VLForConditionalGeneration, AutoConfig

class ProjectorConv1d(nn.Module):

    def __init__(self, config, encoder_dim, llm_dim):
        super().__init__()
        self.conv1d = nn.Conv1d(in_channels=encoder_dim,
                                out_channels=encoder_dim,
                                kernel_size=1,
                                stride=1,
                                padding=0)
        self.linear1 = nn.Linear(encoder_dim, config.projector_hidden_size)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(config.projector_hidden_size, llm_dim)
        self.relu2 = nn.ReLU()

    def forward(self, x):
        x = x.transpose(1, 2)
        x = self.conv1d(x)
        x = x.transpose(1, 2)
        x = self.relu1(x)
        x = self.linear1(x)
        x = self.relu2(x)
        x = self.linear2(x)
        return x

class OolelSpeech(HubertPreTrainedModel):
    def __init__(self, config):
        super().__init__(config)
        self.hubert = HubertModel(config)
        self.projector = ProjectorConv1d(
            config,
            encoder_dim=config.hidden_size * 12,
            llm_dim=config.llm_hidden_size
            )
        # Initialize weights and apply final processing
        self.post_init()
    
    def load_llm(self, llm_name_or_path, **kwargs):
        self.llm = Qwen2_5_VLForConditionalGeneration.from_pretrained(llm_name_or_path, **kwargs)
    
    def forward_hubert(
        self,
        input_values: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        ):

        return self.hubert(
            input_values,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

    def embed_inputs(
        self,
        input_values: Optional[torch.Tensor] = None,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        ):

        hubert_outputs = self.forward_hubert(
            input_values,
            attention_mask=None,
            output_attentions=False,
            output_hidden_states=True,
            return_dict=return_dict,
        )
        
        hidden_states = torch.cat(hubert_outputs.hidden_states[1:], dim=-1)
        speech_embedding = self.projector(hidden_states)
        token_embeddings = self.llm.model.language_model.embed_tokens(input_ids)

        bs, sql, *_ = speech_embedding.shape
        input_embedds = torch.cat([speech_embedding, token_embeddings], dim=1)

        return input_embedds

    def forward(
        self,
        input_values: Optional[torch.Tensor] = None,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None
        ):
      
      input_embedds = self.embed_inputs(
        input_values=input_values,
        input_ids=input_ids,
        attention_mask=attention_mask,
        output_attentions=output_attentions,
        output_hidden_states=output_hidden_states,
        return_dict=return_dict,
      )
      return self.llm(
          inputs_embeds=input_embedds,
          attention_mask=attention_mask,
      )

    def generate(self, input_values, input_ids, **kwargs):
        assert input_values.shape[0] == 1, "Batch generation is not supported with huggingface."
        inputs_embeds = self.embed_inputs(
            input_values=input_values,
            input_ids=input_ids,
            attention_mask=None,
            output_attentions=False,
            output_hidden_states=True,
        )

        return self.llm.generate(inputs_embeds=inputs_embeds, **kwargs)