| import argparse |
| import os |
| import shutil |
| import sys |
| import time |
| from functools import partial |
|
|
| import deepspeed |
| import numpy as np |
| import torch |
| import tqdm |
| import transformers |
| from peft import LoraConfig, get_peft_model |
| from torch.utils.tensorboard import SummaryWriter |
|
|
| from model.AffordanceVLM import AffordanceVLMForCausalLM |
| from model.llava import conversation as conversation_lib |
| from utils.dataset import HybridDataset, ValDataset, collate_fn |
| from utils.utils import (DEFAULT_IM_END_TOKEN, DEFAULT_IM_START_TOKEN, |
| AverageMeter, ProgressMeter, Summary, dict_to_cuda, |
| intersectionAndUnionGPU) |
|
|
| from utils.aff_seg_dataset import AffValDataset |
| from utils.reason_aff_dataset import ReasonAffValDataset |
|
|
|
|
| def parse_args(args): |
| parser = argparse.ArgumentParser(description="LISA Model Training") |
| parser.add_argument("--local_rank", default=0, type=int, help="node rank") |
| parser.add_argument( |
| "--version", default="liuhaotian/llava-llama-2-13b-chat-lightning-preview" |
| ) |
| parser.add_argument("--vis_save_path", default="./vis_output", type=str) |
| parser.add_argument( |
| "--precision", |
| default="bf16", |
| type=str, |
| choices=["fp32", "bf16", "fp16"], |
| help="precision for inference", |
| ) |
| parser.add_argument("--image_size", default=1024, type=int, help="image size") |
| parser.add_argument("--model_max_length", default=512, type=int) |
| parser.add_argument("--lora_r", default=8, type=int) |
| parser.add_argument( |
| "--vision-tower", default="openai/clip-vit-large-patch14", type=str |
| ) |
| parser.add_argument("--load_in_8bit", action="store_true", default=False) |
| parser.add_argument("--load_in_4bit", action="store_true", default=False) |
|
|
| parser.add_argument( |
| "--dataset", default="sem_seg||refer_seg||vqa||reason_seg", type=str |
| ) |
| parser.add_argument("--sample_rates", default="9,3,3,1", type=str) |
| parser.add_argument( |
| "--sem_seg_data", |
| default="ade20k||cocostuff||pascal_part||paco_lvis||mapillary", |
| type=str, |
| ) |
| parser.add_argument( |
| "--refer_seg_data", default="refclef||refcoco||refcoco+||refcocog", type=str |
| ) |
| parser.add_argument("--vqa_data", default="llava_instruct_150k", type=str) |
| parser.add_argument("--reason_seg_data", default="ReasonSeg|train", type=str) |
| parser.add_argument("--aff_seg_data", default="handal", type=str) |
| parser.add_argument("--aff_sample_rates", default="1", type=str) |
| parser.add_argument("--reason_aff_data", default="handal_hard_reasoning", type=str) |
| parser.add_argument("--reason_aff_sample_rates", default="1", type=str) |
| parser.add_argument("--val_dataset", default="ReasonSeg|val", type=str) |
| parser.add_argument("--dataset_dir", default="./dataset", type=str) |
| parser.add_argument("--log_base_dir", default="./runs", type=str) |
| parser.add_argument("--exp_name", default="lisa", type=str) |
| parser.add_argument("--epochs", default=10, type=int) |
| parser.add_argument("--steps_per_epoch", default=500, type=int) |
| parser.add_argument( |
| "--batch_size", default=2, type=int, help="batch size per device per step" |
| ) |
| parser.add_argument( |
| "--grad_accumulation_steps", |
| default=10, |
| type=int, |
| ) |
| parser.add_argument("--val_batch_size", default=1, type=int) |
| parser.add_argument("--workers", default=4, type=int) |
| parser.add_argument("--lr", default=0.0003, type=float) |
| parser.add_argument("--ce_loss_weight", default=1.0, type=float) |
| parser.add_argument("--dice_loss_weight", default=0.5, type=float) |
| parser.add_argument("--bce_loss_weight", default=2.0, type=float) |
| parser.add_argument("--lora_alpha", default=16, type=int) |
| parser.add_argument("--lora_dropout", default=0.05, type=float) |
| parser.add_argument("--lora_target_modules", default="q_proj,v_proj", type=str) |
| parser.add_argument("--explanatory", default=0.1, type=float) |
| parser.add_argument("--beta1", default=0.9, type=float) |
| parser.add_argument("--beta2", default=0.95, type=float) |
| parser.add_argument("--num_classes_per_sample", default=3, type=int) |
| parser.add_argument("--exclude_val", action="store_true", default=False) |
| parser.add_argument("--no_eval", action="store_true", default=False) |
| parser.add_argument("--eval_only", action="store_true", default=False) |
| parser.add_argument("--eval_affordance", action="store_true", default=False) |
| parser.add_argument("--eval_reason_aff", action="store_true", default=False) |
| parser.add_argument("--vision_pretrained", default="PATH_TO_SAM_ViT-H", type=str) |
| parser.add_argument("--out_dim", default=256, type=int) |
| parser.add_argument("--resume", default="", type=str) |
| parser.add_argument("--print_freq", default=1, type=int) |
| parser.add_argument("--start_epoch", default=0, type=int) |
| parser.add_argument("--gradient_checkpointing", action="store_true", default=True) |
| parser.add_argument("--train_mask_decoder", action="store_true", default=True) |
| parser.add_argument("--use_mm_start_end", action="store_true", default=True) |
| parser.add_argument("--auto_resume", action="store_true", default=True) |
| parser.add_argument( |
| "--conv_type", |
| default="llava_v1", |
| type=str, |
| choices=["llava_v1", "llava_llama_2"], |
| ) |
| return parser.parse_args(args) |
|
|
|
|
| def main(args): |
| args = parse_args(args) |
| args.log_dir = os.path.join(args.log_base_dir, args.exp_name) |
| if args.local_rank == 0: |
| os.makedirs(args.log_dir, exist_ok=True) |
| writer = SummaryWriter(args.log_dir) |
| else: |
| writer = None |
|
|
| |
| tokenizer = transformers.AutoTokenizer.from_pretrained( |
| args.version, |
| cache_dir=None, |
| model_max_length=args.model_max_length, |
| padding_side="right", |
| use_fast=False, |
| ) |
| tokenizer.pad_token = tokenizer.unk_token |
| num_added_tokens = tokenizer.add_tokens("[SEG]") |
| args.seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] |
| num_added_tokens = tokenizer.add_tokens("[AFF]") |
| args.aff_token_idx = tokenizer("[AFF]", add_special_tokens=False).input_ids[0] |
|
|
| if args.use_mm_start_end: |
| tokenizer.add_tokens( |
| [DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN], special_tokens=True |
| ) |
|
|
| model_args = { |
| "train_mask_decoder": args.train_mask_decoder, |
| "out_dim": args.out_dim, |
| "ce_loss_weight": args.ce_loss_weight, |
| "dice_loss_weight": args.dice_loss_weight, |
| "bce_loss_weight": args.bce_loss_weight, |
| "seg_token_idx": args.seg_token_idx, |
| "aff_token_idx": args.aff_token_idx, |
| "vision_pretrained": args.vision_pretrained, |
| "vision_tower": args.vision_tower, |
| "use_mm_start_end": args.use_mm_start_end, |
| } |
| torch_dtype = torch.float32 |
| if args.precision == "bf16": |
| torch_dtype = torch.bfloat16 |
| elif args.precision == "fp16": |
| torch_dtype = torch.half |
| model = AffordanceVLMForCausalLM.from_pretrained( |
| args.version, torch_dtype=torch_dtype, low_cpu_mem_usage=True, **model_args |
| ) |
| model.config.eos_token_id = tokenizer.eos_token_id |
| model.config.bos_token_id = tokenizer.bos_token_id |
| model.config.pad_token_id = tokenizer.pad_token_id |
|
|
| model.enable_input_require_grads() |
| model.gradient_checkpointing_enable() |
|
|
| model.get_model().initialize_vision_modules(model.get_model().config) |
| vision_tower = model.get_model().get_vision_tower() |
| vision_tower.to(dtype=torch_dtype, device=args.local_rank) |
| if not args.eval_only: |
| model.get_model().initialize_lisa_modules(model.get_model().config) |
|
|
| for p in vision_tower.parameters(): |
| p.requires_grad = False |
| for p in model.get_model().mm_projector.parameters(): |
| p.requires_grad = False |
|
|
| conversation_lib.default_conversation = conversation_lib.conv_templates[ |
| args.conv_type |
| ] |
|
|
| lora_r = args.lora_r |
| if lora_r > 0: |
|
|
| def find_linear_layers(model, lora_target_modules): |
| cls = torch.nn.Linear |
| lora_module_names = set() |
| for name, module in model.named_modules(): |
| if ( |
| isinstance(module, cls) |
| and all( |
| [ |
| x not in name |
| for x in [ |
| "visual_model", |
| "vision_tower", |
| "mm_projector", |
| "text_hidden_fcs", |
| ] |
| ] |
| ) |
| and any([x in name for x in lora_target_modules]) |
| ): |
| lora_module_names.add(name) |
| return sorted(list(lora_module_names)) |
|
|
| lora_alpha = args.lora_alpha |
| lora_dropout = args.lora_dropout |
| lora_target_modules = find_linear_layers( |
| model, args.lora_target_modules.split(",") |
| ) |
| lora_config = LoraConfig( |
| r=lora_r, |
| lora_alpha=lora_alpha, |
| target_modules=lora_target_modules, |
| lora_dropout=lora_dropout, |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
| model = get_peft_model(model, lora_config) |
| model.print_trainable_parameters() |
|
|
| model.resize_token_embeddings(len(tokenizer)) |
|
|
| |
| for n, p in model.named_parameters(): |
| if any( |
| [ |
| x in n |
| for x in ["lm_head", "embed_tokens", "mask_decoder", "text_hidden_fcs"] |
| ] |
| ): |
| print("n: ", n, "p.shape: ", p.shape) |
| p.requires_grad = True |
|
|
| world_size = torch.cuda.device_count() |
| args.distributed = world_size > 1 |
| train_dataset = HybridDataset( |
| args.dataset_dir, |
| tokenizer, |
| args.vision_tower, |
| samples_per_epoch=args.batch_size |
| * args.grad_accumulation_steps |
| * args.steps_per_epoch |
| * world_size, |
| precision=args.precision, |
| image_size=args.image_size, |
| num_classes_per_sample=args.num_classes_per_sample, |
| exclude_val=args.exclude_val, |
| dataset=args.dataset, |
| sample_rate=[float(x) for x in args.sample_rates.split(",")], |
| sem_seg_data=args.sem_seg_data, |
| refer_seg_data=args.refer_seg_data, |
| vqa_data=args.vqa_data, |
| reason_seg_data=args.reason_seg_data, |
| aff_seg_data=args.aff_seg_data, |
| aff_sample_rate=[float(x) for x in args.aff_sample_rates.split(",")], |
| reason_aff_data=args.reason_aff_data, |
| reason_aff_sample_rate=[float(x) for x in args.reason_aff_sample_rates.split(",")], |
| explanatory=args.explanatory, |
| ) |
|
|
| if args.no_eval == False: |
| if args.eval_affordance: |
| val_dataset = AffValDataset( |
| args.dataset_dir, |
| tokenizer, |
| args.vision_tower, |
| args.val_dataset, |
| args.image_size, |
| ) |
| elif args.eval_reason_aff: |
| val_dataset = ReasonAffValDataset( |
| args.dataset_dir, |
| tokenizer, |
| args.vision_tower, |
| args.val_dataset, |
| args.image_size, |
| ) |
| else: |
| val_dataset = ValDataset( |
| args.dataset_dir, |
| tokenizer, |
| args.vision_tower, |
| args.val_dataset, |
| args.image_size, |
| ) |
| print( |
| f"Training with {len(train_dataset)} examples and validating with {len(val_dataset)} examples." |
| ) |
| else: |
| val_dataset = None |
| print(f"Training with {len(train_dataset)} examples.") |
|
|
| ds_config = { |
| "train_micro_batch_size_per_gpu": args.batch_size, |
| "gradient_accumulation_steps": args.grad_accumulation_steps, |
| "optimizer": { |
| "type": "AdamW", |
| "params": { |
| "lr": args.lr, |
| "weight_decay": 0.0, |
| "betas": (args.beta1, args.beta2), |
| }, |
| }, |
| "scheduler": { |
| "type": "WarmupDecayLR", |
| "params": { |
| "total_num_steps": args.epochs * args.steps_per_epoch, |
| "warmup_min_lr": 0, |
| "warmup_max_lr": args.lr, |
| "warmup_num_steps": 100, |
| "warmup_type": "linear", |
| }, |
| }, |
| "fp16": { |
| "enabled": args.precision == "fp16", |
| }, |
| "bf16": { |
| "enabled": args.precision == "bf16", |
| }, |
| "gradient_clipping": 1.0, |
| "zero_optimization": { |
| "stage": 2, |
| "contiguous_gradients": True, |
| "overlap_comm": True, |
| "reduce_scatter": True, |
| "reduce_bucket_size": 5e8, |
| "allgather_bucket_size": 5e8, |
| }, |
| } |
| model_engine, optimizer, train_loader, scheduler = deepspeed.initialize( |
| model=model, |
| model_parameters=model.parameters(), |
| training_data=train_dataset, |
| collate_fn=partial( |
| collate_fn, |
| tokenizer=tokenizer, |
| conv_type=args.conv_type, |
| use_mm_start_end=args.use_mm_start_end, |
| local_rank=args.local_rank, |
| ), |
| config=ds_config, |
| ) |
|
|
| |
| if args.auto_resume and len(args.resume) == 0: |
| resume = os.path.join(args.log_dir, "ckpt_model") |
| if os.path.exists(resume): |
| args.resume = resume |
|
|
| if args.resume: |
| load_path, client_state = model_engine.load_checkpoint(args.resume) |
| with open(os.path.join(args.resume, "latest"), "r") as f: |
| ckpt_dir = f.readlines()[0].strip() |
| args.start_epoch = ( |
| int(ckpt_dir.replace("global_step", "")) // args.steps_per_epoch |
| ) |
| print( |
| "resume training from {}, start from epoch {}".format( |
| args.resume, args.start_epoch |
| ) |
| ) |
|
|
| |
| if val_dataset is not None: |
| assert args.val_batch_size == 1 |
| val_sampler = torch.utils.data.distributed.DistributedSampler( |
| val_dataset, shuffle=False, drop_last=False |
| ) |
| val_loader = torch.utils.data.DataLoader( |
| val_dataset, |
| batch_size=args.val_batch_size, |
| shuffle=False, |
| num_workers=args.workers, |
| pin_memory=False, |
| sampler=val_sampler, |
| collate_fn=partial( |
| collate_fn, |
| tokenizer=tokenizer, |
| conv_type=args.conv_type, |
| use_mm_start_end=args.use_mm_start_end, |
| local_rank=args.local_rank, |
| ), |
| ) |
|
|
| train_iter = iter(train_loader) |
| best_score, cur_ciou = 0.0, 0.0 |
|
|
| if args.eval_only: |
| giou, ciou = validate(val_loader, model_engine, 0, writer, args) |
| if args.local_rank == 0: |
| with open(os.path.join(args.version, "eval_result.txt"), "a") as f: |
| f.write(f"dataset: {args.val_dataset}, giou: {giou}, ciou: {ciou} \n") |
| exit() |
|
|
| for epoch in range(args.start_epoch, args.epochs): |
| |
| train_iter = train( |
| train_loader, |
| model_engine, |
| epoch, |
| scheduler, |
| writer, |
| train_iter, |
| args, |
| ) |
|
|
| if args.no_eval == False: |
| giou, ciou = validate(val_loader, model_engine, epoch, writer, args) |
| is_best = giou > best_score |
| best_score = max(giou, best_score) |
| cur_ciou = ciou if is_best else cur_ciou |
|
|
| if args.no_eval or is_best: |
| save_dir = os.path.join(args.log_dir, "ckpt_model") |
| if args.local_rank == 0: |
| torch.save( |
| {"epoch": epoch}, |
| os.path.join( |
| args.log_dir, |
| "meta_log_giou{:.3f}_ciou{:.3f}.pth".format( |
| best_score, cur_ciou |
| ), |
| ), |
| ) |
| if os.path.exists(save_dir): |
| shutil.rmtree(save_dir) |
| torch.distributed.barrier() |
| model_engine.save_checkpoint(save_dir) |
|
|
|
|
| def train( |
| train_loader, |
| model, |
| epoch, |
| scheduler, |
| writer, |
| train_iter, |
| args, |
| ): |
| """Main training loop.""" |
| batch_time = AverageMeter("Time", ":6.3f") |
| data_time = AverageMeter("Data", ":6.3f") |
| losses = AverageMeter("Loss", ":.4f") |
| ce_losses = AverageMeter("CeLoss", ":.4f") |
| mask_bce_losses = AverageMeter("MaskBCELoss", ":.4f") |
| mask_dice_losses = AverageMeter("MaskDICELoss", ":.4f") |
| mask_losses = AverageMeter("MaskLoss", ":.4f") |
|
|
| progress = ProgressMeter( |
| args.steps_per_epoch, |
| [ |
| batch_time, |
| losses, |
| ce_losses, |
| mask_losses, |
| mask_bce_losses, |
| mask_dice_losses, |
| ], |
| prefix="Epoch: [{}]".format(epoch), |
| ) |
|
|
| |
| model.train() |
| end = time.time() |
| for global_step in range(args.steps_per_epoch): |
| for i in range(args.grad_accumulation_steps): |
| try: |
| input_dict = next(train_iter) |
| except: |
| train_iter = iter(train_loader) |
| input_dict = next(train_iter) |
|
|
| data_time.update(time.time() - end) |
| input_dict = dict_to_cuda(input_dict) |
|
|
| if args.precision == "fp16": |
| input_dict["images"] = input_dict["images"].half() |
| input_dict["images_clip"] = input_dict["images_clip"].half() |
| elif args.precision == "bf16": |
| input_dict["images"] = input_dict["images"].bfloat16() |
| input_dict["images_clip"] = input_dict["images_clip"].bfloat16() |
| else: |
| input_dict["images"] = input_dict["images"].float() |
| input_dict["images_clip"] = input_dict["images_clip"].float() |
|
|
| output_dict = model(**input_dict) |
|
|
| loss = output_dict["loss"] |
| ce_loss = output_dict["ce_loss"] |
| mask_bce_loss = output_dict["mask_bce_loss"] |
| mask_dice_loss = output_dict["mask_dice_loss"] |
| mask_loss = output_dict["mask_loss"] |
|
|
| losses.update(loss.item(), input_dict["images"].size(0)) |
| ce_losses.update(ce_loss.item(), input_dict["images"].size(0)) |
| mask_bce_losses.update(mask_bce_loss.item(), input_dict["images"].size(0)) |
| mask_dice_losses.update(mask_dice_loss.item(), input_dict["images"].size(0)) |
| mask_losses.update(mask_loss.item(), input_dict["images"].size(0)) |
| model.backward(loss) |
| model.step() |
|
|
| |
| batch_time.update(time.time() - end) |
| end = time.time() |
|
|
| if global_step % args.print_freq == 0: |
| if args.distributed: |
| batch_time.all_reduce() |
| data_time.all_reduce() |
|
|
| losses.all_reduce() |
| ce_losses.all_reduce() |
| mask_bce_losses.all_reduce() |
| mask_dice_losses.all_reduce() |
| mask_losses.all_reduce() |
|
|
| if args.local_rank == 0: |
| progress.display(global_step + 1) |
| writer.add_scalar("train/loss", losses.avg, global_step) |
| writer.add_scalar("train/ce_loss", ce_losses.avg, global_step) |
| writer.add_scalar( |
| "train/mask_bce_loss", mask_bce_losses.avg, global_step |
| ) |
| writer.add_scalar( |
| "train/mask_dice_loss", mask_dice_losses.avg, global_step |
| ) |
| writer.add_scalar("train/mask_loss", mask_losses.avg, global_step) |
| writer.add_scalar( |
| "metrics/total_secs_per_batch", batch_time.avg, global_step |
| ) |
| writer.add_scalar( |
| "metrics/data_secs_per_batch", data_time.avg, global_step |
| ) |
|
|
| batch_time.reset() |
| data_time.reset() |
| losses.reset() |
| ce_losses.reset() |
| mask_bce_losses.reset() |
| mask_dice_losses.reset() |
| mask_losses.reset() |
|
|
| if global_step != 0: |
| curr_lr = scheduler.get_last_lr() |
| if args.local_rank == 0: |
| writer.add_scalar("train/lr", curr_lr[0], global_step) |
|
|
| return train_iter |
|
|
|
|
| def validate(val_loader, model_engine, epoch, writer, args): |
| intersection_meter = AverageMeter("Intersec", ":6.3f", Summary.SUM) |
| union_meter = AverageMeter("Union", ":6.3f", Summary.SUM) |
| acc_iou_meter = AverageMeter("gIoU", ":6.3f", Summary.SUM) |
|
|
| model_engine.eval() |
|
|
| for input_dict in tqdm.tqdm(val_loader): |
| torch.cuda.empty_cache() |
|
|
| input_dict = dict_to_cuda(input_dict) |
| if args.precision == "fp16": |
| input_dict["images"] = input_dict["images"].half() |
| input_dict["images_clip"] = input_dict["images_clip"].half() |
| elif args.precision == "bf16": |
| input_dict["images"] = input_dict["images"].bfloat16() |
| input_dict["images_clip"] = input_dict["images_clip"].bfloat16() |
| else: |
| input_dict["images"] = input_dict["images"].float() |
| input_dict["images_clip"] = input_dict["images_clip"].float() |
|
|
| with torch.no_grad(): |
| output_dict = model_engine(**input_dict) |
|
|
| pred_masks = output_dict["pred_masks"] |
| masks_list = output_dict["gt_masks"][0].int() |
| output_list = (pred_masks[0] > 0).int() |
| assert len(pred_masks) == 1 |
|
|
| intersection, union, acc_iou = 0.0, 0.0, 0.0 |
| for mask_i, output_i in zip(masks_list, output_list): |
| intersection_i, union_i, _ = intersectionAndUnionGPU( |
| output_i.contiguous().clone(), mask_i.contiguous(), 2, ignore_index=255 |
| ) |
| intersection += intersection_i |
| union += union_i |
| acc_iou += intersection_i / (union_i + 1e-5) |
| acc_iou[union_i == 0] += 1.0 |
| intersection, union = intersection.cpu().numpy(), union.cpu().numpy() |
| acc_iou = acc_iou.cpu().numpy() / masks_list.shape[0] |
| intersection_meter.update(intersection), union_meter.update( |
| union |
| ), acc_iou_meter.update(acc_iou, n=masks_list.shape[0]) |
|
|
| intersection_meter.all_reduce() |
| union_meter.all_reduce() |
| acc_iou_meter.all_reduce() |
|
|
| iou_class = intersection_meter.sum / (union_meter.sum + 1e-10) |
| ciou = iou_class[1] |
| giou = acc_iou_meter.avg[1] |
|
|
| if args.local_rank == 0: |
| writer.add_scalar("val/giou", giou, epoch) |
| writer.add_scalar("val/ciou", ciou, epoch) |
| print("giou: {:.4f}, ciou: {:.4f}".format(giou, ciou)) |
|
|
| return giou, ciou |
|
|
|
|
| if __name__ == "__main__": |
| main(sys.argv[1:]) |
|
|