|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
import os |
|
|
import sys |
|
|
sys.path.append(os.path.dirname(os.path.abspath(__file__))) |
|
|
from Qformer import BertConfig, BertLMHeadModel |
|
|
|
|
|
|
|
|
try: |
|
|
import torch_npu |
|
|
from torch_npu.contrib import transfer_to_npu |
|
|
DEVICE_TYPE = "npu" |
|
|
except ModuleNotFoundError: |
|
|
DEVICE_TYPE = "cuda" |
|
|
|
|
|
def generate_length_mask(lens, max_length=None): |
|
|
lens = torch.as_tensor(lens) |
|
|
N = lens.size(0) |
|
|
if max_length is None: |
|
|
max_length = max(lens) |
|
|
idxs = torch.arange(max_length).repeat(N).view(N, max_length) |
|
|
idxs = idxs.to(lens.device) |
|
|
mask = (idxs < lens.view(-1, 1)).int() |
|
|
return mask |
|
|
|
|
|
class QformerBridgeNet(torch.nn.Module): |
|
|
def __init__(self, Qformer_model_name: str = "bert-base-uncased", num_query_token: int = 32, |
|
|
hiddin_size: int = 1024, speech_width: int = 1024, freeze_QFormer: bool = True, |
|
|
load_from_pretrained: str = None): |
|
|
super().__init__() |
|
|
|
|
|
self.Qformer_model_name = Qformer_model_name |
|
|
self.audio_Qformer, self.audio_query_tokens, encoder_config = self.init_Qformer(num_query_token=num_query_token, speech_width=speech_width) |
|
|
self.audio_Qformer.cls = None |
|
|
self.audio_Qformer.bert.embeddings.word_embeddings = None |
|
|
self.audio_Qformer.bert.embeddings.position_embeddings = None |
|
|
for layer in self.audio_Qformer.bert.encoder.layer: |
|
|
layer.output = None |
|
|
layer.intermediate = None |
|
|
|
|
|
self.freeze_QFormer = freeze_QFormer |
|
|
if freeze_QFormer: |
|
|
for name, param in self.audio_Qformer.named_parameters(): |
|
|
param.requires_grad = False |
|
|
self.audio_Qformer.eval() |
|
|
self.audio_query_tokens.requires_grad = False |
|
|
|
|
|
self.hiddin_projection = torch.nn.Linear(encoder_config.hidden_size, hiddin_size) |
|
|
|
|
|
|
|
|
if load_from_pretrained: |
|
|
state_dict = torch.load(load_from_pretrained) |
|
|
del_key = ["projection.weight", "projection.bias"] |
|
|
del_state_dict = {k:v for k, v in state_dict.items() if k not in del_key} |
|
|
self.load_state_dict(del_state_dict) |
|
|
print("Load adaptor_model_pt from", load_from_pretrained) |
|
|
|
|
|
|
|
|
def init_Qformer(self, num_query_token, speech_width, num_hidden_layers=2, cross_attention_freq=2): |
|
|
encoder_config = BertConfig.from_pretrained(self.Qformer_model_name) |
|
|
encoder_config.num_hidden_layers = num_hidden_layers |
|
|
encoder_config.encoder_width = speech_width |
|
|
|
|
|
encoder_config.add_cross_attention = True |
|
|
encoder_config.cross_attention_freq = cross_attention_freq |
|
|
encoder_config.query_length = num_query_token |
|
|
Qformer = BertLMHeadModel(config=encoder_config) |
|
|
query_tokens = nn.Parameter( |
|
|
torch.zeros(1, num_query_token, encoder_config.hidden_size) |
|
|
) |
|
|
query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) |
|
|
return Qformer, query_tokens, encoder_config |
|
|
|
|
|
def hidden(self, batch,): |
|
|
audio_feature, lens = batch['embed'], batch['embed_len'] |
|
|
frame_atts = generate_length_mask(lens).to(audio_feature.device) |
|
|
audio_query_tokens=self.audio_query_tokens.expand(audio_feature.shape[0], -1, -1) |
|
|
|
|
|
|
|
|
|
|
|
audio_query_output=self.audio_Qformer.bert( |
|
|
query_embeds=audio_query_tokens, |
|
|
encoder_hidden_states=audio_feature, |
|
|
encoder_attention_mask=frame_atts, |
|
|
return_dict=True, |
|
|
) |
|
|
audio_hidden = audio_query_output.last_hidden_state |
|
|
return audio_hidden |
|
|
|
|
|
def forward(self, batch) -> torch.Tensor: |
|
|
with torch.no_grad(), torch.amp.autocast( |
|
|
device_type=DEVICE_TYPE, enabled=False |
|
|
): |
|
|
x = self.hidden(batch) |
|
|
x = self.hiddin_projection(x) |
|
|
|
|
|
mask = torch.ones(x.shape[:2]) |
|
|
mask = (mask == 1).to(x.device) |
|
|
return {"output": x, "mask": mask} |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
text_encoder = T5TextEncoder() |
|
|
text = ["a man is speaking", "a woman is singing while a dog is barking"] |
|
|
text_encoder.eval() |
|
|
with torch.no_grad(): |
|
|
output = text_encoder(text) |
|
|
|