File size: 6,880 Bytes
1bc2162 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 | #!/usr/bin/env python3
# coding: utf-8
# @Author : Xinhao Mei @CVSSP, University of Surrey
# @E-mail : x.mei@surrey.ac.uk
# PANNs - BART audio captioning model
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig
from models.audio_encoder_config import AudioEncoderConfig
from models.audio_encoder import AudioEncoderModel
class BartCaptionModel(nn.Module):
def __init__(self, config, cache_dir=None):
super(BartCaptionModel, self).__init__()
self.config = config
# encoder
encoder_config = AudioEncoderConfig(**config["audio_encoder_args"],
audio_args=config["audio_args"])
self.encoder = AudioEncoderModel(encoder_config)
# bart decoder
decoder_name = config["text_decoder_args"]["name"]
decoder_pretrained = config["text_decoder_args"]["pretrained"]
if decoder_pretrained:
self.decoder = BartForConditionalGeneration.from_pretrained(decoder_name, cache_dir=cache_dir)
self.tokenizer = BartTokenizer.from_pretrained(decoder_name, cache_dir=cache_dir)
else:
bart_config = BartConfig.from_pretrained(decoder_name, cache_dir=cache_dir)
self.tokenizer = BartTokenizer.from_pretrained(decoder_name, cache_dir=cache_dir)
self.decoder = BartForConditionalGeneration.from_config(bart_config)
self.enc_to_dec_proj = nn.Linear(encoder_config.hidden_size, self.decoder.config.hidden_size)
self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100, label_smoothing=0.1)
@property
def device(self):
return list(self.parameters())[0].device
def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int):
"""
Shift input ids one token to the right.
"""
shifted_input_ids = input_ids.new_zeros(input_ids.shape)
shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
shifted_input_ids[:, 0] = decoder_start_token_id
if pad_token_id is None:
raise ValueError("self.model.config.pad_token_id has to be defined.")
# replace possible -100 values in labels by `pad_token_id`
shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
return shifted_input_ids
def forward_encoder(self, audios):
outputs = self.encoder(audios)
outputs = self.enc_to_dec_proj(outputs.last_hidden_state)
return outputs
def forward_decoder(self, text, encoder_outputs):
encoder_outputs = self.decoder.model.encoder(
input_ids=None,
inputs_embeds=encoder_outputs,
return_dict=True
)["last_hidden_state"]
text = self.tokenizer(text,
padding='longest',
truncation=True,
max_length=30,
return_tensors="pt")
input_ids = text["input_ids"].to(self.device)
attention_mask = text["attention_mask"].to(self.device)
decoder_targets = input_ids.masked_fill(
input_ids == self.tokenizer.pad_token_id, -100
)
decoder_input_ids = self.shift_tokens_right(
decoder_targets, self.decoder.config.pad_token_id, self.decoder.config.decoder_start_token_id
)
decoder_outputs = self.decoder(
input_ids=None,
attention_mask=None,
decoder_input_ids=decoder_input_ids,
decoder_attention_mask=attention_mask,
inputs_embeds=None,
labels=None,
encoder_outputs=(encoder_outputs,),
return_dict=True
)
lm_logits = decoder_outputs["logits"]
loss = self.loss_fct(lm_logits.view(-1, self.tokenizer.vocab_size), decoder_targets.view(-1))
# loss = decoder_outputs["loss"]
return loss
def forward(self, audio, text):
audio_embeds = self.forward_encoder(audio)
loss = self.forward_decoder(text, audio_embeds)
return loss
def generate(self,
samples,
use_nucleus_sampling=False,
num_beams=3,
max_length=30,
min_length=2,
top_p=0.9,
repetition_penalty=1.0,
):
audio_embs = self.forward_encoder(samples)
# Encoder pass: we use the BART encoder to process the audio embeddings
encoder_outputs = self.decoder.model.encoder(
input_ids=None,
attention_mask=None,
head_mask=None,
inputs_embeds=audio_embs,
output_attentions=None,
output_hidden_states=None,
return_dict=True)
# Prepare decoder input (start token)
# Some versions use decoder_start_token_id, others use bos_token_id
start_token_id = getattr(self.decoder.config, "decoder_start_token_id", self.tokenizer.bos_token_id)
if start_token_id is None:
start_token_id = self.tokenizer.bos_token_id
input_ids = torch.ones((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device) * start_token_id
# We only need the attention mask for the encoder outputs if they were padded,
# but here they are direct from audio_encoder (usually fixed size per batch)
# So we create a simple all-ones mask.
attention_mask = torch.ones(encoder_outputs['last_hidden_state'].shape[:2], dtype=torch.long, device=self.device)
# Use the standard generate method of the BART model
# We pass encoder_outputs directly. We DO NOT pass inputs_embeds here to avoid conflicts.
generate_kwargs = {
"input_ids": None, # Because we use encoder_outputs
"encoder_outputs": encoder_outputs,
"attention_mask": attention_mask,
"max_length": max_length,
"min_length": min_length,
"repetition_penalty": repetition_penalty,
"decoder_input_ids": input_ids, # Initial decoder input
}
if use_nucleus_sampling:
generate_kwargs.update({
"do_sample": True,
"top_p": top_p,
"num_return_sequences": 1,
})
else:
generate_kwargs.update({
"num_beams": num_beams,
})
outputs = self.decoder.generate(**generate_kwargs)
# Raw token ID logging for debugging
print(f"DEBUG: Raw generation token IDs: {outputs.tolist() if torch.is_tensor(outputs) else outputs}")
captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
return captions
|