File size: 6,114 Bytes
1bc2162 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | #!/usr/bin/env python3
# coding: utf-8
# @Author : Xinhao Mei @CVSSP, University of Surrey
# @E-mail : x.mei@surrey.ac.uk
import numpy as np
import torch
import torch.nn as nn
import yaml
from tokenizers import Tokenizer
from transformers import BertTokenizer, AutoConfig, BertConfig, BertLMHeadModel, AutoTokenizer, PreTrainedTokenizerFast, \
BartTokenizer, BartConfig
import os
from gensim.models.word2vec import Word2Vec
# from data_handling.WordTokenizer import WordTokenizer
from models.audio_encoder_config import AudioEncoderConfig
from models.audio_encoder import AudioEncoderModel
from models.configuration_audio_encoder_decoder import AudioEncoderDecoderConfig
from models.modeling_audio_encoder_decoder import AudioEncoderDecoderModel
class BertCaptionModel(nn.Module):
def __init__(self, config):
super(BertCaptionModel, self).__init__()
self.config = config
encoder_config = AudioEncoderConfig(**config["audio_encoder_args"],
audio_args=config["audio_args"])
self.encoder = AudioEncoderModel(encoder_config)
decoder_name = config["text_decoder_args"]["name"]
decoder_pretrained = config["text_decoder_args"]["pretrained"]
self.tokenizer = BertTokenizer.from_pretrained(decoder_name)
if decoder_pretrained:
decoder_config = AutoConfig.from_pretrained(config["text_decoder_args"]["name"],
add_cross_attention=True,
is_decoder=True)
else:
config["text_decoder_args"]["vocab_size"] = self.tokenizer.vocab_size
decoder_config = BertConfig(**config["text_decoder_args"]["bert_args"])
self.model_config = AudioEncoderDecoderConfig.from_encoder_decoder_configs(encoder_config,
decoder_config)
self.model_config.decoder_start_token_id = self.tokenizer.cls_token_id
self.model = AudioEncoderDecoderModel(config=self.model_config,
is_pretrained=False)
@property
def device(self):
return list(self.parameters())[0].device
def forward_encoder(self, audios):
# samples: audio features
outputs = self.model.encoder(audios)
return outputs.last_hidden_state
def forward_decoder(self, text, audio_embeds):
# samples: raw texts
text = self.tokenizer(text,
padding="longest",
truncation=True,
max_length=30,
return_tensors="pt")
input_ids = text["input_ids"].to(self.device)
attention_mask = text["attention_mask"].to(self.device)
decoder_targets = input_ids.masked_fill(
text.input_ids == self.tokenizer.pad_token_id, -100)
decoder_targets[:, 0] = -100
# audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(self.device)
decoder_output = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=audio_embeds,
encoder_attention_mask=None,
labels=decoder_targets,
return_dict=True
)
return decoder_output, decoder_targets
def forward(self, audios, text):
text = self.tokenizer(text,
padding="longest",
truncation=True,
max_length=30,
return_tensors="pt")
input_ids = text["input_ids"].to(self.device)
attention_mask = text["attention_mask"].to(self.device)
decoder_targets = input_ids.masked_fill(
input_ids == self.tokenizer.pad_token_id, -100)
decoder_targets[:, 0] = -100
decoder_output = self.model(
audio_feats=audios,
decoder_input_ids=input_ids,
decoder_attention_mask=attention_mask,
labels=decoder_targets,
return_dict=True
)
return decoder_output.loss
def generate(self,
samples,
use_nucleus_sampling=False,
num_beams=3,
max_length=30,
min_length=2,
top_p=0.9,
repetition_penalty=1.0,
):
# samples: audios
if use_nucleus_sampling:
outputs = self.model.generate(
inputs=samples,
max_length=max_length,
min_length=min_length,
do_sample=True,
top_p=top_p,
num_return_sequences=1,
bos_token_id=self.tokenizer.cls_token_id,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=1.1,
decoder_start_token_id=self.model_config.decoder_start_token_id
)
else:
outputs = self.model.generate(
inputs=samples,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
bos_token_id=self.tokenizer.cls_token_id,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=repetition_penalty,
decoder_start_token_id=self.model_config.decoder_start_token_id
)
captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return captions
if __name__ == '__main__':
os.chdir("../")
with open("settings/settings.yaml", "r") as f:
config = yaml.safe_load(f)
model = BertCaptionModel(config)
audio_feats = torch.randn(16, 1, 64, 100)
text = ["this is a sample" for i in range(16)]
output = model.generate(audio_feats)
print(output)
|