| import transformers |
|
|
| from torch.cuda.amp import autocast, GradScaler |
|
|
| from datasets import REFAVS |
| from configs import args |
| from torch.utils.data import DataLoader |
| from functools import partial |
| from models.llava import conversation as conversation_lib |
| |
| from models.avs_model import Simtoken_ForCausalLM |
| import torch |
| from torch.cuda import amp |
| from transformers import AutoConfig |
| from peft import LoraConfig, get_peft_model |
| from torch import optim |
| from torch.optim import AdamW |
| from transformers import get_cosine_schedule_with_warmup |
| from tqdm import tqdm |
|
|
| from utils import utility |
| import random |
| import numpy as np |
| import re |
| import time |
| import os |
| from PIL import Image |
|
|
|
|
| import warnings |
|
|
| from utils.metric.utility import mask_iou |
|
|
| warnings.filterwarnings("ignore") |
|
|
| from transformers import logging |
| logging.set_verbosity_error() |
|
|
|
|
| IGNORE_INDEX = -100 |
| IMAGE_TOKEN_INDEX = -200 |
| DEFAULT_IMAGE_TOKEN = "<image>" |
| DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>" |
| DEFAULT_IM_START_TOKEN = "<im_start>" |
| DEFAULT_IM_END_TOKEN = "<im_end>" |
| DEFAULT_VIDEO_TOKEN = "<video>" |
|
|
| AUDIO_TOKEN_INDEX = -300 |
| DEFAULT_AUDIO_TOKEN = "<audio>" |
|
|
| def set_seed(seed=42): |
| torch.manual_seed(seed) |
| np.random.seed(seed) |
| random.seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
| def dict_to_cuda(input_dict): |
| for k, v in input_dict.items(): |
| if isinstance(input_dict[k], torch.Tensor): |
| input_dict[k] = v.cuda(non_blocking=True) |
| elif ( |
| isinstance(input_dict[k], list) |
| and len(input_dict[k]) > 0 |
| and isinstance(input_dict[k][0], torch.Tensor) |
| ): |
| input_dict[k] = [ele.cuda(non_blocking=True) for ele in v] |
| return input_dict |
|
|
| def tokenizer_image_audio_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, audio_token_index=AUDIO_TOKEN_INDEX, num_frames=10, return_tensors=None): |
|
|
| prompt_chunks = re.split(r'(<image>|<audio>|<video>)', prompt) |
|
|
| prompt_chunks = [chunk for chunk in prompt_chunks if chunk] |
|
|
| |
| text_chunks = [] |
| token_types = [] |
| for chunk in prompt_chunks: |
| if chunk == "<image>": |
| token_types.append("image") |
| elif chunk == "<audio>": |
| token_types.append("audio") |
| elif chunk == "<video>": |
| token_types.append("video") |
| else: |
| text_chunks.append(chunk) |
|
|
| |
| tokenized_chunks = [tokenizer(chunk).input_ids for chunk in text_chunks] |
|
|
| def insert_separators(text_chunks, tokenized_chunks, token_types, image_token_index, audio_token_index, num_frames): |
| input_ids = [] |
| offset = 0 |
| if ( |
| len(tokenized_chunks) > 0 |
| and len(tokenized_chunks[0]) > 0 |
| and tokenized_chunks[0][0] == tokenizer.bos_token_id |
| ): |
| offset = 1 |
| input_ids.append(tokenized_chunks[0][0]) |
|
|
| min_length = min(len(text_chunks), len(token_types)) |
| for i in range(min_length): |
|
|
| input_ids.extend(tokenized_chunks[i][offset:]) |
|
|
| if token_types[i] == "image": |
| input_ids.append(image_token_index) |
| elif token_types[i] == "audio": |
| input_ids.append(audio_token_index) |
| elif token_types[i] == "video": |
| input_ids.extend([image_token_index] * num_frames) |
|
|
|
|
| if len(text_chunks) > min_length: |
| input_ids.extend(tokenized_chunks[min_length][offset:]) |
|
|
| return input_ids |
|
|
| input_ids = insert_separators(text_chunks, tokenized_chunks, token_types, image_token_index, audio_token_index, num_frames) |
|
|
| if return_tensors is not None: |
| if return_tensors == "pt": |
| return torch.tensor(input_ids, dtype=torch.long) |
| raise ValueError(f"Unsupported tensor type: {return_tensors}") |
| return input_ids |
|
|
| def collate_fn(batch, tokenizer=None): |
| vids = [] |
| images = [] |
| image_clips = [] |
| masks = [] |
| conversations = [] |
| audio_feats = [] |
| image_feats = [] |
| resizes = [] |
| orgsizes = [] |
| first_refs = [] |
|
|
| refs = [] |
| first_refs = [] |
| refs_num = [] |
| fids = [] |
|
|
|
|
| for data in batch: |
| vids.append(data['vid']) |
| images.append(data['image']) |
| image_clips.append(data['img_clip']) |
| masks.append(data['mask']) |
| conversations.append(data['conversation']) |
| audio_feats.append(data['feat_aud']) |
| resizes.append(data['resize']) |
| orgsizes.append(data['orgsize']) |
| image_feats.append(data['feat_sam']) |
| refs_num.append(len(data['ref'])) |
| fids.append(data['fids']) |
|
|
| refs.append(data['ref']) |
| first_refs.append(data['ref'][0]) |
|
|
| input_ids = [tokenizer_image_audio_token(conv, tokenizer, return_tensors="pt") for conv in conversations] |
| input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) |
| attention_masks = input_ids.ne(tokenizer.pad_token_id) |
|
|
| ref_ids = [tokenizer_image_audio_token(ref, tokenizer, return_tensors="pt") for ref in first_refs] |
|
|
| conv = conversation_lib.default_conversation.copy() |
| labels = input_ids.clone() |
|
|
| sep = 'Sure, it is [SEG]' |
|
|
| for conversation, target in zip(conversations, labels): |
| parts = conversation.split(sep) |
| cur_len = 1 |
| target[:cur_len] = IGNORE_INDEX |
|
|
| sep_len = len(tokenizer_image_audio_token(sep, tokenizer)) - 1 |
|
|
| for i in range(len(parts)-1): |
| part_len = len(tokenizer_image_audio_token(parts[i], tokenizer)) - 2 |
| target[cur_len: cur_len + part_len] = IGNORE_INDEX |
| cur_len += part_len + sep_len |
|
|
| target[cur_len:] = IGNORE_INDEX |
|
|
|
|
| return {"vids": vids, |
| "images": images, |
| "images_clip": image_clips, |
| "masks": masks, |
| "convs": conversations, |
| "input_ids": input_ids, |
| "attention_masks": attention_masks, |
| "labels": labels, |
| "audio_feats": audio_feats, |
| "resizes": resizes, |
| "orgsizes": orgsizes, |
| "image_feats": image_feats, |
| "ref_ids": ref_ids, |
| "refs_num": refs_num, |
| "fids": fids, |
| "refs": refs, |
| } |
|
|
|
|
| import torch.multiprocessing as mp |
| if __name__ == "__main__": |
| mp.set_start_method("spawn", force=True) |
| set_seed(42) |
| tokenizer = transformers.AutoTokenizer.from_pretrained( |
| args.mllm, |
| cache_dir=None, |
| model_max_length=2048, |
| padding_side="right", |
| use_fast=False, |
| ) |
|
|
| tokenizer.pad_token = tokenizer.unk_token |
| num_added_tokens = tokenizer.add_tokens("[SEG]") |
| seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] |
| print("seg_token_idx: ", seg_token_idx) |
|
|
|
|
| eval_splits = {split.strip() for split in args.eval_splits.split(",") if split.strip()} |
| val_dataset_s = REFAVS('test_s', args, tokenizer, input_type='refer') if 'test_s' in eval_splits else None |
| val_dataset_u = REFAVS('test_u', args, tokenizer, input_type='refer') if 'test_u' in eval_splits else None |
| val_dataset_n = REFAVS('test_n', args, tokenizer, input_type='refer') if 'test_n' in eval_splits else None |
|
|
|
|
| val_dataloader_s = DataLoader(val_dataset_s, batch_size=1, shuffle=False, num_workers=4, collate_fn=partial(collate_fn, tokenizer=tokenizer)) if val_dataset_s is not None else None |
| val_dataloader_u = DataLoader(val_dataset_u, batch_size=1, shuffle=False, num_workers=4, collate_fn=partial(collate_fn, tokenizer=tokenizer)) if val_dataset_u is not None else None |
| val_dataloader_n = DataLoader(val_dataset_n, batch_size=1, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer)) if val_dataset_n is not None else None |
|
|
|
|
|
|
| model_args = { |
| "train_mask_decoder": True, |
| "out_dim": 256, |
| "ce_loss_weight": 1.0, |
| "dice_loss_weight": 0.5, |
| "bce_loss_weight": 2.0, |
| "seg_token_idx": seg_token_idx, |
| "vision_pretrained": args.vision_pretrained, |
| "vision_tower": args.vision_tower, |
| "use_im_start_end": False, |
| "compress": args.compress, |
| "start": args.start, |
| } |
|
|
|
|
| |
| model = Simtoken_ForCausalLM.from_pretrained(args.mllm, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True, |
| **model_args) |
|
|
| print("\nmodel loaded") |
|
|
| model.config.eos_token_id = tokenizer.eos_token_id |
| model.config.bos_token_id = tokenizer.bos_token_id |
| model.config.pad_token_id = tokenizer.pad_token_id |
|
|
| model.enable_input_require_grads() |
| model.gradient_checkpointing_enable() |
|
|
| model.get_model().initialize_vision_modules(model.get_model().config) |
| vision_tower = model.get_model().get_vision_tower() |
| vision_tower.to(dtype=torch.float32, device="cuda") |
|
|
| model_args_from_pt = AutoConfig.from_pretrained(args.mllm) |
| model_args_from_pt.use_cluster = True |
| model_args_from_pt.freeze = False |
| model_args_from_pt.mm_tune = True |
| model_args_from_pt.spatial_cluster_rate0 = 64 |
| model_args_from_pt.spatial_cluster_rate1 = 32 |
| model_args_from_pt.spatial_cluster_rate2 = 16 |
| model_args_from_pt.temporal_cluster_rate = 0.0625 |
| model_args_from_pt.use_cluster = True |
| model_args_from_pt.vision_tune = False |
| model.get_model().initialize_cluster_modules(model_args_from_pt) |
|
|
| model.get_model().initialize_lisa_modules(model.get_model().config) |
|
|
| for p in vision_tower.parameters(): |
| p.requires_grad = False |
| for p in model.get_model().mm_projector.parameters(): |
| p.requires_grad = False |
|
|
| lora_r = 8 |
| target_modules = "q_proj,v_proj" |
| if lora_r > 0: |
| def find_linear_layers(model, lora_target_modules): |
| cls = torch.nn.Linear |
| lora_module_names = set() |
|
|
| for name, module in model.named_modules(): |
| if ( |
| isinstance(module, cls) |
| and all( |
| [ |
| x not in name |
| for x in [ |
| "visual_model", |
| "vision_tower", |
| "mm_projector", |
| "text_hidden_fcs", |
| "audio_feature_layer", |
| ] |
| ] |
| ) |
| and any([x in name for x in lora_target_modules]) |
| ): |
| lora_module_names.add(name) |
| return sorted(list(lora_module_names)) |
|
|
|
|
| lora_alpha = 16 |
| lora_dropout = 0.05 |
|
|
| lora_target_modules = find_linear_layers( |
| model, target_modules.split(",") |
| ) |
| lora_config = LoraConfig( |
| r=lora_r, |
| lora_alpha=lora_alpha, |
| target_modules=lora_target_modules, |
| lora_dropout=lora_dropout, |
| bias="none", |
| task_type="CAUSAL_LM", |
| ) |
|
|
| model = get_peft_model(model, lora_config) |
| print("\nLora deployed") |
|
|
| model.print_trainable_parameters() |
|
|
| model = model.to("cuda") |
| model.resize_token_embeddings(len(tokenizer)) |
|
|
| model.load_state_dict(torch.load(args.saved_model), strict=False) |
| print("saved model loaded") |
|
|
|
|
| save_root = args.visualization_root |
|
|
| def visualization(model, dataloader, save_root, name): |
| save_root = os.path.join(save_root, name) |
| os.makedirs(save_root, exist_ok=True) |
| print(f"save_root: {save_root}") |
| model.eval() |
| for batch in tqdm(dataloader, desc=f"Visualization on {name} "): |
| input_dict = dict_to_cuda(batch) |
| with torch.no_grad(): |
| output_dict = model.forward(images=input_dict["images"], |
| images_clip=input_dict["images_clip"], |
| audio_features=input_dict["audio_feats"], |
| image_features=input_dict["image_feats"], |
| input_ids=input_dict["input_ids"], |
| labels=input_dict["labels"], |
| attention_masks=input_dict["attention_masks"], |
| masks_list=input_dict["masks"], |
| resize_list=input_dict["resizes"], |
| orgsize_list=input_dict["orgsizes"], |
| conversation_list=input_dict["convs"], |
| refs_num=input_dict["refs_num"], |
| fids=input_dict["fids"], |
| vids=input_dict["vids"], |
| contrast=args.ct_weight, |
| ref_ids=input_dict["ref_ids"], |
| inference=True) |
| pred_masks = output_dict["pred_masks"] |
| gt_masks = output_dict["gt_masks"] |
|
|
| for b in range(len(pred_masks)): |
| sample = torch.sigmoid(pred_masks[b]) |
| vid = input_dict["vids"][b] |
| vid_root = os.path.join(save_root, vid) |
| os.makedirs(vid_root, exist_ok=True) |
| |
|
|
| binary_sample = (sample > 0.4).to(torch.uint8) |
| num_seg, T, H, W = sample.shape |
|
|
| for seg_idx in range(num_seg): |
| ref = input_dict["refs"][b][seg_idx] |
| ref_root = os.path.join(vid_root, ref) |
| os.makedirs(ref_root, exist_ok=True) |
| |
|
|
| for t in range(T): |
| mask_np = binary_sample[seg_idx, t].cpu().numpy() * 255 |
| mask_img = Image.fromarray(mask_np.astype(np.uint8)) |
|
|
| save_path = os.path.join(ref_root, f"frame{t}.png") |
| mask_img.save(save_path) |
| |
| print("visualization finished") |
|
|
|
|
| def valuate(model, dataloader, name): |
| model.eval() |
|
|
| total_iou = 0 |
| total_fscore = 0 |
| count = 0 |
|
|
| for batch in tqdm(dataloader, desc=f"Evaluating on {name}"): |
| input_dict = dict_to_cuda(batch) |
|
|
| with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True): |
| with torch.no_grad(): |
| output_dict = model.forward(images=input_dict["images"], |
| images_clip=input_dict["images_clip"], |
| audio_features=input_dict["audio_feats"], |
| image_features=input_dict["image_feats"], |
| input_ids=input_dict["input_ids"], |
| labels=input_dict["labels"], |
| attention_masks=input_dict["attention_masks"], |
| masks_list=input_dict["masks"], |
| resize_list=input_dict["resizes"], |
| orgsize_list=input_dict["orgsizes"], |
| conversation_list=input_dict["convs"], |
| refs_num=input_dict["refs_num"], |
| fids=input_dict["fids"], |
| vids=input_dict["vids"], |
| contrast=args.ct_weight, |
| ref_ids=input_dict["ref_ids"], |
| inference=True) |
| pred_masks = output_dict["pred_masks"] |
| gt_masks = output_dict["gt_masks"] |
| for i in range(len(pred_masks)): |
| num_seg = pred_masks[i].shape[0] |
| T = pred_masks[i].shape[1] |
| iou = utility.mask_iou(pred_masks[i], gt_masks[i]) |
| fscore = utility.Eval_Fmeasure(pred_masks[i], gt_masks[i], None) |
|
|
| total_iou += iou * num_seg * T |
| total_fscore += fscore * num_seg * T |
| count += num_seg * T |
|
|
| print(f"\n valuate on {name}: miou: {total_iou/count} fscore: {total_fscore/count}") |
|
|
|
|
| def valuate_Null(model, dataloader): |
| model.eval() |
|
|
| total_metric = 0 |
| count = 0 |
|
|
| for batch in tqdm(dataloader, desc=f"Evaluating on Null"): |
| input_dict = dict_to_cuda(batch) |
| with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True): |
| with torch.no_grad(): |
| output_dict = model.forward(images=input_dict["images"], |
| images_clip=input_dict["images_clip"], |
| audio_features=input_dict["audio_feats"], |
| image_features=input_dict["image_feats"], |
| input_ids=input_dict["input_ids"], |
| labels=input_dict["labels"], |
| attention_masks=input_dict["attention_masks"], |
| masks_list=input_dict["masks"], |
| resize_list=input_dict["resizes"], |
| orgsize_list=input_dict["orgsizes"], |
| conversation_list=input_dict["convs"], |
| refs_num=input_dict["refs_num"], |
| fids=input_dict["fids"], |
| vids=input_dict["vids"], |
| contrast=args.ct_weight, |
| ref_ids=input_dict["ref_ids"], |
| inference=True) |
| pred_masks = output_dict["pred_masks"] |
| gt_masks = output_dict["gt_masks"] |
| for i in range(len(pred_masks)): |
| num_seg = pred_masks[i].shape[0] |
| T = pred_masks[i].shape[1] |
| null_metric = utility.metric_s_for_null(pred_masks[i]) |
|
|
| total_metric += null_metric * num_seg * T |
| count += num_seg * T |
|
|
| print(f"\n valuate on test_n_refer, metric: {total_metric / count}") |
|
|
|
|
|
|
|
|
| if val_dataloader_s is not None: |
| valuate(model, val_dataloader_s, 'test_seen') |
| if val_dataloader_u is not None: |
| valuate(model, val_dataloader_u, 'test_unseen') |
| if val_dataloader_n is not None: |
| valuate_Null(model, val_dataloader_n) |
|
|