| | """ |
| | Copyright (c) 2023, salesforce.com, inc. |
| | All rights reserved. |
| | SPDX-License-Identifier: BSD-3-Clause |
| | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
| | """ |
| | import logging |
| | import torch |
| | import torch.nn as nn |
| | from torch.cuda.amp import autocast as autocast |
| | |
| | from lavis.models.blip2_models.blip2 import disabled_train |
| | from model.blip2 import Blip2Base |
| | from transformers import AutoTokenizer |
| | from transformers import OPTForCausalLM |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | from opendelta import LoraModel |
| | from opendelta.delta_models.lora import LoraConfig as DeltaLoraConfig |
| | from transformers import BertTokenizer, BitsAndBytesConfig |
| | from model.help_funcs import hf_enable_gradient_checkpointing |
| | import json |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | opt_model_list = [ |
| | "facebook/galactica-125m", |
| | "facebook/galactica-1.3b", |
| | "facebook/galactica-6.7b", |
| | "facebook/galactica-30b", |
| | ] |
| |
|
| | def get_gpu_memory(device=0): |
| | |
| | |
| | |
| | |
| | free, total = torch.cuda.mem_get_info(device) |
| | free = free / (1024 ** 3) |
| | total = total / (1024 ** 3) |
| | return free, total-free, total |
| |
|
| | def mask_by_len(input, lens, fill_value=0): |
| | ''' |
| | input: shape = [N, D] |
| | lens: shape = [N] |
| | ''' |
| | mask = torch.arange(input.shape[1], device=input.device).reshape(1, -1) |
| | mask = mask < lens.reshape(-1, 1) |
| | input[mask] = fill_value |
| | return input |
| |
|
| |
|
| |
|
| | class Blip2OPT_new(Blip2Base): |
| | """ |
| | BLIP2 first-stage model with Q-former and ViT. |
| | Supported model types: |
| | - pretrained: pretrained model with vit-g |
| | - pretrain_vitL: pretrained model with vit-large |
| | - coco: fintuned model on coco |
| | Usage: |
| | >>> from lavis.models import load_model |
| | >>> model = load_model("blip2", "pretrain") |
| | """ |
| | def __init__( |
| | self, |
| | bert_name, |
| | num_query_token=32, |
| | cross_attention_freq=2, |
| | plm_model="facebook/esm2_t30_150M_UR50D", |
| | plm_tune='freeze', |
| | llm_name="facebook/galactica-1.3b", |
| | llm_tune='freeze', |
| | peft_dir='', |
| | args=None, |
| | ): |
| | super().__init__() |
| | self.args = args |
| | self.enbale_gradient_checkpointing = args.enbale_gradient_checkpointing |
| |
|
| | self.plm_tokenizer, self.plm, self.ln_layer = self.init_protein_encoder(plm_model) |
| | self.plm_tune = plm_tune |
| | if plm_tune == 'freeze': |
| | for name, param in self.plm.named_parameters(): |
| | param.requires_grad = False |
| | self.plm = self.plm.eval() |
| | self.plm.train = disabled_train |
| | logging.info("freeze plm encoder") |
| | elif plm_tune == 'lora': |
| | lora_config = DeltaLoraConfig(args.lora_r, |
| | args.lora_alpha, |
| | args.lora_dropout, |
| | modified_modules=["query", "value"]) |
| | self.delta = LoraModel.from_config(lora_config, self.plm) |
| | self.delta.freeze_module(set_state_dict=False) |
| | self.delta.log() |
| | else: |
| | raise NotImplementedError() |
| | |
| | self.num_query_token = num_query_token |
| | self.qformer_tokenizer, self.Qformer, self.query_tokens = self.init_Qformer(bert_name, num_query_token, self.plm.num_features, cross_attention_freq) |
| | |
| | self.Qformer.cls = None |
| | self.Qformer.bert.embeddings.word_embeddings = None |
| | self.Qformer.bert.embeddings.position_embeddings = None |
| | for layer in self.Qformer.bert.encoder.layer: |
| | layer.output = None |
| | layer.intermediate = None |
| |
|
| | |
| | |
| | self.llm_model, self.llm_tokenizer = self.load_llm(llm_name) |
| | |
| | |
| | self.eos_token_id = self.llm_tokenizer.eos_token_id |
| | self.pad_token_id = self.llm_tokenizer.pad_token_id |
| |
|
| | if llm_tune == 'freeze': |
| | for name, param in self.llm_model.named_parameters(): |
| | param.requires_grad = False |
| | elif llm_tune == 'full': |
| | for name, param in self.llm_model.named_parameters(): |
| | param.requires_grad = True |
| | elif llm_tune == 'lora': |
| | lora_config = DeltaLoraConfig(args.lora_r, |
| | args.lora_alpha, |
| | args.lora_dropout,) |
| | self.delta = LoraModel.from_config(lora_config, self.llm_model) |
| | self.delta.freeze_module(set_state_dict=False) |
| | self.delta.log() |
| | elif llm_tune == 'mid_lora': |
| | lora_config = DeltaLoraConfig(args.lora_r, args.lora_alpha, args.lora_dropout, modified_modules=["q_proj", "v_proj", 'k_proj', "out_proj", "fc1", "fc2"]) |
| | self.delta = LoraModel.from_config(lora_config, self.llm_model) |
| | self.delta.freeze_module(set_state_dict=False) |
| | self.delta.log() |
| | elif llm_tune == 'peft_lora': |
| | config = PeftLoraConfig( |
| | r=args.lora_r, |
| | lora_alpha=args.lora_alpha, |
| | |
| | lora_dropout=args.lora_dropout, |
| | bias="none", |
| | task_type="CAUSAL_LM", |
| | ) |
| | self.llm_model = get_peft_model(self.llm_model, config) |
| | for name, module in self.llm_model.named_modules(): |
| | if isinstance(module, LoraLayer): |
| | if True: |
| | module = module.to(torch.bfloat16) |
| | if 'norm' in name: |
| | module = module.to(torch.float32) |
| | if 'lm_head' in name or 'embed_tokens' in name: |
| | if hasattr(module, 'weight'): |
| | if True and module.weight.dtype == torch.float32: |
| | module = module.to(torch.bfloat16) |
| | else: |
| | raise NotImplementedError() |
| |
|
| | self.opt_proj = nn.Linear(self.Qformer.config.hidden_size, self.llm_model.config.hidden_size) |
| |
|
| | def load_llm(self, llm_model, load_4bit=False, enable_gradient_checkpointing=True): |
| | llm_tokenizer = AutoTokenizer.from_pretrained(llm_model, use_fast=False, padding_side='right') |
| | llm_tokenizer.add_special_tokens({'pad_token': '<pad>'}) |
| | |
| | special_tokens_dict = {'additional_special_tokens': ['<PROT>', '<TEXT>']} |
| | llm_tokenizer.add_special_tokens(special_tokens_dict) |
| | |
| | llm_model = AutoModelForCausalLM.from_pretrained(llm_model, torch_dtype=torch.bfloat16) |
| | llm_model.resize_token_embeddings(len(llm_tokenizer)) |
| | |
| | return llm_model, llm_tokenizer |
| |
|
| | def forward(self, batch): |
| | prot_batch, prompt_batch, text_dict = batch |
| | text_seqs = text_dict['targets'] |
| | batch_size = prompt_batch['input_ids'].size(0) |
| | |
| | |
| |
|
| | prot_embeds = self.plm(**prot_batch, return_dict=True) |
| | prot_embeds = prot_embeds.last_hidden_state |
| | if self.plm_tune == 'freeze': |
| | prot_embeds = prot_embeds.detach() |
| | prot_embeds = self.ln_layer(prot_embeds) |
| | device = prot_embeds.device |
| | query_tokens = self.query_tokens.expand(prot_embeds.shape[0], -1, -1) |
| | query_output = self.Qformer.bert( |
| | query_embeds=query_tokens, |
| | encoder_hidden_states=prot_embeds, |
| | encoder_attention_mask=prot_batch.attention_mask, |
| | return_dict=True, |
| | ) |
| | prot_tokens = self.opt_proj(query_output.last_hidden_state) |
| | prot_mask = torch.ones(prot_tokens.shape[:2], dtype=torch.long, device=device) |
| |
|
| | |
| | prompt_embeds = self.llm_model.get_input_embeddings()(prompt_batch.input_ids) |
| | prompt_mask = prompt_batch['attention_mask'] |
| |
|
| |
|
| | text_batch = self.llm_tokenizer( |
| | list(text_seqs), |
| | padding='longest', |
| | truncation=True, |
| | max_length=1024, |
| | return_tensors='pt' |
| | ).to(device) |
| | target_embeds = self.llm_model.get_input_embeddings()(text_batch['input_ids']) |
| | target_mask = text_batch['attention_mask'] |
| | targets = text_batch['input_ids'].masked_fill(text_batch['input_ids'] == self.llm_tokenizer.pad_token_id, -100) |
| |
|
| | |
| | embedding_layer = self.llm_model.get_input_embeddings() |
| |
|
| | def embed_special_str(token_str): |
| | |
| | ids = self.llm_tokenizer(token_str, add_special_tokens=False).input_ids |
| | |
| | ids_tensor = torch.tensor([ids], device=device) |
| | |
| | embs = embedding_layer(ids_tensor) |
| | |
| | return embs.expand(batch_size, -1, -1) |
| |
|
| | |
| | embed_im_start = embed_special_str("<|im_start|>user\n") |
| | embed_im_end = embed_special_str("<|im_end|>\n") |
| | embed_assistant= embed_special_str("<|im_start|>assistant\n") |
| |
|
| | |
| | user_embeds = torch.cat([embed_im_start, prot_tokens , prompt_embeds,embed_im_end, embed_assistant], dim=1) |
| | user_mask = torch.ones(user_embeds.shape[:2], dtype=torch.long, device=device) |
| |
|
| | assistant_embeds = target_embeds |
| | assistant_mask = target_mask |
| |
|
| | inputs_embeds = torch.cat([user_embeds, assistant_embeds], dim=1) |
| | attention_mask = torch.cat([user_mask, assistant_mask], dim=1) |
| |
|
| | |
| | ignore_labels = torch.full(user_embeds.shape[:2], -100, dtype=torch.long, device=device) |
| | assistant_labels = targets |
| | labels = torch.cat([ignore_labels, assistant_labels], dim=1) |
| |
|
| | |
| | outputs = self.llm_model( |
| | inputs_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | labels=labels, |
| | return_dict=True, |
| | ) |
| | loss = outputs.loss |
| | return loss |
| | |
| | @torch.no_grad() |
| | def generate( |
| | self, |
| | samples, |
| | do_sample=False, |
| | num_beams=5, |
| | max_length=128, |
| | min_length=1, |
| | top_p=0.9, |
| | repetition_penalty=1.0, |
| | length_penalty=1.0, |
| | num_captions=1, |
| | temperature=1, |
| | ): |
| | """ |
| | Args: |
| | samples (dict): A dictionary containing the following keys: |
| | - image (torch.Tensor): A tensor of shape (batch_size, 3, H, W) |
| | num_beams (int): Number of beams for beam search. 1 means no beam search. |
| | max_length (int): The maximum length of the sequence to be generated. |
| | min_length (int): The minimum length of the sequence to be generated. |
| | top_p (float): The cumulative probability for nucleus sampling. |
| | repetition_penalty (float): The parameter for repetition penalty. 1.0 means no penalty. |
| | num_captions (int): Number of captions to be generated for each image. |
| | Returns: |
| | captions (list): A list of strings of length batch_size * num_captions. |
| | """ |
| |
|
| | |
| | prot_batch = samples['prot_batch'] |
| | prompt_batch = samples['prompt_batch'] |
| |
|
| | device = prompt_batch['input_ids'].device |
| | batch_size = prompt_batch['input_ids'].size(0) |
| |
|
| | |
| | prot_embeds = self.plm(**prot_batch, return_dict=True).last_hidden_state |
| | prot_embeds = self.ln_layer(prot_embeds) |
| | query_tokens = self.query_tokens.expand(prot_embeds.shape[0], -1, -1) |
| | query_output = self.Qformer.bert( |
| | query_embeds=query_tokens, |
| | encoder_hidden_states=prot_embeds, |
| | encoder_attention_mask=prot_batch['attention_mask'], |
| | return_dict=True, |
| | ) |
| | prot_tokens = self.opt_proj(query_output.last_hidden_state) |
| |
|
| | |
| | prompt_input_ids = prompt_batch['input_ids'] |
| | prompt_attention_mask = prompt_batch['attention_mask'] |
| | prompt_embeds = self.llm_model.get_input_embeddings()(prompt_input_ids) |
| |
|
| | |
| | embedding_layer = self.llm_model.get_input_embeddings() |
| |
|
| | def embed_special_str(token_str): |
| | |
| | ids = self.llm_tokenizer(token_str, add_special_tokens=False).input_ids |
| | |
| | ids_tensor = torch.tensor([ids], device=device) |
| | |
| | embs = embedding_layer(ids_tensor) |
| | |
| | return embs.expand(batch_size, -1, -1) |
| |
|
| | |
| | embed_im_start = embed_special_str("<|im_start|>user\n") |
| | embed_im_end = embed_special_str("<|im_end|>\n") |
| | embed_assistant= embed_special_str("<|im_start|>assistant\n") |
| |
|
| |
|
| | |
| | user_embeds = torch.cat([embed_im_start, prot_tokens, prompt_embeds, embed_im_end], dim=1) |
| | assistant_prefix = embed_assistant |
| | inputs_embeds = torch.cat([user_embeds, assistant_prefix], dim=1) |
| |
|
| | |
| | user_mask = torch.ones(user_embeds.shape[:2], dtype=torch.long, device=device) |
| | assistant_mask = torch.ones((batch_size, embed_assistant.size(1)), dtype=torch.long, device=device) |
| | attention_mask = torch.cat([user_mask, assistant_mask], dim=1) |
| | |
| | outputs = self.llm_model.generate( |
| | inputs_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | do_sample=do_sample, |
| | top_p=top_p, |
| | temperature=temperature, |
| | num_beams=num_beams, |
| | max_new_tokens=max_length, |
| | min_length=min_length, |
| | |
| | eos_token_id=self.eos_token_id, |
| | repetition_penalty=repetition_penalty, |
| | length_penalty=length_penalty, |
| | num_return_sequences=num_captions, |
| | use_cache=True, |
| | cache_implementation="hybrid" |
| | ) |
| | output_text = self.llm_tokenizer.batch_decode(outputs, skip_special_tokens=True) |
| | output_text = [text.strip() for text in output_text] |
| | |
| | return output_text |
| |
|