| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import logging |
| import json |
| import contextlib |
| import random |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from transformers import LlamaTokenizer, StoppingCriteriaList |
| from peft import LoraConfig, TaskType, get_peft_model |
|
|
| from .Qformer import BertConfig, BertLMHeadModel |
| from .modeling_llama import LlamaForCausalLM |
| from .modeling_whisper import WhisperModel |
| from .beats.BEATs import BEATsConfig, BEATs |
| from .utils import StoppingCriteriaSub |
|
|
|
|
| class SALMONN(nn.Module): |
| @classmethod |
| def init_speech_Qformer(cls, num_query_token, speech_width, num_hidden_layers=2): |
| encoder_config = BertConfig.from_pretrained("bert-base-uncased") |
| encoder_config.num_hidden_layers = num_hidden_layers |
| encoder_config.encoder_width = speech_width |
| |
| encoder_config.add_cross_attention = True |
| encoder_config.cross_attention_freq = 1 |
| encoder_config.query_length = num_query_token |
| Qformer = BertLMHeadModel(config=encoder_config) |
| query_tokens = nn.Parameter( |
| torch.zeros(1, num_query_token, encoder_config.hidden_size) |
| ) |
| query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) |
| return Qformer, query_tokens |
|
|
| @property |
| def device(self): |
| return list(self.parameters())[0].device |
|
|
| def maybe_autocast(self, dtype=torch.float16): |
| |
| |
| enable_autocast = self.device != torch.device("cpu") |
|
|
| if enable_autocast: |
| return torch.cuda.amp.autocast(dtype=dtype) |
| else: |
| return contextlib.nullcontext() |
|
|
| def __init__( |
| self, |
| llama_path="", |
| whisper_path="", |
| freeze_whisper=True, |
| beats_path="", |
| freeze_beats=True, |
| |
| use_speech_Qformer=True, |
| num_speech_query_token=1, |
| freeze_speech_QFormer=False, |
| window_level_Qformer=True, |
| second_per_window=0.333333, |
| second_stride=0.333333, |
| |
| speech_llama_proj_model="", |
| freeze_speech_llama_proj=False, |
| |
| lora=True, |
| lora_rank=8, |
| lora_alpha=32, |
| lora_dropout=0.1, |
| |
| multi_prompt=False, |
| prompt_path="", |
| prompt_template="", |
| max_txt_len=128, |
| end_sym="</s>", |
| low_resource=False, |
| device_8bit=0, |
| ): |
| super().__init__() |
|
|
| self.beats_path = beats_path |
| self.use_speech_Qformer = use_speech_Qformer |
| self.window_level_Qformer = window_level_Qformer |
| self.second_per_window = second_per_window |
| self.second_stride = second_stride |
| self.lora = lora |
| self.multi_prompt = multi_prompt |
| self.max_txt_len = max_txt_len |
| self.end_sym = end_sym |
| self.low_resource = low_resource |
|
|
| logging.info('Loading LLaMA Tokenizer') |
| self.llama_tokenizer = LlamaTokenizer.from_pretrained(llama_path, use_fast=False) |
| self.llama_tokenizer.add_special_tokens({'pad_token': '[PAD]'}) |
| self.llama_tokenizer.padding_side = "right" |
|
|
| logging.info('Loading LLaMA Model') |
| if self.low_resource: |
| self.llama_model = LlamaForCausalLM.from_pretrained( |
| llama_path, |
| torch_dtype=torch.float16, |
| load_in_8bit=True, |
| device_map={"": device_8bit}, |
| ) |
| else: |
| self.llama_model = LlamaForCausalLM.from_pretrained( |
| llama_path, |
| torch_dtype=torch.float16, |
| ) |
|
|
| self.llama_model.resize_token_embeddings(len(self.llama_tokenizer)) |
| for name, param in self.llama_model.named_parameters(): |
| param.requires_grad = False |
| logging.info('Loading LLaMA Done') |
|
|
| if self.lora: |
| self.peft_config = LoraConfig( |
| task_type=TaskType.CAUSAL_LM, |
| inference_mode=False, |
| r=lora_rank, |
| lora_alpha=lora_alpha, |
| lora_dropout=lora_dropout, |
| ) |
| self.llama_model = get_peft_model(self.llama_model, self.peft_config) |
| self.llama_model.print_trainable_parameters() |
| logging.info('LoRA Training') |
|
|
| assert whisper_path |
| logging.info('Loading Whisper Model') |
| self.speech_encoder = WhisperModel.from_pretrained(whisper_path).encoder |
| self.ln_speech = nn.LayerNorm(self.speech_encoder.config.d_model) |
| if freeze_whisper: |
| for name, param in self.speech_encoder.named_parameters(): |
| param.requires_grad = False |
| self.speech_encoder.eval() |
| logging.info("freeze Whisper") |
| |
| if self.beats_path: |
| logging.info("Loading BEATs Model") |
| beats_ckpt = torch.load(self.beats_path, map_location='cpu') |
| beats_cfg = BEATsConfig(beats_ckpt['cfg']) |
| self.beats = BEATs(beats_cfg) |
| self.beats.load_state_dict(beats_ckpt['model']) |
| self.ln_audio = nn.LayerNorm(self.beats.cfg.encoder_embed_dim) |
| if freeze_beats: |
| for name, param in self.beats.named_parameters(): |
| param.requires_grad = False |
| self.beats.eval() |
| logging.info("freeze BEATs") |
|
|
| if self.use_speech_Qformer: |
| if self.beats_path: |
| self.speech_Qformer, self.speech_query_tokens = self.init_speech_Qformer( |
| num_query_token=num_speech_query_token, speech_width=self.speech_encoder.config.d_model + self.beats.cfg.encoder_embed_dim |
| ) |
| else: |
| self.speech_Qformer, self.speech_query_tokens = self.init_speech_Qformer( |
| num_query_token=num_speech_query_token, speech_width=self.speech_encoder.config.d_model |
| ) |
| self.speech_Qformer.bert.embeddings.word_embeddings = None |
| self.speech_Qformer.bert.embeddings.position_embeddings = None |
| for layer in self.speech_Qformer.bert.encoder.layer: |
| layer.output = None |
| layer.intermediate = None |
| self.speech_Qformer.cls = None |
| if freeze_speech_QFormer: |
| for name, param in self.speech_Qformer.named_parameters(): |
| param.requires_grad = False |
| self.speech_Qformer.eval() |
| self.speech_query_tokens.requires_grad = False |
| logging.info("freeze Speech QFormer") |
|
|
| logging.info('Loading speech LLAMA proj') |
| self.speech_llama_proj = nn.Linear( |
| self.speech_Qformer.config.hidden_size, self.llama_model.config.hidden_size |
| ) |
| if speech_llama_proj_model: |
| logging.info("Loading speech LLAMA proj from {}".format(speech_llama_proj_model)) |
| speech_llama_proj_weight = torch.load(speech_llama_proj_model, map_location="cpu") |
| self.load_state_dict(speech_llama_proj_weight['model'], strict=False) |
| if freeze_speech_llama_proj: |
| for name, param in self.speech_llama_proj.named_parameters(): |
| param.requires_grad = False |
| self.speech_llama_proj.eval() |
| logging.info("freeze speech LLAMA proj") |
| else: |
| |
| raise NotImplementedError |
|
|
| |
| self.prompt_dict = {} |
| if prompt_path: |
| try: |
| raw_prompts = json.load(open(prompt_path, "r")) |
| except: |
| print("Failed to load prompt! Try to use utf-8 encoding.") |
| raw_prompts = json.load(open(prompt_path, "r", encoding='utf-8')) |
| for task in raw_prompts.keys(): |
| filted_prompts = [raw_prompt for raw_prompt in raw_prompts[task] if "<SpeechHere>" in raw_prompt] |
| self.prompt_dict[task] = [prompt_template.format(p) for p in filted_prompts] |
| print("Loading training prompts done!") |
|
|
| def _encode_auditory_feature(self, speech_embeds, audio_embeds=None): |
| with self.maybe_autocast(): |
| if self.use_speech_Qformer: |
| speech_embeds = self.ln_speech(speech_embeds) |
| if audio_embeds is not None: |
| audio_embeds = self.ln_audio(audio_embeds) |
| if audio_embeds.size(1) < speech_embeds.size(1): |
| audio_embeds = F.pad(audio_embeds, (0, 0, 0, speech_embeds.size(1) - audio_embeds.size(1))) |
| elif audio_embeds.size(1) > speech_embeds.size(1): |
| speech_embeds = F.pad(speech_embeds, (0, 0, 0, audio_embeds.size(1) - speech_embeds.size(1))) |
| speech_embeds = torch.cat((speech_embeds, audio_embeds), dim=-1) |
| speech_atts = torch.ones(speech_embeds.size()[:-1], dtype=torch.long).to(speech_embeds.device) |
|
|
| if self.window_level_Qformer: |
| B, T, C = speech_embeds.shape |
| kernel = round(1500 * self.second_per_window / 30.0) |
| stride = round(1500 * self.second_stride / 30.0) |
| kernel = (1, kernel) |
| stride = (1, stride) |
| speech_embeds_tr = speech_embeds.transpose(1, 2).unsqueeze(2) |
| speech_embeds_overlap = F.unfold(speech_embeds_tr, kernel_size=kernel, dilation=1, padding=0, stride=stride) |
| _, _, L = speech_embeds_overlap.shape |
| speech_embeds_overlap = speech_embeds_overlap.view(B, -1, kernel[1], L) |
| speech_embeds_overlap = torch.permute(speech_embeds_overlap, [0, 3, 2, 1]) |
| speech_embeds = speech_embeds_overlap.reshape(-1, kernel[1], C) |
| speech_atts = torch.ones(speech_embeds.size()[:-1], dtype=torch.long, device=speech_embeds.device) |
|
|
| query_tokens = self.speech_query_tokens.expand(speech_embeds.shape[0], -1, -1) |
| query_output = self.speech_Qformer.bert( |
| query_embeds=query_tokens, |
| encoder_hidden_states=speech_embeds, |
| encoder_attention_mask=speech_atts, |
| return_dict=True, |
| ) |
| speech_embeds = self.speech_llama_proj(query_output.last_hidden_state) |
|
|
| if self.window_level_Qformer: |
| speech_embeds = speech_embeds.view(B, -1, speech_embeds.size(2)).contiguous() |
|
|
| speech_atts = torch.ones(speech_embeds.size()[:-1], dtype=torch.long).to(speech_embeds.device) |
| else: |
| raise NotImplementedError |
|
|
| return speech_embeds, speech_atts |
|
|
| def encode_speech(self, spectrogram, raw_wav=None, audio_padding_mask=None): |
| with self.maybe_autocast(): |
| speech_embeds = self.speech_encoder(spectrogram, return_dict=True).last_hidden_state |
|
|
| if self.beats_path and raw_wav is not None: |
| audio_embeds, _ = self.beats.extract_features(raw_wav, padding_mask=audio_padding_mask, feature_only=True) |
| else: |
| audio_embeds = None |
|
|
| return self._encode_auditory_feature(speech_embeds, audio_embeds=audio_embeds) |
|
|
| def prompt_wrap(self, embeds, atts, prompt, multi_prompt=False): |
| if prompt: |
| if multi_prompt: |
| p_before = [] |
| p_after = [] |
| for i, p in enumerate(prompt): |
| b, a = p.split("<SpeechHere>") |
| p_before.append(b) |
| p_after.append(a) |
| |
| p_before_tokens = self.llama_tokenizer( |
| p_before, return_tensors="pt", add_special_tokens=False |
| ).to(embeds.device) |
| p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids) if not self.lora else self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids) |
|
|
| |
| p_after_tokens = self.llama_tokenizer( |
| p_after, return_tensors="pt", padding="longest", add_special_tokens=False |
| ).to(embeds.device) |
| p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids) if not self.lora else self.llama_model.model.model.embed_tokens(p_after_tokens.input_ids) |
|
|
| wrapped_embeds = torch.cat([p_before_embeds, embeds, p_after_embeds], dim=1) |
| wrapped_atts = torch.cat([p_before_tokens.attention_mask, atts, p_after_tokens.attention_mask], dim=1) |
| else: |
| batch_size = embeds.shape[0] |
| p_before, p_after = prompt.split("<SpeechHere>") |
|
|
| p_before_tokens = self.llama_tokenizer( |
| p_before, return_tensors="pt", add_special_tokens=False |
| ).to(embeds.device) |
| p_after_tokens = self.llama_tokenizer( |
| p_after, return_tensors="pt", add_special_tokens=False |
| ).to(embeds.device) |
| p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) if not self.lora else self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) |
| p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) if not self.lora else self.llama_model.model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) |
|
|
| wrapped_embeds = torch.cat([p_before_embeds, embeds, p_after_embeds], dim=1) |
| wrapped_atts = torch.cat([p_before_tokens.attention_mask, atts, p_after_tokens.attention_mask], dim=1) |
| return wrapped_embeds, wrapped_atts |
| else: |
| return embeds, atts |
|
|
| def forward(self, samples, verbose=False): |
| |
| task = list(set(samples["task"])) |
| if len(task) > 1 or "QA" in task: |
| self.multi_prompt = True |
|
|
| |
| if self.prompt_dict: |
| if self.multi_prompt: |
| prompt = [random.choice(self.prompt_dict[task]) for task in samples["task"]] |
| if "Q" in samples: |
| prompt = [p.format(q) if '{}' in p else p for p, q in zip(prompt, samples["Q"]) ] |
| else: |
| prompt = random.choice(self.prompt_dict[samples["task"][0]]) |
|
|
| |
| spectrogram = samples["spectrogram"] |
| raw_wav = samples.get("raw_wav", None) |
| audio_padding_mask = samples.get("padding_mask", None) |
|
|
| speech_embeds, speech_atts = self.encode_speech(spectrogram, raw_wav=raw_wav, audio_padding_mask=audio_padding_mask) |
|
|
| |
| if self.prompt_dict: |
| speech_embeds, speech_atts = self.prompt_wrap(speech_embeds, speech_atts, prompt, multi_prompt=self.multi_prompt) |
|
|
| |
| text = [t + self.end_sym for t in samples["text"]] |
| to_regress_tokens = self.llama_tokenizer( |
| text, |
| return_tensors="pt", |
| padding="longest", |
| truncation=True, |
| max_length=self.max_txt_len, |
| add_special_tokens=False |
| ).to(spectrogram.device) |
| to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids) if not self.lora else self.llama_model.model.model.embed_tokens(to_regress_tokens.input_ids) |
| targets = to_regress_tokens.input_ids.masked_fill( |
| to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100 |
| ) |
| empty_targets = ( |
| torch.ones( |
| [speech_atts.shape[0], speech_atts.shape[1] + 1], |
| dtype=torch.long |
| ).to(spectrogram.device).fill_(-100) |
| ) |
| targets = torch.cat([empty_targets, targets], dim=1) |
|
|
| batch_size = speech_embeds.shape[0] |
| bos = torch.ones( |
| [batch_size, 1], |
| dtype=to_regress_tokens.input_ids.dtype, |
| device=to_regress_tokens.input_ids.device, |
| ) * self.llama_tokenizer.bos_token_id |
| bos_embeds = self.llama_model.model.embed_tokens(bos) if not self.lora else self.llama_model.model.model.embed_tokens(bos) |
| atts_bos = speech_atts[:, :1] |
|
|
| inputs_embeds = torch.cat([bos_embeds, speech_embeds, to_regress_embeds], dim=1) |
| attention_mask = torch.cat([atts_bos, speech_atts, to_regress_tokens.attention_mask], dim=1) |
|
|
| |
| with self.maybe_autocast(): |
| outputs = self.llama_model( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| return_dict=True, |
| labels=targets, |
| ) |
| loss = outputs.loss |
|
|
| if verbose: |
| nvocab = self.llama_model.config.vocab_size |
| results = outputs.logits[:, empty_targets.size(1) - 1: -1, :].contiguous().view(-1, nvocab).argmax(dim=-1) |
| labels = targets[:, empty_targets.size(1):].contiguous().view(-1) |
| mask = (labels != -100) |
| correct = (results[mask] == labels[mask]).float().sum() |
| total = len(labels[mask]) |
|
|
| if verbose: |
| return {"loss": loss, "correct": correct, "total": total} |
|
|
| return {"loss": loss} |
|
|
| def generate(self, samples, generate_cfg, prompts=None): |
| batch_size = samples["spectrogram"].shape[0] |
|
|
| spectrogram = samples["spectrogram"] |
| raw_wav = samples.get("raw_wav", None) |
| audio_padding_mask = samples.get("padding_mask", None) |
|
|
| speech_embeds, speech_atts = self.encode_speech(spectrogram, raw_wav=raw_wav, audio_padding_mask=audio_padding_mask) |
|
|
| if prompts is not None: |
| speech_embeds, speech_atts = self.prompt_wrap(speech_embeds, speech_atts, prompts, multi_prompt=True) |
|
|
| bos = torch.ones( |
| [batch_size, 1], |
| dtype=torch.int32, |
| device=speech_embeds.device, |
| ) * self.llama_tokenizer.bos_token_id |
| bos_embeds = self.llama_model.model.embed_tokens(bos) if not self.lora else self.llama_model.model.model.embed_tokens(bos) |
| atts_bos = speech_atts[:, :1] |
|
|
| embeds = torch.cat([bos_embeds, speech_embeds], dim=1) |
| attns = torch.cat([atts_bos, speech_atts], dim=1) |
|
|
| stop_words_ids = [torch.tensor([2]).cuda()] |
| stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)]) |
| outputs = self.llama_model.generate( |
| inputs_embeds=embeds, |
| max_new_tokens=generate_cfg.get("max_new_tokens", 200), |
| stopping_criteria=stopping_criteria, |
| num_beams=generate_cfg.get("num_beams", 4), |
| do_sample=generate_cfg.get("do_sample", False), |
| min_length=generate_cfg.get("min_length", 1), |
| temperature=generate_cfg.get("temperature", 1.0), |
| top_p=generate_cfg.get("top_p", 0.9), |
| repetition_penalty=generate_cfg.get("repetition_penalty", 1.0), |
| length_penalty=generate_cfg.get("length_penalty", 1.0), |
| attention_mask=attns, |
| ) |
| text = self.llama_tokenizer.batch_decode(outputs, add_special_tokens=False) |
|
|
| return text |
|
|
| @classmethod |
| def from_config(cls, config): |
| llama_path = config.get("llama_path") |
| whisper_path = config.get("whisper_path") |
| freeze_whisper = config.get("freeze_whisper", True) |
| beats_path = config.get("beats_path", "") |
| freeze_beats = config.get("freeze_beats", True) |
|
|
| use_speech_Qformer = config.get("use_speech_Qformer", True) |
| num_speech_query_token = config.get("num_speech_query_token", 1) |
| freeze_speech_QFormer = config.get("freeze_speech_QFormer", False) |
| window_level_Qformer = config.get("window_level_Qformer", True) |
| second_per_window = config.get("second_per_window", 0.333333) |
| second_stride = config.get("second_stride", 0.333333) |
|
|
| speech_llama_proj_model = config.get("speech_llama_proj_model", "") |
| freeze_speech_llama_proj = config.get("freeze_speech_llama_proj", False) |
|
|
| lora = config.get("lora", True) |
| lora_rank = config.get("lora_rank", 8) |
| lora_alpha = config.get("lora_alpha", 32) |
| lora_dropout = config.get("lora_dropout", 0.1) |
|
|
| multi_prompt = config.get("multi_prompt", False) |
| prompt_path = config.get("prompt_path", "") |
| prompt_template = config.get("prompt_template", "") |
| max_txt_len = config.get("max_txt_len", 128) |
| end_sym = config.get("end_sym", "</s>") |
| low_resource = config.get("low_resource", False) |
| device_8bit = config.get("device_8bit", 0) |
|
|
| model = cls( |
| llama_path=llama_path, |
| whisper_path=whisper_path, |
| freeze_whisper=freeze_whisper, |
| beats_path=beats_path, |
| freeze_beats=freeze_beats, |
| use_speech_Qformer=use_speech_Qformer, |
| num_speech_query_token=num_speech_query_token, |
| freeze_speech_QFormer=freeze_speech_QFormer, |
| window_level_Qformer=window_level_Qformer, |
| second_per_window=second_per_window, |
| second_stride=second_stride, |
| speech_llama_proj_model=speech_llama_proj_model, |
| freeze_speech_llama_proj=freeze_speech_llama_proj, |
| lora=lora, |
| lora_rank=lora_rank, |
| lora_alpha=lora_alpha, |
| lora_dropout=lora_dropout, |
| multi_prompt=multi_prompt, |
| prompt_path=prompt_path, |
| prompt_template=prompt_template, |
| max_txt_len=max_txt_len, |
| end_sym=end_sym, |
| low_resource=low_resource, |
| device_8bit=device_8bit, |
| ) |
|
|
| ckpt_path = config.get("ckpt", "") |
| if ckpt_path: |
| logging.info("Load SALMONN ckpt from: {}".format(ckpt_path)) |
| ckpt = torch.load(ckpt_path, map_location="cpu") |
| model.load_state_dict(ckpt['model'], strict=False) |
|
|
| return model |
|
|