ObjectRelator-Original / psalm /eval /eval_egoexo_competition_final_hardcode.py
YuqianFu's picture
Upload folder using huggingface_hub
625a17f verified
import torch
import os
from enum import Enum
from tqdm import tqdm
import numpy as np
# from detectron2.structures import BitMasks
# from psalm.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, \
# DEFAULT_IM_END_TOKEN, DEFAULT_SEG_TOKEN, SEG_TOKEN_INDEX
# from psalm.model.builder import load_pretrained_model
# from psalm.utils import disable_torch_init
# from psalm.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
import cv2
# from torch.utils.data import Dataset, DataLoader
# from psalm import conversation as conversation_lib
from psalm.train.train_datasets_eval import COCO_interactive_dataset_extrametric # debug
# from detectron2.structures import BoxMode
# from detectron2.data import MetadataCatalog, DatasetCatalog
from typing import Dict, Optional, Sequence, List
from dataclasses import dataclass, field
import torch.distributed as dist
import transformers
from pathlib import Path
from psalm.eval.segmentation_evaluation import openseg_classes
from natsort import natsorted
COLOR_MAP = openseg_classes.ADE20K_150_CATEGORIES
import re
from psalm.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, SEG_TOKEN_INDEX, CLS_TOKEN_INDEX, REGION_TOKEN_INDEX, REFER_TOKEN_INDEX
import json
from pycocotools import mask as mask_utils
# 定义命令行参数
@dataclass
class DataArguments:
data_path: str = field(default=None,
metadata={"help": "Path to the training data."})
lazy_preprocess: bool = False
is_multimodal: bool = False
image_folder: Optional[str] = field(default='/path/to/val2017')
model_path: Optional[str] = field(default="/path/to/model")
mask_config: Optional[str] = field(default="./psalm/mask_config/maskformer2_swin_base_384_bs16_50ep.yaml")
image_aspect_ratio: str = 'square'
image_grid_pinpoints: Optional[str] = field(default=None)
json_path: str = '/path/to/coco'
model_map_name: str = 'psalm_video'
version: str = 'llava_phi'
segmentation: bool = True
eval_batch_size: int = 1 # debug
dataloader_num_workers: int = 8
seg_task: Optional[str] = field(default="region")
region_mask_type: Optional[str] = field(default=None)
with_memory: bool = False
resume: bool = False
using_autocast: bool = False
resume_path: Optional[str] = field(default=None)
save_format: Optional[str] = field(default=None)
#定义collect函数
@dataclass
class DataCollatorForCOCODatasetV2(object):
"""Collate examples for supervised fine-tuning."""
tokenizer: transformers.PreTrainedTokenizer
#sequence表示列表、元组等有序对象,instances的类型表示为字典组成的有序列表,其中一个字典表示一帧图像
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances]
for key in ("input_ids", "labels"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(labels,
batch_first=True,
padding_value=IGNORE_INDEX)
input_ids = input_ids[:, :self.tokenizer.model_max_length]
labels = labels[:, :self.tokenizer.model_max_length]
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
)
if 'image' in instances[0]:
images = [instance['image'] for instance in instances]
if all(x is not None and x.shape == images[0].shape for x in images):
batch['images'] = torch.stack(images)
else:
batch['images'] = images
if 'vp_image' in instances[0]:
vp_images = [instance['vp_image'] for instance in instances]
if all(x is not None and x.shape == vp_images[0].shape for x in vp_images):
batch['vp_images'] = torch.stack(vp_images)
else:
batch['vp_images'] = vp_images
for instance in instances:
for key in ['input_ids', 'labels', 'image']:
del instance[key]
batch['seg_info'] = [instance for instance in instances]
if 'dataset_type' in instances[0]:
batch['dataset_type'] = [instance['dataset_type'] for instance in instances]
if 'class_name_ids' in instances[0]:
class_name_ids = [instance['class_name_ids'] for instance in instances]
if any(x.shape != class_name_ids[0].shape for x in class_name_ids):
batch['class_name_ids'] = torch.nn.utils.rnn.pad_sequence(
class_name_ids,
batch_first=True,
padding_value=-1,
)
else:
batch['class_name_ids'] = torch.stack(class_name_ids, dim=0)
if 'token_refer_id' in instances[0]:
token_refer_id = [instance['token_refer_id'] for instance in instances]
batch['token_refer_id'] = token_refer_id
if 'cls_indices' in instances[0]:
cls_indices = [instance['cls_indices'] for instance in instances]
if any(x.shape != cls_indices[0].shape for x in cls_indices):
batch['cls_indices'] = torch.nn.utils.rnn.pad_sequence(
cls_indices,
batch_first=True,
padding_value=-1,
)
else:
batch['cls_indices'] = torch.stack(cls_indices, dim=0)
if 'random_idx' in instances[0]:
random_idxs = [instance['random_idx'] for instance in instances]
batch['random_idx'] = torch.stack(random_idxs, dim=0)
if 'class_name_embedding_indices' in instances[0]:
class_name_embedding_indices = [instance['class_name_embedding_indices'] for instance in instances]
class_name_embedding_indices = torch.nn.utils.rnn.pad_sequence(
class_name_embedding_indices,
batch_first=True,
padding_value=0)
batch['class_name_embedding_indices'] = class_name_embedding_indices
if 'refer_embedding_indices' in instances[0]:
refer_embedding_indices = [instance['refer_embedding_indices'] for instance in instances]
refer_embedding_indices = torch.nn.utils.rnn.pad_sequence(
refer_embedding_indices,
batch_first=True,
padding_value=0)
batch['refer_embedding_indices'] = refer_embedding_indices
return batch
# 定义处理model输出结果的函数
def parse_outputs(outputs,gt_mask):
res_list = []
for output in outputs:
# gt = output['gt'].cpu().numpy().astype(np.uint8)
pred_mask = output['instances'].pred_masks
pred_mask = pred_mask.cpu().numpy()
scores = output['instances'].scores.transpose(1,0).cpu().numpy()
gt_mask = output['gt'].cpu().numpy().astype(np.uint8)
try:
pred_cls = output['instances'].pred_classes.cpu().numpy()
except:
pred_cls = None
assert scores.shape[0] == gt_mask.shape[0]
for i in range(gt_mask.shape[0]):
res = {
'pred':pred_mask,
'gt': gt_mask[i],
'scores':scores[i],
'pred_cls':pred_cls
}
res_list.append(res)
return res_list
# 定义Dataset类
class DAVIS_Dataset(COCO_interactive_dataset_extrametric):
#注意,这里所有的处理逻辑针对的都是一帧图像
def __getitem__(self, idx):
data = self.data[idx]
#图片的相对路径名称,like2017/trainval/JPEGImages/480p/bike-packing/00001.jpg
image_file = data['image']
#image_folder是data_root根路径 在这里是data_segswap
image_folder = self.data_args.image_folder
data_dict = {}
#file_name是图片的完整路径名称,like /data/Davis/2017/trainval/JPEGImages/480p/bike-packing/00001.jpg
data_dict['file_name'] = os.path.join(image_folder, image_file)
data_dict['height'] = data['image_info']['height']
data_dict['width'] = data['image_info']['width']
#image_id可以理解为计数器,编号
data_dict['image_id'] = data['new_img_id']
#annotations,本帧对应的注释,coco格式的分割mask,一张图片可能包含多个实例的mask
data_dict['annotations'] = data['anns']
#vp_annotations,每段视频中第一帧的注释
data_dict['vp_annotations'] = data['first_frame_anns']
#vp_image,每段视频中第一帧的完整路径,like /data/Davis/2017/trainval/JPEGImages/480p/bike-packing/00000.jpg
data_dict['vp_image'] = os.path.join(image_folder,data['first_frame_image'])
for annotation in data_dict['annotations']:
annotation['bbox_mode'] = BoxMode.XYXY_ABS
#边界框左上角和右下角的坐标都为原点,意思是将边界框置为空框
annotation['bbox'] = [0,0,0,0]
annotation['image_id'] = data['new_img_id']
for annotation in data_dict['vp_annotations']:
annotation['bbox_mode'] = BoxMode.XYXY_ABS
annotation['bbox'] = [0,0,0,0]
annotation['image_id'] = data['new_img_id']
processor = self.data_args.image_processor['null_mask'] # debug:处理null mask
#尝试从命令行参数中获取region_mask_type
region_mask_type = getattr(self.data_args,'region_mask_type',None)
if region_mask_type is not None:
region_mask_type = region_mask_type.split('||')
#print("region_mask_type:", region_mask_type)
#根据region_mask_type和mask_format(这里是0、1掩码),对原始的data_dict进行预处理,将Detectron2格式的dataset dict转化为MaskFormer格式的
data_dict = processor.preprocess(data_dict,region_mask_type=region_mask_type,mask_format='bitmask')
#num_target,本帧图像中有多少个对象
#下面的一小段代码,主要是利用llama处理输入的文本,生成对应的token
num_target = len(data_dict['instances'])
#<image> 是一个特殊的占位符,表示图像的输入
prefix_inst = 'This is an image <image>, Please segment by given regions'
#<region> 占位符来表示每个需要分割的区域,用逗号分隔,最后一个 <region> 以句号结束,例如,如果有 3 个区域,结果是 ' <region>, <region>, <region>.'
regions_inst = ' <region>,' * (num_target - 1) + ' <region>.'
sources_value = f'\nThis is all regions: {regions_inst}\n'
#sources构建了一个人类和模型交互的对话格式,定义了来自人类的输入和来自模型的输出
sources = [
[{'from': 'human', 'value': prefix_inst + sources_value},
{'from': 'gpt', 'value': '\n[SEG]<seg>'}]]
text_dict = self.preprocess_llama2(sources, self.tokenizer)
#input_ids是模型的实际输入,是由分词器将文本 sources 转换为的一系列数字标识(token IDs)
input_ids = text_dict['input_ids'][0]
#labels是模型训练时的token的真实标签,与input_ids对应
labels = text_dict['labels'][0]
data_dict['input_ids'] = input_ids
data_dict['labels'] = labels
data_dict['dataset_type'] = 'region_coco'
return data_dict
def evaluation():
# 模型的加载
parser = transformers.HfArgumentParser(DataArguments)
data_args = parser.parse_args_into_dataclasses()[0]
# disable_torch_init()
# model_path = os.path.expanduser(data_args.model_path)
# model_name = get_model_name_from_path(model_path)
# print(f'current model is {model_path}')
# tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, model_args=data_args, mask_config=data_args.mask_config, device='cuda')
# data_args.image_processor = image_processor
# data_args.is_multimodal = True
# conversation_lib.default_conversation = conversation_lib.conv_templates[data_args.version]
# 数据集的加载
# data_collator = DataCollatorForCOCODatasetV2(tokenizer=tokenizer)
# dataloader_params = {
# "batch_size": data_args.eval_batch_size,
# "num_workers": data_args.dataloader_num_workers,
# }
# def load_ref_dataset():
# return DAVIS_Dataset(json_path=data_args.json_path, tokenizer=tokenizer, data_args=data_args)
#注册load_ref_dataset函数,方便快速获取数据集
# DatasetCatalog.register('refcoco_dataset', load_ref_dataset)
# MetadataCatalog.get('refcoco_dataset').set(stuff_classes=['object'],)
# 模型导入到device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# if data_args.using_autocast:
# model.to(device=device).eval() # debug:不指定模型精度
# else:
# model.to(device=device,dtype=torch.float).eval()
# 定义处理takes的范围
save_path_json = "/scratch/yuqian_fu/competition_test_20250518_hardcode_v2_exoego_new.json" # debug: 实验前修改
data_path = "/home/yuqian_fu/Projects/PSALM/egoexo_test_framelevel.json"
with open(data_path, "r") as fp:
datas = json.load(fp)
splits_path = "/home/yuqian_fu/Projects/ego-exo4d-relation/correspondence/split.json"
with open(splits_path, "r") as fp:
splits = json.load(fp)
takes_all = splits["test"]
# debug:修改推理takes的范围
NUM = len(takes_all) // 8
takes_all = takes_all[:NUM]
# takes_all = takes_all[NUM:NUM*2]
# takes_all = takes_all[NUM*2:NUM*3]
# takes_all = takes_all[NUM*3:]
#takes_all = ["40dc3bbc-c8c4-4e6d-a2a5-32a357f3c291"] # debug: 测试单个take
# resume机制
if data_args.resume:
with open(data_args.resume_path, "r") as fp:
result = json.load(fp)
# 删除result字典中最后处理的take_id
# last_processed_take = "xxx"
# del result[last_processed_take]
processed_takes = set(result.keys())
takes_all = [take_id for take_id in takes_all if take_id not in processed_takes]
else:
result = {}
# 混合精度推理
if data_args.using_autocast:
scaler = torch.cuda.amp.autocast(enabled=True)
# 记录target注释文件缺失的数量
anno_miss_num = 0
with torch.no_grad():
for take_id in tqdm(takes_all):
print("current take_id:", take_id)
# 获取针对每个take的标注文件
with open(f'{data_args.image_folder}/{take_id}/annotation.json', 'r') as fp:
annotations = json.load(fp)
# 获取每个take下的所有物体,并创建从fill_number到物体名称的映射
objs = natsorted(list(annotations["masks"].keys())) #debug: 是否有必要使用natsort
coco_id_to_cont_id = {cont_id + 1: coco_id for cont_id, coco_id in enumerate(objs)}
id_range = list(coco_id_to_cont_id.keys())
# 筛选出所有video-name为take_id的数据,构造数据集
datas_list = []
for data in datas:
if data['video_name'] == take_id:
datas_list.append(data)
# print("len(datas_list):", len(datas_list)) # debug
# eval_dataset = DAVIS_Dataset(datas_list, tokenizer=tokenizer, data_args=data_args)
# eval_dataloader = DataLoader(eval_dataset, batch_size=dataloader_params['batch_size'], collate_fn=data_collator,
# num_workers=dataloader_params['num_workers'])
# 保存每个take下的结果
pred_json = {'masks': {}, 'subsample_idx': annotations['subsample_idx']}
objs_after = [] # debug: 统计推理之后obj的个数
for idx, inputs in enumerate(datas_list):
# 准备inputs
#inputs = {k: v.to(device) if torch.is_tensor(v) else v for k, v in inputs.items()}
data_idx = datas_list[idx]
# 提取cam
# target_cam = inputs['seg_info'][0]['file_name'].split('/')[-2]
# query_cam = inputs['seg_info'][0]['vp_file_path'].split('/')[-2] # debug
target_cam = data_idx['image'].split('/')[-2]
query_cam = data_idx['first_frame_image'].split('/')[-2] # debug
pair_key = f'{query_cam}_{target_cam}'
#print("pair_key:", pair_key) # debug
# 提取id
#id = inputs['seg_info'][0]['file_name'].split('/')[-1].split('.')[0]
id = data_idx['image'].split('/')[-1].split('.')[0] # debug
# print("id:", id) # debug
# 提取高宽
# h = inputs['seg_info'][0]['height']
# w = inputs['seg_info'][0]['width']
h = data_idx['image_info']['height']
w = data_idx['image_info']['width']
for i in range(len(data_idx['first_frame_anns'])):
cur_fill_number = data_idx['first_frame_anns'][i]['category_id']
# debug:如果填充物体id不在所有物体的索引列表中,跳过
if cur_fill_number not in id_range:
print(f"cur_fill_number {cur_fill_number} not in id_range, skipping...")
raise ValueError(f"cur_fill_number {cur_fill_number} not in id_range, skipping...")
# 根据cur_fill_number逆映射找到obj-name
obj_name = coco_id_to_cont_id[cur_fill_number]
objs_after.append(obj_name) # debug: 统计推理之后obj的个数
# 对获取到的obj_name进行合法性筛查
if target_cam not in annotations['masks'][obj_name].keys():
print(f"target_cam {target_cam} not in {obj_name}, skipping...")
raise ValueError(f"target_cam {target_cam} not in {obj_name}, skipping...")
if id not in annotations["masks"][obj_name][target_cam].keys():
anno_miss_num += 1
# print(f"id {id} not in {target_cam}, skipping...")
# 编码cur_pred // 保存mask图片
cur_pred = np.random.randint(0, 2, (h, w), dtype=np.uint8) # debug: 随机生成掩码
if data_args.save_format == 'rle':
cur_pred = mask_utils.encode(np.asfortranarray(cur_pred.astype(np.uint8)))
cur_pred['counts'] = cur_pred['counts'].decode('ascii')
elif data_args.save_format == 'png':
save_path = f'/scratch/yuqian_fu/results_v2/{take_id}/{target_cam}/{obj_name}/{id}.png'
os.makedirs(os.path.dirname(save_path), exist_ok=True)
cv2.imwrite(save_path, cur_pred.astype(np.uint8))
else:
raise ValueError(f"Unsupported save format: {data_args.save_format}")
# 1) 保证第一层 obj_name 存在
if obj_name not in pred_json['masks']:
pred_json['masks'][obj_name] = {}
# 2) 保证第二层 pair_key 存在
if pair_key not in pred_json['masks'][obj_name]:
pred_json['masks'][obj_name][pair_key] = {}
if data_args.save_format == 'rle':
pred_json['masks'][obj_name][f'{query_cam}_{target_cam}'][id] = {'pred_mask': cur_pred, 'confidence': 1.0}
elif data_args.save_format == 'png':
pred_json['masks'][obj_name][f'{query_cam}_{target_cam}'][id] = {'pred_mask': save_path, 'confidence': 1.0}
#检查一下pred_json['masks']的内容是否为空
if len(pred_json['masks']) == 0:
print(f"pred_json['masks'] is empty for take_id {take_id}, skipping...")
# continue # bug
# 将这个take下的所有结果存储
#存储之前,确保take下每个物体都存在,有的物体下没有targer_cam,需要置空写死
check_obj = set(objs) - set(objs_after)
if len(check_obj) > 0:
for obj in check_obj:
# 对于缺失的物体,分两种情况。一种是obj下没有任何cam,另一种是obj下有cam但是没有ids
# cams = annotations['masks'][obj].keys()
# exo_cams = [x for x in cams if 'aria' not in x]
# ego_cams = [x for x in cams if 'aria' in x]
# # 如果exo_cams和ego_cams都非空,则需要增加一个cam的键
# if len(exo_cams) > 0 and len(ego_cams) > 0:
# pred_json['masks'][obj] = {}
# ego = ego_cams[0]
# for exo in exo_cams:
# pred_json['masks'][obj][f"{exo}_{ego}"] = {}
# else:
print(f"{take_id}缺失物体{obj}")
pred_json['masks'][obj] = {}
#pred_json['masks'][obj] = "xxx"
result[take_id] = pred_json
# 每个take定期保存,防止中断
# with open(save_path_json, 'w') as fp:
# json.dump(result, fp)
# 保存最后的结果
# with open(save_path_json, "w") as fp:
# json.dump(result, fp)
# 打印miss anno样本的数目
print(f"Total number of missing annotations: {anno_miss_num}")
if __name__ == '__main__':
evaluation()
# path = "/scratch/yuqian_fu/competition_test_20250516_single_take_complete.json"
# with open(path, "r") as fp:
# result = json.load(fp)
# pred = result["40dc3bbc-c8c4-4e6d-a2a5-32a357f3c291"]
# print(len(list(pred['masks'].keys())))