File size: 4,573 Bytes
4853fdc
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torch
import torch.nn as nn

import os
import sys
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from Qformer import BertConfig, BertLMHeadModel


try:
    import torch_npu
    from torch_npu.contrib import transfer_to_npu
    DEVICE_TYPE = "npu"
except ModuleNotFoundError:
    DEVICE_TYPE = "cuda"

def generate_length_mask(lens, max_length=None):
    lens = torch.as_tensor(lens)
    N = lens.size(0)
    if max_length is None:
        max_length = max(lens)
    idxs = torch.arange(max_length).repeat(N).view(N, max_length)
    idxs = idxs.to(lens.device)
    mask = (idxs < lens.view(-1, 1)).int()
    return mask

class QformerBridgeNet(torch.nn.Module):
    def __init__(self, Qformer_model_name: str = "bert-base-uncased", num_query_token: int = 32, 
                 hiddin_size: int = 1024, speech_width: int = 1024, freeze_QFormer: bool = True,
                 load_from_pretrained: str = None):
        super().__init__()
        
        self.Qformer_model_name = Qformer_model_name
        self.audio_Qformer, self.audio_query_tokens, encoder_config = self.init_Qformer(num_query_token=num_query_token,  speech_width=speech_width)
        self.audio_Qformer.cls = None
        self.audio_Qformer.bert.embeddings.word_embeddings = None
        self.audio_Qformer.bert.embeddings.position_embeddings = None
        for layer in self.audio_Qformer.bert.encoder.layer:
            layer.output = None
            layer.intermediate = None
        
        self.freeze_QFormer = freeze_QFormer
        if freeze_QFormer:
            for name, param in self.audio_Qformer.named_parameters():
                param.requires_grad = False
            self.audio_Qformer.eval()
            self.audio_query_tokens.requires_grad = False

        self.hiddin_projection = torch.nn.Linear(encoder_config.hidden_size, hiddin_size)
        #torch.nn.init.xavier_uniform_(self.hiddin_projection.weight, gain=torch.nn.init.calculate_gain("relu"))

        if load_from_pretrained:
            state_dict = torch.load(load_from_pretrained)
            del_key = ["projection.weight", "projection.bias"]
            del_state_dict = {k:v for k, v in state_dict.items() if k not in del_key}
            self.load_state_dict(del_state_dict)
            print("Load adaptor_model_pt from", load_from_pretrained)     
        
        
    def init_Qformer(self, num_query_token, speech_width, num_hidden_layers=2, cross_attention_freq=2):
        encoder_config = BertConfig.from_pretrained(self.Qformer_model_name)
        encoder_config.num_hidden_layers = num_hidden_layers
        encoder_config.encoder_width = speech_width
        # insert cross-attention layer every other block
        encoder_config.add_cross_attention = True
        encoder_config.cross_attention_freq = cross_attention_freq
        encoder_config.query_length = num_query_token
        Qformer = BertLMHeadModel(config=encoder_config)
        query_tokens = nn.Parameter(
            torch.zeros(1, num_query_token, encoder_config.hidden_size)
        )
        query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range)
        return Qformer, query_tokens, encoder_config
    
    def hidden(self, batch,):
        audio_feature, lens = batch['embed'], batch['embed_len']
        frame_atts = generate_length_mask(lens).to(audio_feature.device)
        audio_query_tokens=self.audio_query_tokens.expand(audio_feature.shape[0], -1, -1)
        #frame_atts = torch.ones(audio_feature.size()[:-1], dtype=torch.long).to(audio_feature.device)
        
        #print(audio_query_tokens.shape, audio_feature.shape, frame_atts.shape)
        audio_query_output=self.audio_Qformer.bert(
            query_embeds=audio_query_tokens, #[32,768]
            encoder_hidden_states=audio_feature,
            encoder_attention_mask=frame_atts,
            return_dict=True,
            )
        audio_hidden = audio_query_output.last_hidden_state
        return audio_hidden

    def forward(self, batch) -> torch.Tensor:   
        with torch.no_grad(), torch.amp.autocast(
            device_type=DEVICE_TYPE, enabled=False
        ):
            x = self.hidden(batch)
        x = self.hiddin_projection(x)

        mask = torch.ones(x.shape[:2])
        mask = (mask == 1).to(x.device)
        return {"output": x, "mask": mask}
        
        
if __name__ == '__main__':
    text_encoder = T5TextEncoder()
    text = ["a man is speaking", "a woman is singing while a dog is barking"]
    text_encoder.eval()
    with torch.no_grad():
        output = text_encoder(text)