| | """ |
| | Copyright (c) 2023, salesforce.com, inc. |
| | All rights reserved. |
| | SPDX-License-Identifier: BSD-3-Clause |
| | For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause |
| | """ |
| | import logging |
| | import random |
| | import os |
| | import torch |
| | from torch.cuda.amp import autocast as autocast |
| | import torch.nn as nn |
| |
|
| | from minigpt4.common.registry import registry |
| | from minigpt4.models.blip2 import Blip2Base, disabled_train |
| | from minigpt4.models.modeling_llama import LlamaForCausalLM |
| | from transformers import LlamaTokenizer |
| |
|
| |
|
| | @registry.register_model("mini_gpt4") |
| | class MiniGPT4(Blip2Base): |
| | """ |
| | BLIP2 GPT-LLAMA model. |
| | """ |
| |
|
| | PRETRAINED_MODEL_CONFIG_DICT = { |
| | "pretrain_vicuna": "configs/models/minigpt4.yaml", |
| | } |
| |
|
| | def __init__( |
| | self, |
| | vit_model="eva_clip_g", |
| | q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", |
| | img_size=224, |
| | drop_path_rate=0, |
| | use_grad_checkpoint=False, |
| | vit_precision="fp16", |
| | freeze_vit=True, |
| | freeze_qformer=True, |
| | num_query_token=32, |
| | llama_model="", |
| | llama_cache_dir='', |
| | prompt_path="", |
| | prompt_template="", |
| | max_txt_len=32, |
| | end_sym='\n', |
| | ): |
| | super().__init__() |
| |
|
| | self.tokenizer = self.init_tokenizer() |
| |
|
| | print('Loading VIT') |
| | self.visual_encoder, self.ln_vision = self.init_vision_encoder( |
| | vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision |
| | ) |
| | if freeze_vit: |
| | for name, param in self.visual_encoder.named_parameters(): |
| | param.requires_grad = False |
| | self.visual_encoder = self.visual_encoder.eval() |
| | self.visual_encoder.train = disabled_train |
| | for name, param in self.ln_vision.named_parameters(): |
| | param.requires_grad = False |
| | self.ln_vision = self.ln_vision.eval() |
| | self.ln_vision.train = disabled_train |
| | logging.info("freeze vision encoder") |
| | print('Loading VIT Done') |
| |
|
| | print('Loading Q-Former') |
| | self.Qformer, self.query_tokens = self.init_Qformer( |
| | num_query_token, self.visual_encoder.num_features |
| | ) |
| | self.Qformer.cls = None |
| | self.Qformer.bert.embeddings.word_embeddings = None |
| | self.Qformer.bert.embeddings.position_embeddings = None |
| | for layer in self.Qformer.bert.encoder.layer: |
| | layer.output = None |
| | layer.intermediate = None |
| | self.load_from_pretrained(url_or_filename=q_former_model) |
| |
|
| | if freeze_qformer: |
| | for name, param in self.Qformer.named_parameters(): |
| | param.requires_grad = False |
| | self.Qformer = self.Qformer.eval() |
| | self.Qformer.train = disabled_train |
| | self.query_tokens.requires_grad = False |
| | logging.info("freeze Qformer") |
| | print('Loading Q-Former Done') |
| |
|
| | print('Loading LLAMA') |
| | self.llama_tokenizer = LlamaTokenizer.from_pretrained('camenduru/MiniGPT4', use_fast=False) |
| | self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token |
| |
|
| | if llama_cache_dir: |
| | self.llama_model = LlamaForCausalLM.from_pretrained( |
| | 'camenduru/MiniGPT4', load_in_8bit=True, torch_dtype=torch.float16, device_map="auto" |
| | ) |
| | else: |
| | self.llama_model = LlamaForCausalLM.from_pretrained( |
| | 'camenduru/MiniGPT4', load_in_8bit=True, torch_dtype=torch.float16, device_map="auto" |
| | ) |
| | for name, param in self.llama_model.named_parameters(): |
| | param.requires_grad = False |
| | print('Loading LLAMA Done') |
| |
|
| | self.llama_proj = nn.Linear( |
| | self.Qformer.config.hidden_size, self.llama_model.config.hidden_size |
| | ) |
| | self.max_txt_len = max_txt_len |
| | self.end_sym = end_sym |
| |
|
| | if prompt_path: |
| | with open(prompt_path, 'r') as f: |
| | raw_prompts = f.read().splitlines() |
| | filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "<ImageHere>" in raw_prompt] |
| | self.prompt_list = [prompt_template.format(p) for p in filted_prompts] |
| | print('Load {} training prompts'.format(len(self.prompt_list))) |
| | print('Prompt Example \n{}'.format(random.choice(self.prompt_list))) |
| | else: |
| | self.prompt_list = [] |
| |
|
| | def vit_to_cpu(self): |
| | self.ln_vision.to("cpu") |
| | self.ln_vision.float() |
| | self.visual_encoder.to("cpu") |
| | self.visual_encoder.float() |
| | |
| | def encode_img(self, image): |
| | device = image.device |
| | self.vit_to_cpu() |
| | image = image.to("cpu") |
| | with self.maybe_autocast(): |
| | image_embeds = self.ln_vision(self.visual_encoder(image)).to(device) |
| | image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) |
| |
|
| | query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) |
| | query_output = self.Qformer.bert( |
| | query_embeds=query_tokens, |
| | encoder_hidden_states=image_embeds, |
| | encoder_attention_mask=image_atts, |
| | return_dict=True, |
| | ) |
| |
|
| | inputs_llama = self.llama_proj(query_output.last_hidden_state) |
| | atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) |
| | return inputs_llama, atts_llama |
| |
|
| | def prompt_wrap(self, img_embeds, atts_img, prompt): |
| | if prompt: |
| | batch_size = img_embeds.shape[0] |
| | p_before, p_after = prompt.split('<ImageHere>') |
| | p_before_tokens = self.llama_tokenizer( |
| | p_before, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) |
| | p_after_tokens = self.llama_tokenizer( |
| | p_after, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) |
| | p_before_embeds = self.llama_model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) |
| | p_after_embeds = self.llama_model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) |
| | wrapped_img_embeds = torch.cat([p_before_embeds, img_embeds, p_after_embeds], dim=1) |
| | wrapped_atts_img = atts_img[:, :1].expand(-1, wrapped_img_embeds.shape[1]) |
| | return wrapped_img_embeds, wrapped_atts_img |
| | else: |
| | return img_embeds, atts_img |
| |
|
| | def forward(self, samples): |
| | image = samples["image"] |
| | img_embeds, atts_img = self.encode_img(image) |
| | if hasattr(samples, 'question_split'): |
| | print('VQA Batch') |
| | vqa_prompt = '###Human: <Img><ImageHere></Img> ' |
| | img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, vqa_prompt) |
| | elif self.prompt_list: |
| | prompt = random.choice(self.prompt_list) |
| | img_embeds, atts_img = self.prompt_wrap(img_embeds, atts_img, prompt) |
| |
|
| | self.llama_tokenizer.padding_side = "right" |
| |
|
| | text = [t + self.end_sym for t in samples["text_input"]] |
| |
|
| | to_regress_tokens = self.llama_tokenizer( |
| | text, |
| | return_tensors="pt", |
| | padding="longest", |
| | truncation=True, |
| | max_length=self.max_txt_len, |
| | add_special_tokens=False |
| | ).to(image.device) |
| |
|
| | targets = to_regress_tokens.input_ids.masked_fill( |
| | to_regress_tokens.input_ids == self.llama_tokenizer.pad_token_id, -100 |
| | ) |
| |
|
| | empty_targets = ( |
| | torch.ones([atts_img.shape[0], atts_img.shape[1]+1], |
| | dtype=torch.long).to(image.device).fill_(-100) |
| | ) |
| | targets = torch.cat([empty_targets, targets], dim=1) |
| |
|
| | batch_size = img_embeds.shape[0] |
| | bos = torch.ones([batch_size, 1], |
| | dtype=to_regress_tokens.input_ids.dtype, |
| | device=to_regress_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id |
| | bos_embeds = self.llama_model.model.embed_tokens(bos) |
| | atts_bos = atts_img[:, :1] |
| |
|
| | to_regress_embeds = self.llama_model.model.embed_tokens(to_regress_tokens.input_ids) |
| | inputs_embeds = torch.cat([bos_embeds, img_embeds, to_regress_embeds], dim=1) |
| | attention_mask = torch.cat([atts_bos, atts_img, to_regress_tokens.attention_mask], dim=1) |
| |
|
| | with self.maybe_autocast(): |
| | outputs = self.llama_model( |
| | inputs_embeds=inputs_embeds, |
| | attention_mask=attention_mask, |
| | return_dict=True, |
| | labels=targets, |
| | ) |
| | loss = outputs.loss |
| |
|
| | return {"loss": loss} |
| |
|
| | @classmethod |
| | def from_config(cls, cfg): |
| | vit_model = cfg.get("vit_model", "eva_clip_g") |
| | q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth") |
| | img_size = cfg.get("image_size") |
| | num_query_token = cfg.get("num_query_token") |
| | llama_model = cfg.get("llama_model") |
| |
|
| | drop_path_rate = cfg.get("drop_path_rate", 0) |
| | use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) |
| | vit_precision = cfg.get("vit_precision", "fp16") |
| | freeze_vit = cfg.get("freeze_vit", True) |
| | freeze_qformer = cfg.get("freeze_qformer", True) |
| | llama_cache_dir = cfg.get("llama_cache_dir", "") |
| |
|
| | prompt_path = cfg.get("prompt_path", "") |
| | prompt_template = cfg.get("prompt_template", "") |
| | max_txt_len = cfg.get("max_txt_len", 32) |
| | end_sym = cfg.get("end_sym", '\n') |
| |
|
| | model = cls( |
| | vit_model=vit_model, |
| | q_former_model=q_former_model, |
| | img_size=img_size, |
| | drop_path_rate=drop_path_rate, |
| | use_grad_checkpoint=use_grad_checkpoint, |
| | vit_precision=vit_precision, |
| | freeze_vit=freeze_vit, |
| | freeze_qformer=freeze_qformer, |
| | llama_cache_dir=llama_cache_dir, |
| | num_query_token=num_query_token, |
| | llama_model=llama_model, |
| | prompt_path=prompt_path, |
| | prompt_template=prompt_template, |
| | max_txt_len=max_txt_len, |
| | end_sym=end_sym |
| | ) |
| |
|
| | ckpt_path = cfg.get("ckpt", "") |
| | if ckpt_path: |
| | print("Load BLIP2-LLM Checkpoint: {}".format(ckpt_path)) |
| | ckpt = torch.load(ckpt_path, map_location="cpu") |
| | msg = model.load_state_dict(ckpt['model'], strict=False) |
| |
|
| | return model |