SimToken / load_model.py
yfan07's picture
Add files using upload-large-folder tool
0f55e72 verified
import transformers
from torch.cuda.amp import autocast, GradScaler
from datasets import REFAVS
from configs import args
from torch.utils.data import DataLoader
from functools import partial
from models.llava import conversation as conversation_lib
# from models.avs_model import VISAForCausalLM
from models.avs_model import Simtoken_ForCausalLM
import torch
from torch.cuda import amp
from transformers import AutoConfig
from peft import LoraConfig, get_peft_model
from torch import optim
from torch.optim import AdamW
from transformers import get_cosine_schedule_with_warmup
from tqdm import tqdm
from utils import utility
import random
import numpy as np
import re
import time
import os
from PIL import Image
import warnings
from utils.metric.utility import mask_iou
warnings.filterwarnings("ignore")
from transformers import logging
logging.set_verbosity_error()
IGNORE_INDEX = -100
IMAGE_TOKEN_INDEX = -200
DEFAULT_IMAGE_TOKEN = "<image>"
DEFAULT_IMAGE_PATCH_TOKEN = "<im_patch>"
DEFAULT_IM_START_TOKEN = "<im_start>"
DEFAULT_IM_END_TOKEN = "<im_end>"
DEFAULT_VIDEO_TOKEN = "<video>"
AUDIO_TOKEN_INDEX = -300
DEFAULT_AUDIO_TOKEN = "<audio>"
def set_seed(seed=42):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def dict_to_cuda(input_dict):
for k, v in input_dict.items():
if isinstance(input_dict[k], torch.Tensor):
input_dict[k] = v.cuda(non_blocking=True)
elif (
isinstance(input_dict[k], list)
and len(input_dict[k]) > 0
and isinstance(input_dict[k][0], torch.Tensor)
):
input_dict[k] = [ele.cuda(non_blocking=True) for ele in v]
return input_dict
def tokenizer_image_audio_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, audio_token_index=AUDIO_TOKEN_INDEX, num_frames=10, return_tensors=None):
prompt_chunks = re.split(r'(<image>|<audio>|<video>)', prompt)
prompt_chunks = [chunk for chunk in prompt_chunks if chunk]
# divide prompt into two set
text_chunks = [] # text
token_types = [] # <image>/<audio>/<video>
for chunk in prompt_chunks:
if chunk == "<image>":
token_types.append("image")
elif chunk == "<audio>":
token_types.append("audio")
elif chunk == "<video>":
token_types.append("video")
else:
text_chunks.append(chunk)
# Tokenize the text
tokenized_chunks = [tokenizer(chunk).input_ids for chunk in text_chunks]
def insert_separators(text_chunks, tokenized_chunks, token_types, image_token_index, audio_token_index, num_frames):
input_ids = []
offset = 0
if (
len(tokenized_chunks) > 0
and len(tokenized_chunks[0]) > 0
and tokenized_chunks[0][0] == tokenizer.bos_token_id
):
offset = 1
input_ids.append(tokenized_chunks[0][0])
min_length = min(len(text_chunks), len(token_types))
for i in range(min_length):
input_ids.extend(tokenized_chunks[i][offset:])
if token_types[i] == "image":
input_ids.append(image_token_index)
elif token_types[i] == "audio":
input_ids.append(audio_token_index)
elif token_types[i] == "video":
input_ids.extend([image_token_index] * num_frames)
if len(text_chunks) > min_length:
input_ids.extend(tokenized_chunks[min_length][offset:])
return input_ids
input_ids = insert_separators(text_chunks, tokenized_chunks, token_types, image_token_index, audio_token_index, num_frames)
if return_tensors is not None:
if return_tensors == "pt":
return torch.tensor(input_ids, dtype=torch.long)
raise ValueError(f"Unsupported tensor type: {return_tensors}")
return input_ids
def collate_fn(batch, tokenizer=None):
vids = []
images = []
image_clips = []
masks = []
conversations = []
audio_feats = []
image_feats = []
resizes = []
orgsizes = []
first_refs = []
refs = []
first_refs = []
refs_num = []
fids = []
for data in batch:
vids.append(data['vid'])
images.append(data['image'])
image_clips.append(data['img_clip'])
masks.append(data['mask'])
conversations.append(data['conversation'])
audio_feats.append(data['feat_aud'])
resizes.append(data['resize'])
orgsizes.append(data['orgsize'])
image_feats.append(data['feat_sam'])
refs_num.append(len(data['ref']))
fids.append(data['fids'])
refs.append(data['ref'])
first_refs.append(data['ref'][0])
input_ids = [tokenizer_image_audio_token(conv, tokenizer, return_tensors="pt") for conv in conversations] # list
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
attention_masks = input_ids.ne(tokenizer.pad_token_id)
ref_ids = [tokenizer_image_audio_token(ref, tokenizer, return_tensors="pt") for ref in first_refs]
conv = conversation_lib.default_conversation.copy()
labels = input_ids.clone()
sep = 'Sure, it is [SEG]'
for conversation, target in zip(conversations, labels):
parts = conversation.split(sep)
cur_len = 1
target[:cur_len] = IGNORE_INDEX
sep_len = len(tokenizer_image_audio_token(sep, tokenizer)) - 1
for i in range(len(parts)-1):
part_len = len(tokenizer_image_audio_token(parts[i], tokenizer)) - 2
target[cur_len: cur_len + part_len] = IGNORE_INDEX
cur_len += part_len + sep_len
target[cur_len:] = IGNORE_INDEX
return {"vids": vids,
"images": images, # list[B]:[T, 3, 1024, 1024]
"images_clip": image_clips, # list[B]:[T, 3, 224, 224]
"masks": masks, # list[B]:[num_ref, T, H, W]
"convs": conversations, # list[B]: str
"input_ids": input_ids, # list[B]:[max_len]
"attention_masks": attention_masks, # list[B]:[max_len]
"labels": labels, # list[B]:[max_len]
"audio_feats": audio_feats, # list[B]:[10, 128]
"resizes": resizes, # list[B]
"orgsizes": orgsizes, # list[B]
"image_feats": image_feats,
"ref_ids": ref_ids, # list[B]: [ref_id_len]
"refs_num": refs_num,
"fids": fids,
"refs": refs,
}
import torch.multiprocessing as mp
if __name__ == "__main__":
mp.set_start_method("spawn", force=True)
set_seed(42)
tokenizer = transformers.AutoTokenizer.from_pretrained(
args.mllm,
cache_dir=None,
model_max_length=2048, # 2048
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
num_added_tokens = tokenizer.add_tokens("[SEG]")
seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0] # 32000
print("seg_token_idx: ", seg_token_idx)
eval_splits = {split.strip() for split in args.eval_splits.split(",") if split.strip()}
val_dataset_s = REFAVS('test_s', args, tokenizer, input_type='refer') if 'test_s' in eval_splits else None
val_dataset_u = REFAVS('test_u', args, tokenizer, input_type='refer') if 'test_u' in eval_splits else None
val_dataset_n = REFAVS('test_n', args, tokenizer, input_type='refer') if 'test_n' in eval_splits else None
val_dataloader_s = DataLoader(val_dataset_s, batch_size=1, shuffle=False, num_workers=4, collate_fn=partial(collate_fn, tokenizer=tokenizer)) if val_dataset_s is not None else None
val_dataloader_u = DataLoader(val_dataset_u, batch_size=1, shuffle=False, num_workers=4, collate_fn=partial(collate_fn, tokenizer=tokenizer)) if val_dataset_u is not None else None
val_dataloader_n = DataLoader(val_dataset_n, batch_size=1, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer)) if val_dataset_n is not None else None
model_args = {
"train_mask_decoder": True,
"out_dim": 256, # 256
"ce_loss_weight": 1.0,
"dice_loss_weight": 0.5,
"bce_loss_weight": 2.0,
"seg_token_idx": seg_token_idx,
"vision_pretrained": args.vision_pretrained, # sam_vit_h_xxx.pth
"vision_tower": args.vision_tower,
"use_im_start_end": False,
"compress": args.compress,
"start": args.start,
}
# model = Simtoken_ForCausalLM.from_pretrained(args.mllm, torch_dtype=torch.float32, low_cpu_mem_usage=True, **model_args)
model = Simtoken_ForCausalLM.from_pretrained(args.mllm, torch_dtype=torch.bfloat16, low_cpu_mem_usage=True,
**model_args)
print("\nmodel loaded")
model.config.eos_token_id = tokenizer.eos_token_id
model.config.bos_token_id = tokenizer.bos_token_id
model.config.pad_token_id = tokenizer.pad_token_id
model.enable_input_require_grads()
model.gradient_checkpointing_enable()
model.get_model().initialize_vision_modules(model.get_model().config)
vision_tower = model.get_model().get_vision_tower()
vision_tower.to(dtype=torch.float32, device="cuda")
model_args_from_pt = AutoConfig.from_pretrained(args.mllm)
model_args_from_pt.use_cluster = True
model_args_from_pt.freeze = False
model_args_from_pt.mm_tune = True
model_args_from_pt.spatial_cluster_rate0 = 64
model_args_from_pt.spatial_cluster_rate1 = 32
model_args_from_pt.spatial_cluster_rate2 = 16
model_args_from_pt.temporal_cluster_rate = 0.0625
model_args_from_pt.use_cluster = True
model_args_from_pt.vision_tune = False
model.get_model().initialize_cluster_modules(model_args_from_pt)
model.get_model().initialize_lisa_modules(model.get_model().config)
for p in vision_tower.parameters():
p.requires_grad = False
for p in model.get_model().mm_projector.parameters():
p.requires_grad = False
lora_r = 8
target_modules = "q_proj,v_proj"
if lora_r > 0:
def find_linear_layers(model, lora_target_modules):
cls = torch.nn.Linear
lora_module_names = set()
for name, module in model.named_modules():
if (
isinstance(module, cls)
and all(
[
x not in name
for x in [
"visual_model",
"vision_tower",
"mm_projector",
"text_hidden_fcs",
"audio_feature_layer",
]
]
)
and any([x in name for x in lora_target_modules])
):
lora_module_names.add(name)
return sorted(list(lora_module_names))
lora_alpha = 16
lora_dropout = 0.05
lora_target_modules = find_linear_layers(
model, target_modules.split(",")
)
lora_config = LoraConfig(
r=lora_r,
lora_alpha=lora_alpha,
target_modules=lora_target_modules,
lora_dropout=lora_dropout,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
print("\nLora deployed")
model.print_trainable_parameters()
model = model.to("cuda")
model.resize_token_embeddings(len(tokenizer))
model.load_state_dict(torch.load(args.saved_model), strict=False)
print("saved model loaded")
save_root = args.visualization_root
def visualization(model, dataloader, save_root, name):
save_root = os.path.join(save_root, name)
os.makedirs(save_root, exist_ok=True)
print(f"save_root: {save_root}")
model.eval()
for batch in tqdm(dataloader, desc=f"Visualization on {name} "):
input_dict = dict_to_cuda(batch)
with torch.no_grad():
output_dict = model.forward(images=input_dict["images"],
images_clip=input_dict["images_clip"],
audio_features=input_dict["audio_feats"],
image_features=input_dict["image_feats"],
input_ids=input_dict["input_ids"],
labels=input_dict["labels"],
attention_masks=input_dict["attention_masks"],
masks_list=input_dict["masks"],
resize_list=input_dict["resizes"],
orgsize_list=input_dict["orgsizes"],
conversation_list=input_dict["convs"],
refs_num=input_dict["refs_num"],
fids=input_dict["fids"],
vids=input_dict["vids"],
contrast=args.ct_weight,
ref_ids=input_dict["ref_ids"],
inference=True)
pred_masks = output_dict["pred_masks"] # list[B]:[num_seg, T, H, W]
gt_masks = output_dict["gt_masks"] # list[B]:[num_seg, T, H, W]
for b in range(len(pred_masks)):
sample = torch.sigmoid(pred_masks[b]) # [num_seg, T, H, W]
vid = input_dict["vids"][b]
vid_root = os.path.join(save_root, vid)
os.makedirs(vid_root, exist_ok=True)
# print("vid_root:", vid_root)
binary_sample = (sample > 0.4).to(torch.uint8)
num_seg, T, H, W = sample.shape
for seg_idx in range(num_seg):
ref = input_dict["refs"][b][seg_idx]
ref_root = os.path.join(vid_root, ref)
os.makedirs(ref_root, exist_ok=True)
# print("ref_root:", ref_root)
for t in range(T):
mask_np = binary_sample[seg_idx, t].cpu().numpy() * 255
mask_img = Image.fromarray(mask_np.astype(np.uint8))
save_path = os.path.join(ref_root, f"frame{t}.png")
mask_img.save(save_path)
# print(f"image saved as {save_path}")
print("visualization finished")
def valuate(model, dataloader, name):
model.eval()
total_iou = 0
total_fscore = 0
count = 0
for batch in tqdm(dataloader, desc=f"Evaluating on {name}"):
input_dict = dict_to_cuda(batch)
with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True):
with torch.no_grad():
output_dict = model.forward(images=input_dict["images"],
images_clip=input_dict["images_clip"],
audio_features=input_dict["audio_feats"],
image_features=input_dict["image_feats"],
input_ids=input_dict["input_ids"],
labels=input_dict["labels"],
attention_masks=input_dict["attention_masks"],
masks_list=input_dict["masks"],
resize_list=input_dict["resizes"],
orgsize_list=input_dict["orgsizes"],
conversation_list=input_dict["convs"],
refs_num=input_dict["refs_num"],
fids=input_dict["fids"],
vids=input_dict["vids"],
contrast=args.ct_weight,
ref_ids=input_dict["ref_ids"],
inference=True)
pred_masks = output_dict["pred_masks"] # list[B]:[num_seg, T, H, W]
gt_masks = output_dict["gt_masks"] # list[B]:[num_seg, T, H, W]
for i in range(len(pred_masks)):
num_seg = pred_masks[i].shape[0]
T = pred_masks[i].shape[1]
iou = utility.mask_iou(pred_masks[i], gt_masks[i])
fscore = utility.Eval_Fmeasure(pred_masks[i], gt_masks[i], None)
total_iou += iou * num_seg * T
total_fscore += fscore * num_seg * T
count += num_seg * T
print(f"\n valuate on {name}: miou: {total_iou/count} fscore: {total_fscore/count}")
def valuate_Null(model, dataloader):
model.eval()
total_metric = 0
count = 0
for batch in tqdm(dataloader, desc=f"Evaluating on Null"):
input_dict = dict_to_cuda(batch)
with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True):
with torch.no_grad():
output_dict = model.forward(images=input_dict["images"],
images_clip=input_dict["images_clip"],
audio_features=input_dict["audio_feats"],
image_features=input_dict["image_feats"],
input_ids=input_dict["input_ids"],
labels=input_dict["labels"],
attention_masks=input_dict["attention_masks"],
masks_list=input_dict["masks"],
resize_list=input_dict["resizes"],
orgsize_list=input_dict["orgsizes"],
conversation_list=input_dict["convs"],
refs_num=input_dict["refs_num"],
fids=input_dict["fids"],
vids=input_dict["vids"],
contrast=args.ct_weight,
ref_ids=input_dict["ref_ids"],
inference=True)
pred_masks = output_dict["pred_masks"] # list[B]:[num_seg, T, H, W]
gt_masks = output_dict["gt_masks"] # list[B]:[num_seg, T, H, W]
for i in range(len(pred_masks)):
num_seg = pred_masks[i].shape[0]
T = pred_masks[i].shape[1]
null_metric = utility.metric_s_for_null(pred_masks[i])
total_metric += null_metric * num_seg * T
count += num_seg * T
print(f"\n valuate on test_n_refer, metric: {total_metric / count}")
if val_dataloader_s is not None:
valuate(model, val_dataloader_s, 'test_seen')
if val_dataloader_u is not None:
valuate(model, val_dataloader_u, 'test_unseen')
if val_dataloader_n is not None:
valuate_Null(model, val_dataloader_n)