|
|
import torch |
|
|
import os |
|
|
from enum import Enum |
|
|
from tqdm import tqdm |
|
|
import numpy as np |
|
|
from detectron2.structures import BitMasks |
|
|
from psalm.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, \ |
|
|
DEFAULT_IM_END_TOKEN, DEFAULT_SEG_TOKEN, SEG_TOKEN_INDEX |
|
|
from psalm.model.builder import load_pretrained_model |
|
|
from psalm.utils import disable_torch_init |
|
|
from psalm.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria |
|
|
import cv2 |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
from psalm import conversation as conversation_lib |
|
|
|
|
|
from psalm.train.train_datasets_eval import COCO_interactive_dataset_extrametric |
|
|
from psalm.eval.eval_davis_evaonly import Multicondition_Dataset_extrametric |
|
|
from pycocotools.mask import encode, decode, frPyObjects |
|
|
|
|
|
from detectron2.structures import BoxMode |
|
|
from detectron2.data import MetadataCatalog, DatasetCatalog |
|
|
from typing import Dict, Optional, Sequence, List |
|
|
from dataclasses import dataclass, field |
|
|
import torch.distributed as dist |
|
|
import transformers |
|
|
from pathlib import Path |
|
|
from segmentation_evaluation import openseg_classes |
|
|
COLOR_MAP = openseg_classes.ADE20K_150_CATEGORIES |
|
|
from detectron2.data import detection_utils as utils |
|
|
import pickle |
|
|
import math |
|
|
import json |
|
|
import utils_metric |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DataCollatorForCOCODatasetV2(object): |
|
|
"""Collate examples for supervised fine-tuning.""" |
|
|
|
|
|
tokenizer: transformers.PreTrainedTokenizer |
|
|
|
|
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
|
|
if len(instances[0]) == 0: |
|
|
return {} |
|
|
input_ids, labels = tuple([instance[key] for instance in instances] |
|
|
for key in ("input_ids", "labels")) |
|
|
input_ids = torch.nn.utils.rnn.pad_sequence( |
|
|
input_ids, |
|
|
batch_first=True, |
|
|
padding_value=self.tokenizer.pad_token_id) |
|
|
labels = torch.nn.utils.rnn.pad_sequence(labels, |
|
|
batch_first=True, |
|
|
padding_value=IGNORE_INDEX) |
|
|
input_ids = input_ids[:, :self.tokenizer.model_max_length] |
|
|
labels = labels[:, :self.tokenizer.model_max_length] |
|
|
batch = dict( |
|
|
input_ids=input_ids, |
|
|
labels=labels, |
|
|
attention_mask=input_ids.ne(self.tokenizer.pad_token_id), |
|
|
) |
|
|
if 'image' in instances[0]: |
|
|
images = [instance['image'] for instance in instances] |
|
|
if all(x is not None and x.shape == images[0].shape for x in images): |
|
|
batch['images'] = torch.stack(images) |
|
|
else: |
|
|
batch['images'] = images |
|
|
if 'vp_image' in instances[0]: |
|
|
vp_images = [instance['vp_image'] for instance in instances] |
|
|
if all(x is not None and x.shape == vp_images[0].shape for x in vp_images): |
|
|
batch['vp_images'] = torch.stack(vp_images) |
|
|
else: |
|
|
batch['vp_images'] = vp_images |
|
|
for instance in instances: |
|
|
for key in ['input_ids', 'labels', 'image']: |
|
|
del instance[key] |
|
|
batch['seg_info'] = [instance for instance in instances] |
|
|
|
|
|
if 'dataset_type' in instances[0]: |
|
|
batch['dataset_type'] = [instance['dataset_type'] for instance in instances] |
|
|
|
|
|
if 'class_name_ids' in instances[0]: |
|
|
class_name_ids = [instance['class_name_ids'] for instance in instances] |
|
|
if any(x.shape != class_name_ids[0].shape for x in class_name_ids): |
|
|
batch['class_name_ids'] = torch.nn.utils.rnn.pad_sequence( |
|
|
class_name_ids, |
|
|
batch_first=True, |
|
|
padding_value=-1, |
|
|
) |
|
|
else: |
|
|
batch['class_name_ids'] = torch.stack(class_name_ids, dim=0) |
|
|
if 'token_refer_id' in instances[0]: |
|
|
token_refer_id = [instance['token_refer_id'] for instance in instances] |
|
|
batch['token_refer_id'] = token_refer_id |
|
|
if 'cls_indices' in instances[0]: |
|
|
cls_indices = [instance['cls_indices'] for instance in instances] |
|
|
if any(x.shape != cls_indices[0].shape for x in cls_indices): |
|
|
batch['cls_indices'] = torch.nn.utils.rnn.pad_sequence( |
|
|
cls_indices, |
|
|
batch_first=True, |
|
|
padding_value=-1, |
|
|
) |
|
|
else: |
|
|
batch['cls_indices'] = torch.stack(cls_indices, dim=0) |
|
|
if 'random_idx' in instances[0]: |
|
|
random_idxs = [instance['random_idx'] for instance in instances] |
|
|
batch['random_idx'] = torch.stack(random_idxs, dim=0) |
|
|
if 'class_name_embedding_indices' in instances[0]: |
|
|
class_name_embedding_indices = [instance['class_name_embedding_indices'] for instance in instances] |
|
|
class_name_embedding_indices = torch.nn.utils.rnn.pad_sequence( |
|
|
class_name_embedding_indices, |
|
|
batch_first=True, |
|
|
padding_value=0) |
|
|
batch['class_name_embedding_indices'] = class_name_embedding_indices |
|
|
if 'refer_embedding_indices' in instances[0]: |
|
|
refer_embedding_indices = [instance['refer_embedding_indices'] for instance in instances] |
|
|
refer_embedding_indices = torch.nn.utils.rnn.pad_sequence( |
|
|
refer_embedding_indices, |
|
|
batch_first=True, |
|
|
padding_value=0) |
|
|
batch['refer_embedding_indices'] = refer_embedding_indices |
|
|
|
|
|
|
|
|
|
|
|
return batch |
|
|
|
|
|
|
|
|
def __str__(self): |
|
|
fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" |
|
|
return fmtstr.format(**self.__dict__) |
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DataArguments: |
|
|
data_path: str = field(default=None, metadata={"help": "Path to the training data."}) |
|
|
lazy_preprocess: bool = False |
|
|
only_two_class: bool = False |
|
|
old_two_class: bool = False |
|
|
is_multimodal: bool = False |
|
|
image_folder: Optional[str] = field(default='/home/emzhang/data/segmentation/refer_seg/images/mscoco/images/train2014') |
|
|
|
|
|
mask_config: Optional[str] = field(default="./psalm/mask_config/maskformer2_swin_base_384_bs16_50ep.yaml") |
|
|
image_aspect_ratio: str = 'square' |
|
|
image_grid_pinpoints: Optional[str] = field(default=None) |
|
|
region_mask_type: Optional[str] = field(default=None) |
|
|
|
|
|
json_path: str = '/home/emzhang/code/LLaVA/datasets/refcoco/refcoco_val.json' |
|
|
model_path: str = '/home/emzhang/code/llava_zem/checkpoints/SEG_class_refcoco_after_fixbug' |
|
|
model_map_name: str = 'psalm_video' |
|
|
version: str = 'opt-iml-1.3b' |
|
|
SEG_norm: bool = field(default=False) |
|
|
SEG_proj: bool = field(default=True) |
|
|
criterion_type: Optional[str] = field(default="concat_seg") |
|
|
matcher_type: Optional[str] = field(default="wo_class") |
|
|
llm_pos: Optional[str] = field(default="none") |
|
|
ln_2048: bool = field(default=False) |
|
|
seg_idx_back: bool = field(default=False) |
|
|
segmentation: bool = True |
|
|
eval_batch_size: int = 1 |
|
|
dataloader_num_workers: int = 4 |
|
|
thr: float = 0.5 |
|
|
topk: int=1 |
|
|
fuse_score: bool = field(default=False) |
|
|
seg_task: Optional[str] = field(default="region") |
|
|
seg_last: bool = field(default=True) |
|
|
num_chunks: int=1 |
|
|
chunk_idx: int=0 |
|
|
|
|
|
|
|
|
def fuse_davis_mask(mask_list,fill_number_list): |
|
|
fused_mask = np.zeros_like(mask_list[0]) |
|
|
for mask, fill_number in zip(mask_list,fill_number_list): |
|
|
fill_number = int(fill_number) |
|
|
fused_mask[mask == 1] = fill_number |
|
|
return fused_mask |
|
|
|
|
|
|
|
|
import os |
|
|
import re |
|
|
|
|
|
def get_latest_checkpoint_path(model_path): |
|
|
|
|
|
checkpoint_pattern = re.compile(r"checkpoint-(\d+)") |
|
|
|
|
|
|
|
|
if os.path.basename(model_path).startswith("checkpoint-") and checkpoint_pattern.match(os.path.basename(model_path)): |
|
|
return model_path |
|
|
|
|
|
|
|
|
elif os.path.isdir(model_path): |
|
|
checkpoints = [d for d in os.listdir(model_path) if checkpoint_pattern.match(d)] |
|
|
|
|
|
if not checkpoints: |
|
|
raise ValueError("No checkpoints found in the specified directory.") |
|
|
|
|
|
|
|
|
max_checkpoint = max(checkpoints, key=lambda x: int(checkpoint_pattern.match(x).group(1))) |
|
|
model_path = os.path.join(model_path, max_checkpoint) |
|
|
|
|
|
elif not os.path.exists(model_path): |
|
|
raise FileNotFoundError(f"The specified path '{model_path}' does not exist.") |
|
|
|
|
|
return model_path |
|
|
|
|
|
|
|
|
file_path = "/data/work-gcp-europe-west4-a/yuqian_fu/Ego/data_segswap/egoexo_val_framelevel_newprompt_all_instruction.json" |
|
|
pred_path = "/data/work-gcp-europe-west4-a/yuqian_fu/Ego/data_segswap/mask_predictions/egofullmodel_smalljson" |
|
|
root_path = "/data/work-gcp-europe-west4-a/yuqian_fu/Ego/data_segswap" |
|
|
val_set = os.listdir(pred_path) |
|
|
with open(file_path, 'r') as f: |
|
|
datas = json.load(f) |
|
|
|
|
|
|
|
|
parser = transformers.HfArgumentParser(DataArguments) |
|
|
data_args = parser.parse_args_into_dataclasses()[0] |
|
|
disable_torch_init() |
|
|
model_path = os.path.expanduser(data_args.model_path) |
|
|
model_name = get_model_name_from_path(model_path) |
|
|
print(f'current model is {model_path}') |
|
|
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, model_args=data_args, mask_config=data_args.mask_config, device='cuda') |
|
|
data_args.image_processor = image_processor |
|
|
data_args.is_multimodal = True |
|
|
conversation_lib.default_conversation = conversation_lib.conv_templates[data_args.version] |
|
|
|
|
|
def evaluation(take_id): |
|
|
num_frame = 0 |
|
|
|
|
|
data_list = [] |
|
|
for data in datas: |
|
|
if data['video_name'] == take_id: |
|
|
data_list.append(data) |
|
|
eval_dataset = Multicondition_Dataset_extrametric(data_list=data_list, tokenizer=tokenizer, data_args=data_args) |
|
|
data_collator = DataCollatorForCOCODatasetV2(tokenizer=tokenizer) |
|
|
|
|
|
dataloader_params = { |
|
|
"batch_size": data_args.eval_batch_size, |
|
|
"num_workers": data_args.dataloader_num_workers, |
|
|
} |
|
|
eval_dataloader = DataLoader(eval_dataset, batch_size=dataloader_params['batch_size'], collate_fn=data_collator, |
|
|
num_workers=dataloader_params['num_workers']) |
|
|
|
|
|
cam_target = data_list[0]['image'].split('/')[-2] |
|
|
gt_path = f"{root_path}/{take_id}/annotation.json" |
|
|
with open(gt_path, 'r') as fp: |
|
|
gt = json.load(fp) |
|
|
|
|
|
objs = list(gt["masks"].keys()) |
|
|
coco_id_to_cont_id = {cont_id + 1: coco_id for cont_id, coco_id in enumerate(objs)} |
|
|
id_range = list(coco_id_to_cont_id.keys()) |
|
|
|
|
|
IoUs = [] |
|
|
ShapeAcc = [] |
|
|
ExistenceAcc = [] |
|
|
LocationScores = [] |
|
|
|
|
|
obj_target = [] |
|
|
for obj in objs: |
|
|
if cam_target in gt["masks"][obj].keys(): |
|
|
obj_target.append(obj) |
|
|
|
|
|
|
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
model.to(device=device,dtype=torch.float).eval() |
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
for idx, inputs in tqdm(enumerate(eval_dataloader), total=len(eval_dataloader)): |
|
|
if len(inputs) == 0: |
|
|
print('no data load') |
|
|
continue |
|
|
|
|
|
inputs = {k: v.to(device) if torch.is_tensor(v) else v for k, v in inputs.items()} |
|
|
inputs['token_refer_id'] = [ids.to(device) for ids in inputs['token_refer_id']] |
|
|
|
|
|
|
|
|
frame_id = inputs['seg_info'][0]['file_name'].split('/')[-1].split('.')[0] |
|
|
|
|
|
try: |
|
|
|
|
|
if 'instance' in data_args.model_map_name: |
|
|
outputs = model.eval_video( |
|
|
input_ids=inputs['input_ids'], |
|
|
attention_mask=inputs['attention_mask'], |
|
|
images=inputs['images'].float(), |
|
|
vp_images=inputs['vp_images'].float(), |
|
|
seg_info=inputs['seg_info'], |
|
|
class_name_embedding_indices=inputs['class_name_embedding_indices'], |
|
|
class_name_ids=inputs['class_name_ids'], |
|
|
cls_indices=inputs['cls_indices'], |
|
|
labels=inputs['labels'] |
|
|
) |
|
|
else: |
|
|
|
|
|
''' |
|
|
outputs = model.eval_video( |
|
|
input_ids=inputs['input_ids'], |
|
|
attention_mask=inputs['attention_mask'], |
|
|
images=inputs['images'].float(), |
|
|
vp_images=inputs['vp_images'].float(), |
|
|
seg_info=inputs['seg_info'], |
|
|
labels=inputs['labels'] |
|
|
) |
|
|
''' |
|
|
|
|
|
outputs = model.eval_video( |
|
|
input_ids=inputs['input_ids'], |
|
|
attention_mask=inputs['attention_mask'], |
|
|
images=inputs['images'].float(), |
|
|
vp_images=inputs['vp_images'].float(), |
|
|
seg_info=inputs['seg_info'], |
|
|
token_refer_id = inputs['token_refer_id'], |
|
|
refer_embedding_indices=inputs['refer_embedding_indices'], |
|
|
labels=inputs['labels'] |
|
|
) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.synchronize() |
|
|
except: |
|
|
print('something wrong when infer') |
|
|
continue |
|
|
|
|
|
output = outputs[0] |
|
|
pred_mask = output['instances'].pred_masks |
|
|
pred_mask = pred_mask.cpu().numpy() |
|
|
scores = output['instances'].scores.transpose(1, 0).cpu().numpy() |
|
|
gt_mask = output['gt'].cpu().numpy().astype(np.uint8) |
|
|
assert len(scores) == len(inputs['seg_info'][0]['instances'].vp_fill_number) |
|
|
pred_mask_list = [] |
|
|
pred_score_list = [] |
|
|
fill_number_list = [] |
|
|
prev_idx = [] |
|
|
for i in range(len(scores)): |
|
|
cur_scores = scores[i] |
|
|
cur_fill_number = inputs['seg_info'][0]['instances'].vp_fill_number[i] |
|
|
max_score, idx = torch.topk(torch.tensor(cur_scores), 10, largest=True, sorted=True) |
|
|
idx = idx.cpu().numpy() |
|
|
for i in range(10): |
|
|
if idx[i] not in prev_idx: |
|
|
prev_idx.append(idx[i]) |
|
|
pick_idx = idx[i] |
|
|
pick_score = max_score[i] |
|
|
break |
|
|
|
|
|
cur_pred = pred_mask[pick_idx, :] |
|
|
pred_score_list.append(pick_score) |
|
|
pred_mask_list.append(cur_pred) |
|
|
fill_number_list.append(cur_fill_number) |
|
|
pred_mask_list = [tensor_.astype(np.uint8) for tensor_ in pred_mask_list] |
|
|
fused_pred_mask = fuse_davis_mask(pred_mask_list,fill_number_list) |
|
|
|
|
|
obj_range = [] |
|
|
for obj in obj_target: |
|
|
if frame_id in gt["masks"][obj][cam_target].keys(): |
|
|
obj_range.append(obj) |
|
|
pred_mask = fused_pred_mask |
|
|
unique_instances = np.unique(pred_mask) |
|
|
unique_instances = unique_instances[unique_instances != 0] |
|
|
unique_instances = [x for x in unique_instances if x in id_range] |
|
|
|
|
|
if len(unique_instances) == 0: |
|
|
continue |
|
|
|
|
|
num_frame += 1 |
|
|
for instance_value in unique_instances: |
|
|
binary_mask = (pred_mask == instance_value).astype(np.uint8) |
|
|
h,w = binary_mask.shape |
|
|
obj_name = coco_id_to_cont_id[instance_value] |
|
|
if obj_name not in obj_range: |
|
|
continue |
|
|
gt_mask = decode(gt["masks"][obj_name][cam_target][frame_id]) |
|
|
gt_mask = cv2.resize(gt_mask, (w, h), interpolation=cv2.INTER_NEAREST) |
|
|
iou, shape_acc = utils_metric.eval_mask(gt_mask, binary_mask) |
|
|
ex_acc = utils_metric.existence_accuracy(gt_mask, binary_mask) |
|
|
location_score = utils_metric.location_score(gt_mask, binary_mask, size=(h, w)) |
|
|
IoUs.append(iou) |
|
|
ShapeAcc.append(shape_acc) |
|
|
ExistenceAcc.append(ex_acc) |
|
|
LocationScores.append(location_score) |
|
|
|
|
|
IoUs = np.array(IoUs) |
|
|
ShapeAcc = np.array(ShapeAcc) |
|
|
ExistenceAcc = np.array(ExistenceAcc) |
|
|
LocationScores = np.array(LocationScores) |
|
|
|
|
|
|
|
|
return IoUs.tolist(), ShapeAcc.tolist(), ExistenceAcc.tolist(), LocationScores.tolist(), num_frame |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
total_iou = [] |
|
|
total_shape_acc = [] |
|
|
total_existence_acc = [] |
|
|
total_location_scores = [] |
|
|
num_total = 0 |
|
|
|
|
|
|
|
|
for take_id in val_set[100:]: |
|
|
ious, shape_accs, existence_accs, location_scores, num_frame = evaluation(take_id) |
|
|
total_iou += ious |
|
|
total_shape_acc += shape_accs |
|
|
total_existence_acc += existence_accs |
|
|
total_location_scores += location_scores |
|
|
num_total += num_frame |
|
|
|
|
|
print('TOTAL IOU: ', np.mean(total_iou)) |
|
|
print('TOTAL LOCATION SCORE: ', np.mean(total_location_scores)) |
|
|
print('TOTAL SHAPE ACC: ', np.mean(total_shape_acc)) |
|
|
print('TOTAL EXISTENCE ACC: ', np.mean(total_existence_acc)) |
|
|
print("total frames:", num_total) |
|
|
|