| |
| |
| |
| |
|
|
|
|
| |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| from transformers import BartForConditionalGeneration, BartTokenizer, BartConfig |
| from models.audio_encoder_config import AudioEncoderConfig |
| from models.audio_encoder import AudioEncoderModel |
|
|
|
|
| class BartCaptionModel(nn.Module): |
|
|
| def __init__(self, config, cache_dir=None): |
| super(BartCaptionModel, self).__init__() |
|
|
| self.config = config |
|
|
| |
| encoder_config = AudioEncoderConfig(**config["audio_encoder_args"], |
| audio_args=config["audio_args"]) |
| self.encoder = AudioEncoderModel(encoder_config) |
|
|
| |
| decoder_name = config["text_decoder_args"]["name"] |
| decoder_pretrained = config["text_decoder_args"]["pretrained"] |
| if decoder_pretrained: |
| self.decoder = BartForConditionalGeneration.from_pretrained(decoder_name, cache_dir=cache_dir) |
| self.tokenizer = BartTokenizer.from_pretrained(decoder_name, cache_dir=cache_dir) |
| else: |
| bart_config = BartConfig.from_pretrained(decoder_name, cache_dir=cache_dir) |
| self.tokenizer = BartTokenizer.from_pretrained(decoder_name, cache_dir=cache_dir) |
| self.decoder = BartForConditionalGeneration.from_config(bart_config) |
|
|
| self.enc_to_dec_proj = nn.Linear(encoder_config.hidden_size, self.decoder.config.hidden_size) |
| self.loss_fct = nn.CrossEntropyLoss(ignore_index=-100, label_smoothing=0.1) |
|
|
| @property |
| def device(self): |
| return list(self.parameters())[0].device |
|
|
| def shift_tokens_right(self, input_ids: torch.Tensor, pad_token_id: int, decoder_start_token_id: int): |
| """ |
| Shift input ids one token to the right. |
| """ |
| shifted_input_ids = input_ids.new_zeros(input_ids.shape) |
| shifted_input_ids[:, 1:] = input_ids[:, :-1].clone() |
| shifted_input_ids[:, 0] = decoder_start_token_id |
|
|
| if pad_token_id is None: |
| raise ValueError("self.model.config.pad_token_id has to be defined.") |
| |
| shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id) |
|
|
| return shifted_input_ids |
|
|
| def forward_encoder(self, audios): |
| outputs = self.encoder(audios) |
| outputs = self.enc_to_dec_proj(outputs.last_hidden_state) |
| return outputs |
|
|
| def forward_decoder(self, text, encoder_outputs): |
|
|
| encoder_outputs = self.decoder.model.encoder( |
| input_ids=None, |
| inputs_embeds=encoder_outputs, |
| return_dict=True |
| )["last_hidden_state"] |
|
|
| text = self.tokenizer(text, |
| padding='longest', |
| truncation=True, |
| max_length=30, |
| return_tensors="pt") |
| input_ids = text["input_ids"].to(self.device) |
| attention_mask = text["attention_mask"].to(self.device) |
|
|
| decoder_targets = input_ids.masked_fill( |
| input_ids == self.tokenizer.pad_token_id, -100 |
| ) |
|
|
| decoder_input_ids = self.shift_tokens_right( |
| decoder_targets, self.decoder.config.pad_token_id, self.decoder.config.decoder_start_token_id |
| ) |
|
|
| decoder_outputs = self.decoder( |
| input_ids=None, |
| attention_mask=None, |
| decoder_input_ids=decoder_input_ids, |
| decoder_attention_mask=attention_mask, |
| inputs_embeds=None, |
| labels=None, |
| encoder_outputs=(encoder_outputs,), |
| return_dict=True |
| ) |
| lm_logits = decoder_outputs["logits"] |
| loss = self.loss_fct(lm_logits.view(-1, self.tokenizer.vocab_size), decoder_targets.view(-1)) |
| |
| return loss |
|
|
| def forward(self, audio, text): |
|
|
| audio_embeds = self.forward_encoder(audio) |
| loss = self.forward_decoder(text, audio_embeds) |
|
|
| return loss |
|
|
| def generate(self, |
| samples, |
| use_nucleus_sampling=False, |
| num_beams=3, |
| max_length=30, |
| min_length=2, |
| top_p=0.9, |
| repetition_penalty=1.0, |
| ): |
|
|
| audio_embs = self.forward_encoder(samples) |
|
|
| |
| encoder_outputs = self.decoder.model.encoder( |
| input_ids=None, |
| attention_mask=None, |
| head_mask=None, |
| inputs_embeds=audio_embs, |
| output_attentions=None, |
| output_hidden_states=None, |
| return_dict=True) |
|
|
| |
| |
| start_token_id = getattr(self.decoder.config, "decoder_start_token_id", self.tokenizer.bos_token_id) |
| if start_token_id is None: |
| start_token_id = self.tokenizer.bos_token_id |
|
|
| input_ids = torch.ones((encoder_outputs['last_hidden_state'].size(0), 1)).long().to(self.device) * start_token_id |
| |
| |
| |
| |
| attention_mask = torch.ones(encoder_outputs['last_hidden_state'].shape[:2], dtype=torch.long, device=self.device) |
|
|
| |
| |
| generate_kwargs = { |
| "input_ids": None, |
| "encoder_outputs": encoder_outputs, |
| "attention_mask": attention_mask, |
| "max_length": max_length, |
| "min_length": min_length, |
| "repetition_penalty": repetition_penalty, |
| "decoder_input_ids": input_ids, |
| } |
|
|
| if use_nucleus_sampling: |
| generate_kwargs.update({ |
| "do_sample": True, |
| "top_p": top_p, |
| "num_return_sequences": 1, |
| }) |
| else: |
| generate_kwargs.update({ |
| "num_beams": num_beams, |
| }) |
|
|
| outputs = self.decoder.generate(**generate_kwargs) |
| |
| print(f"DEBUG: Raw generation token IDs: {outputs.tolist() if torch.is_tensor(outputs) else outputs}") |
| |
| captions = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| return captions |
|
|