#!/usr/bin/env python3 # coding: utf-8 # @Author : Xinhao Mei @CVSSP, University of Surrey # @E-mail : x.mei@surrey.ac.uk # PANNs - BART audio captioning model import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig from models.audio_encoder_config import AudioEncoderConfig from models.audio_encoder import AudioEncoderModel class BartCaptionModel(nn.Module): def __init__(self, config, cache_dir=None): super(BartCaptionModel, self).__init__() self.config = config # encoder encoder_config = AudioEncoderConfig(**config["audio_encoder_args"], audio_args=config["audio_args"]) self.encoder = AudioEncoderModel(encoder_config) # bart decoder decoder_name = config["text_decoder_args"]["name"] decoder_pretrained = config["text_decoder_args"]["pretrained"] if decoder_pretrained: self.decoder = BartForConditionalGeneration.from_pretrained(decoder_name, cache_dir=cache_dir) self.tokenizer = BartTokenizer.from_pretrained(decoder_name, cache_dir=cache_dir) else: bart_config = BartConfig.from_pretrained(decoder_name, cache_dir=cache_dir) self.tokenizer = BartTokenizer.from_pretrained(decoder_name, cache_dir=cache_dir) self.decoder = BartForConditionalGeneration.from_config(bart_config) self.enc_to_dec_proj = nn.Linear(encoder_config.hidden_size, self.decoder.config.hidden_size) self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100, label_smoothing=0.1) @property def device(self): return list(self.parameters())[0].device def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): """ Shift input ids one token to the right. """ shifted_input_ids = input_ids.new_zeros(input_ids.shape) shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() shifted_input_ids[:, 0] = decoder_start_token_id if pad_token_id is None: raise ValueError("self.model.config.pad_token_id has to be defined.") # replace possible -100 values in labels by `pad_token_id` shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) return shifted_input_ids def forward_encoder(self, audios): outputs = self.encoder(audios) outputs = self.enc_to_dec_proj(outputs.last_hidden_state) return outputs def forward_decoder(self, text, encoder_outputs): encoder_outputs = self.decoder.model.encoder( input_ids=None, inputs_embeds=encoder_outputs, return_dict=True )["last_hidden_state"] text = self.tokenizer(text, padding='longest', truncation=True, max_length=30, return_tensors="pt") input_ids = text["input_ids"].to(self.device) attention_mask = text["attention_mask"].to(self.device) decoder_targets = input_ids.masked_fill( input_ids == self.tokenizer.pad_token_id, -100 ) decoder_input_ids = self.shift_tokens_right( decoder_targets, self.decoder.config.pad_token_id, self.decoder.config.decoder_start_token_id ) decoder_outputs = self.decoder( input_ids=None, attention_mask=None, decoder_input_ids=decoder_input_ids, decoder_attention_mask=attention_mask, inputs_embeds=None, labels=None, encoder_outputs=(encoder_outputs,), return_dict=True ) lm_logits = decoder_outputs["logits"] loss = self.loss_fct(lm_logits.view(-1, self.tokenizer.vocab_size), decoder_targets.view(-1)) # loss = decoder_outputs["loss"] return loss def forward(self, audio, text): audio_embeds = self.forward_encoder(audio) loss = self.forward_decoder(text, audio_embeds) return loss def generate(self, samples, use_nucleus_sampling=False, num_beams=3, max_length=30, min_length=2, top_p=0.9, repetition_penalty=1.0, ): audio_embs = self.forward_encoder(samples) # Encoder pass: we use the BART encoder to process the audio embeddings encoder_outputs = self.decoder.model.encoder( input_ids=None, attention_mask=None, head_mask=None, inputs_embeds=audio_embs, output_attentions=None, output_hidden_states=None, return_dict=True) # Prepare decoder input (start token) # Some versions use decoder_start_token_id, others use bos_token_id start_token_id = getattr(self.decoder.config, "decoder_start_token_id", self.tokenizer.bos_token_id) if start_token_id is None: start_token_id = self.tokenizer.bos_token_id input_ids = torch.ones((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device) * start_token_id # We only need the attention mask for the encoder outputs if they were padded, # but here they are direct from audio_encoder (usually fixed size per batch) # So we create a simple all-ones mask. attention_mask = torch.ones(encoder_outputs['last_hidden_state'].shape[:2], dtype=torch.long, device=self.device) # Use the standard generate method of the BART model # We pass encoder_outputs directly. We DO NOT pass inputs_embeds here to avoid conflicts. generate_kwargs = { "input_ids": None, # Because we use encoder_outputs "encoder_outputs": encoder_outputs, "attention_mask": attention_mask, "max_length": max_length, "min_length": min_length, "repetition_penalty": repetition_penalty, "decoder_input_ids": input_ids, # Initial decoder input } if use_nucleus_sampling: generate_kwargs.update({ "do_sample": True, "top_p": top_p, "num_return_sequences": 1, }) else: generate_kwargs.update({ "num_beams": num_beams, }) outputs = self.decoder.generate(**generate_kwargs) # Raw token ID logging for debugging print(f"DEBUG: Raw generation token IDs: {outputs.tolist() if torch.is_tensor(outputs) else outputs}") captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) return captions