Toun / models /bert_captioning.py
babaTEEpe's picture
Upload 11 files
1bc2162 verified
#!/usr/bin/env python3
# coding: utf-8
# @Author : Xinhao Mei @CVSSP, University of Surrey
# @E-mail : x.mei@surrey.ac.uk
import numpy as np
import torch
import torch.nn as nn
import yaml
from tokenizers import Tokenizer
from transformers import BertTokenizer, AutoConfig, BertConfig, BertLMHeadModel, AutoTokenizer, PreTrainedTokenizerFast, \
BartTokenizer, BartConfig
import os
from gensim.models.word2vec import Word2Vec
# from data_handling.WordTokenizer import WordTokenizer
from models.audio_encoder_config import AudioEncoderConfig
from models.audio_encoder import AudioEncoderModel
from models.configuration_audio_encoder_decoder import AudioEncoderDecoderConfig
from models.modeling_audio_encoder_decoder import AudioEncoderDecoderModel
class BertCaptionModel(nn.Module):
def __init__(self, config):
super(BertCaptionModel, self).__init__()
self.config = config
encoder_config = AudioEncoderConfig(**config["audio_encoder_args"],
audio_args=config["audio_args"])
self.encoder = AudioEncoderModel(encoder_config)
decoder_name = config["text_decoder_args"]["name"]
decoder_pretrained = config["text_decoder_args"]["pretrained"]
self.tokenizer = BertTokenizer.from_pretrained(decoder_name)
if decoder_pretrained:
decoder_config = AutoConfig.from_pretrained(config["text_decoder_args"]["name"],
add_cross_attention=True,
is_decoder=True)
else:
config["text_decoder_args"]["vocab_size"] = self.tokenizer.vocab_size
decoder_config = BertConfig(**config["text_decoder_args"]["bert_args"])
self.model_config = AudioEncoderDecoderConfig.from_encoder_decoder_configs(encoder_config,
decoder_config)
self.model_config.decoder_start_token_id = self.tokenizer.cls_token_id
self.model = AudioEncoderDecoderModel(config=self.model_config,
is_pretrained=False)
@property
def device(self):
return list(self.parameters())[0].device
def forward_encoder(self, audios):
# samples: audio features
outputs = self.model.encoder(audios)
return outputs.last_hidden_state
def forward_decoder(self, text, audio_embeds):
# samples: raw texts
text = self.tokenizer(text,
padding="longest",
truncation=True,
max_length=30,
return_tensors="pt")
input_ids = text["input_ids"].to(self.device)
attention_mask = text["attention_mask"].to(self.device)
decoder_targets = input_ids.masked_fill(
text.input_ids == self.tokenizer.pad_token_id, -100)
decoder_targets[:, 0] = -100
# audio_atts = torch.ones(audio_embeds.size()[:-1], dtype=torch.long).to(self.device)
decoder_output = self.model.decoder(
input_ids=input_ids,
attention_mask=attention_mask,
encoder_hidden_states=audio_embeds,
encoder_attention_mask=None,
labels=decoder_targets,
return_dict=True
)
return decoder_output, decoder_targets
def forward(self, audios, text):
text = self.tokenizer(text,
padding="longest",
truncation=True,
max_length=30,
return_tensors="pt")
input_ids = text["input_ids"].to(self.device)
attention_mask = text["attention_mask"].to(self.device)
decoder_targets = input_ids.masked_fill(
input_ids == self.tokenizer.pad_token_id, -100)
decoder_targets[:, 0] = -100
decoder_output = self.model(
audio_feats=audios,
decoder_input_ids=input_ids,
decoder_attention_mask=attention_mask,
labels=decoder_targets,
return_dict=True
)
return decoder_output.loss
def generate(self,
samples,
use_nucleus_sampling=False,
num_beams=3,
max_length=30,
min_length=2,
top_p=0.9,
repetition_penalty=1.0,
):
# samples: audios
if use_nucleus_sampling:
outputs = self.model.generate(
inputs=samples,
max_length=max_length,
min_length=min_length,
do_sample=True,
top_p=top_p,
num_return_sequences=1,
bos_token_id=self.tokenizer.cls_token_id,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=1.1,
decoder_start_token_id=self.model_config.decoder_start_token_id
)
else:
outputs = self.model.generate(
inputs=samples,
max_length=max_length,
min_length=min_length,
num_beams=num_beams,
bos_token_id=self.tokenizer.cls_token_id,
eos_token_id=self.tokenizer.sep_token_id,
pad_token_id=self.tokenizer.pad_token_id,
repetition_penalty=repetition_penalty,
decoder_start_token_id=self.model_config.decoder_start_token_id
)
captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return captions
if __name__ == '__main__':
os.chdir("../")
with open("settings/settings.yaml", "r") as f:
config = yaml.safe_load(f)
model = BertCaptionModel(config)
audio_feats = torch.randn(16, 1, 64, 100)
text = ["this is a sample" for i in range(16)]
output = model.generate(audio_feats)
print(output)