| import argparse |
| import glob |
| import os |
| import sys |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| import torch.nn.functional as F |
| import transformers |
| from peft import LoraConfig, get_peft_model |
| from transformers import AutoTokenizer |
|
|
| from model.AffordanceVLM import AffordanceVLMForCausalLM |
| from utils.utils import DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN |
|
|
|
|
| def parse_args(args): |
| parser = argparse.ArgumentParser( |
| description="merge lora weights and save model with hf format" |
| ) |
| parser.add_argument( |
| "--version", default="liuhaotian/llava-llama-2-13b-chat-lightning-preview" |
| ) |
| parser.add_argument("--vis_save_path", default="./vis_output", type=str) |
| parser.add_argument( |
| "--precision", |
| default="bf16", |
| type=str, |
| choices=["fp32", "bf16", "fp16"], |
| help="precision for inference", |
| ) |
| parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str) |
| parser.add_argument("--out_dim", default=256, type=int) |
| parser.add_argument("--image_size", default=1024, type=int, help="image size") |
| parser.add_argument("--model_max_length", default=512, type=int) |
| parser.add_argument( |
| "--vision-tower", default="openai/clip-vit-large-patch14", type=str |
| ) |
| parser.add_argument("--lora_r", default=8, type=int) |
| parser.add_argument("--lora_alpha", default=16, type=int) |
| parser.add_argument("--lora_dropout", default=0.05, type=float) |
| parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str) |
| parser.add_argument("--local-rank", default=0, type=int, help="node rank") |
| parser.add_argument("--train_mask_decoder", action="store_true", default=True) |
| parser.add_argument("--use_mm_start_end", action="store_true", default=True) |
| parser.add_argument( |
| "--conv_type", |
| default="llava_v1", |
| type=str, |
| choices=["llava_v1", "llava_llama_2"], |
| ) |
| parser.add_argument("--weight", default="", type=str, required=True) |
| parser.add_argument("--save_path", default="./lisa_model", type=str, required=True) |
| return parser.parse_args(args) |
|
|
|
|
| def main(args): |
| args = parse_args(args) |
| os.makedirs(args.vis_save_path, exist_ok=True) |
|
|
| |
| tokenizer = transformers.AutoTokenizer.from_pretrained( |
| args.version, |
| cache_dir=None, |
| model_max_length=args.model_max_length, |
| padding_side="right", |
| use_fast=False, |
| ) |
| tokenizer.pad_token = tokenizer.unk_token |
| num_added_tokens = tokenizer.add_tokens("[SEG]") |
| args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] |
| num_added_tokens = tokenizer.add_tokens("[AFF]") |
| args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0] |
|
|
| if args.use_mm_start_end: |
| tokenizer.add_tokens( |
| [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True |
| ) |
|
|
| model_args = { |
| "train_mask_decoder": args.train_mask_decoder, |
| "out_dim": args.out_dim, |
| "seg_token_idx": args.seg_token_idx, |
| "aff_token_idx": args.aff_token_idx, |
| "vision_tower": args.vision_tower, |
| } |
|
|
| torch_dtype = torch.float32 |
| if args.precision == "bf16": |
| torch_dtype = torch.bfloat16 |
| elif args.precision == "fp16": |
| torch_dtype = torch.half |
| model = AffordanceVLMForCausalLM.from_pretrained( |
| args.version, torch_dtype=torch_dtype, low_cpu_mem_usage=True, **model_args |
| ) |
| model.config.eos_token_id = tokenizer.eos_token_id |
| model.config.bos_token_id = tokenizer.bos_token_id |
| model.config.pad_token_id = tokenizer.pad_token_id |
|
|
| model.get_model().initialize_vision_modules(model.get_model().config) |
| vision_tower = model.get_model().get_vision_tower() |
| vision_tower.to(dtype=torch_dtype) |
| model.get_model().initialize_lisa_modules(model.get_model().config) |
|
|
| lora_r = args.lora_r |
| if lora_r > 0: |
|
|
| def find_linear_layers(model, lora_target_modules): |
| cls = torch.nn.Linear |
| lora_module_names = set() |
| for name, module in model.named_modules(): |
| if ( |
| isinstance(module, cls) |
| and all( |
| [ |
| x not in name |
| for x in [ |
| "visual_model", |
| "vision_tower", |
| "mm_projector", |
| "text_hidden_fcs", |
| ] |
| ] |
| ) |
| and any([x in name for x in lora_target_modules]) |
| ): |
| lora_module_names.add(name) |
| return sorted(list(lora_module_names)) |
|
|
| lora_alpha = args.lora_alpha |
| lora_dropout = args.lora_dropout |
| lora_target_modules = find_linear_layers( |
| model, args.lora_target_modules.split(",") |
| ) |
| lora_config = LoraConfig( |
| r=lora_r, |
| lora_alpha=lora_alpha, |
| target_modules=lora_target_modules, |
| lora_dropout=lora_dropout, |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
| model = get_peft_model(model, lora_config) |
| model.print_trainable_parameters() |
|
|
| model.resize_token_embeddings(len(tokenizer)) |
|
|
| state_dict = torch.load(args.weight, map_location="cpu") |
| model.load_state_dict(state_dict, strict=True) |
|
|
| model = model.merge_and_unload() |
| state_dict = {} |
| for k, v in model.state_dict().items(): |
| if "vision_tower" not in k: |
| state_dict[k] = v |
| model.save_pretrained(args.save_path, state_dict=state_dict) |
| tokenizer.save_pretrained(args.save_path) |
|
|
|
|
| if __name__ == "__main__": |
| main(sys.argv[1:]) |
|
|