import torch import os from enum import Enum from tqdm import tqdm import numpy as np # from detectron2.structures import BitMasks # from psalm.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, \ # DEFAULT_IM_END_TOKEN, DEFAULT_SEG_TOKEN, SEG_TOKEN_INDEX # from psalm.model.builder import load_pretrained_model # from psalm.utils import disable_torch_init # from psalm.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria import cv2 # from torch.utils.data import Dataset, DataLoader # from psalm import conversation as conversation_lib from psalm.train.train_datasets_eval import COCO_interactive_dataset_extrametric # debug # from detectron2.structures import BoxMode # from detectron2.data import MetadataCatalog, DatasetCatalog from typing import Dict, Optional, Sequence, List from dataclasses import dataclass, field import torch.distributed as dist import transformers from pathlib import Path from psalm.eval.segmentation_evaluation import openseg_classes from natsort import natsorted COLOR_MAP = openseg_classes.ADE20K_150_CATEGORIES import re from psalm.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, SEG_TOKEN_INDEX, CLS_TOKEN_INDEX, REGION_TOKEN_INDEX, REFER_TOKEN_INDEX import json from pycocotools import mask as mask_utils # 定义命令行参数 @dataclass class DataArguments: data_path: str = field(default=None, metadata={"help": "Path to the training data."}) lazy_preprocess: bool = False is_multimodal: bool = False image_folder: Optional[str] = field(default='/path/to/val2017') model_path: Optional[str] = field(default="/path/to/model") mask_config: Optional[str] = field(default="./psalm/mask_config/maskformer2_swin_base_384_bs16_50ep.yaml") image_aspect_ratio: str = 'square' image_grid_pinpoints: Optional[str] = field(default=None) json_path: str = '/path/to/coco' model_map_name: str = 'psalm_video' version: str = 'llava_phi' segmentation: bool = True eval_batch_size: int = 1 # debug dataloader_num_workers: int = 8 seg_task: Optional[str] = field(default="region") region_mask_type: Optional[str] = field(default=None) with_memory: bool = False resume: bool = False using_autocast: bool = False resume_path: Optional[str] = field(default=None) save_format: Optional[str] = field(default=None) #定义collect函数 @dataclass class DataCollatorForCOCODatasetV2(object): """Collate examples for supervised fine-tuning.""" tokenizer: transformers.PreTrainedTokenizer #sequence表示列表、元组等有序对象,instances的类型表示为字典组成的有序列表,其中一个字典表示一帧图像 def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels")) input_ids = torch.nn.utils.rnn.pad_sequence( input_ids, batch_first=True, padding_value=self.tokenizer.pad_token_id) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX) input_ids = input_ids[:, :self.tokenizer.model_max_length] labels = labels[:, :self.tokenizer.model_max_length] batch = dict( input_ids=input_ids, labels=labels, attention_mask=input_ids.ne(self.tokenizer.pad_token_id), ) if 'image' in instances[0]: images = [instance['image'] for instance in instances] if all(x is not None and x.shape == images[0].shape for x in images): batch['images'] = torch.stack(images) else: batch['images'] = images if 'vp_image' in instances[0]: vp_images = [instance['vp_image'] for instance in instances] if all(x is not None and x.shape == vp_images[0].shape for x in vp_images): batch['vp_images'] = torch.stack(vp_images) else: batch['vp_images'] = vp_images for instance in instances: for key in ['input_ids', 'labels', 'image']: del instance[key] batch['seg_info'] = [instance for instance in instances] if 'dataset_type' in instances[0]: batch['dataset_type'] = [instance['dataset_type'] for instance in instances] if 'class_name_ids' in instances[0]: class_name_ids = [instance['class_name_ids'] for instance in instances] if any(x.shape != class_name_ids[0].shape for x in class_name_ids): batch['class_name_ids'] = torch.nn.utils.rnn.pad_sequence( class_name_ids, batch_first=True, padding_value=-1, ) else: batch['class_name_ids'] = torch.stack(class_name_ids, dim=0) if 'token_refer_id' in instances[0]: token_refer_id = [instance['token_refer_id'] for instance in instances] batch['token_refer_id'] = token_refer_id if 'cls_indices' in instances[0]: cls_indices = [instance['cls_indices'] for instance in instances] if any(x.shape != cls_indices[0].shape for x in cls_indices): batch['cls_indices'] = torch.nn.utils.rnn.pad_sequence( cls_indices, batch_first=True, padding_value=-1, ) else: batch['cls_indices'] = torch.stack(cls_indices, dim=0) if 'random_idx' in instances[0]: random_idxs = [instance['random_idx'] for instance in instances] batch['random_idx'] = torch.stack(random_idxs, dim=0) if 'class_name_embedding_indices' in instances[0]: class_name_embedding_indices = [instance['class_name_embedding_indices'] for instance in instances] class_name_embedding_indices = torch.nn.utils.rnn.pad_sequence( class_name_embedding_indices, batch_first=True, padding_value=0) batch['class_name_embedding_indices'] = class_name_embedding_indices if 'refer_embedding_indices' in instances[0]: refer_embedding_indices = [instance['refer_embedding_indices'] for instance in instances] refer_embedding_indices = torch.nn.utils.rnn.pad_sequence( refer_embedding_indices, batch_first=True, padding_value=0) batch['refer_embedding_indices'] = refer_embedding_indices return batch # 定义处理model输出结果的函数 def parse_outputs(outputs,gt_mask): res_list = [] for output in outputs: # gt = output['gt'].cpu().numpy().astype(np.uint8) pred_mask = output['instances'].pred_masks pred_mask = pred_mask.cpu().numpy() scores = output['instances'].scores.transpose(1,0).cpu().numpy() gt_mask = output['gt'].cpu().numpy().astype(np.uint8) try: pred_cls = output['instances'].pred_classes.cpu().numpy() except: pred_cls = None assert scores.shape[0] == gt_mask.shape[0] for i in range(gt_mask.shape[0]): res = { 'pred':pred_mask, 'gt': gt_mask[i], 'scores':scores[i], 'pred_cls':pred_cls } res_list.append(res) return res_list # 定义Dataset类 class DAVIS_Dataset(COCO_interactive_dataset_extrametric): #注意,这里所有的处理逻辑针对的都是一帧图像 def __getitem__(self, idx): data = self.data[idx] #图片的相对路径名称,like2017/trainval/JPEGImages/480p/bike-packing/00001.jpg image_file = data['image'] #image_folder是data_root根路径 在这里是data_segswap image_folder = self.data_args.image_folder data_dict = {} #file_name是图片的完整路径名称,like /data/Davis/2017/trainval/JPEGImages/480p/bike-packing/00001.jpg data_dict['file_name'] = os.path.join(image_folder, image_file) data_dict['height'] = data['image_info']['height'] data_dict['width'] = data['image_info']['width'] #image_id可以理解为计数器,编号 data_dict['image_id'] = data['new_img_id'] #annotations,本帧对应的注释,coco格式的分割mask,一张图片可能包含多个实例的mask data_dict['annotations'] = data['anns'] #vp_annotations,每段视频中第一帧的注释 data_dict['vp_annotations'] = data['first_frame_anns'] #vp_image,每段视频中第一帧的完整路径,like /data/Davis/2017/trainval/JPEGImages/480p/bike-packing/00000.jpg data_dict['vp_image'] = os.path.join(image_folder,data['first_frame_image']) for annotation in data_dict['annotations']: annotation['bbox_mode'] = BoxMode.XYXY_ABS #边界框左上角和右下角的坐标都为原点,意思是将边界框置为空框 annotation['bbox'] = [0,0,0,0] annotation['image_id'] = data['new_img_id'] for annotation in data_dict['vp_annotations']: annotation['bbox_mode'] = BoxMode.XYXY_ABS annotation['bbox'] = [0,0,0,0] annotation['image_id'] = data['new_img_id'] processor = self.data_args.image_processor['null_mask'] # debug:处理null mask #尝试从命令行参数中获取region_mask_type region_mask_type = getattr(self.data_args,'region_mask_type',None) if region_mask_type is not None: region_mask_type = region_mask_type.split('||') #print("region_mask_type:", region_mask_type) #根据region_mask_type和mask_format(这里是0、1掩码),对原始的data_dict进行预处理,将Detectron2格式的dataset dict转化为MaskFormer格式的 data_dict = processor.preprocess(data_dict,region_mask_type=region_mask_type,mask_format='bitmask') #num_target,本帧图像中有多少个对象 #下面的一小段代码,主要是利用llama处理输入的文本,生成对应的token num_target = len(data_dict['instances']) # 是一个特殊的占位符,表示图像的输入 prefix_inst = 'This is an image , Please segment by given regions' # 占位符来表示每个需要分割的区域,用逗号分隔,最后一个 以句号结束,例如,如果有 3 个区域,结果是 ' , , .' regions_inst = ' ,' * (num_target - 1) + ' .' sources_value = f'\nThis is all regions: {regions_inst}\n' #sources构建了一个人类和模型交互的对话格式,定义了来自人类的输入和来自模型的输出 sources = [ [{'from': 'human', 'value': prefix_inst + sources_value}, {'from': 'gpt', 'value': '\n[SEG]'}]] text_dict = self.preprocess_llama2(sources, self.tokenizer) #input_ids是模型的实际输入,是由分词器将文本 sources 转换为的一系列数字标识(token IDs) input_ids = text_dict['input_ids'][0] #labels是模型训练时的token的真实标签,与input_ids对应 labels = text_dict['labels'][0] data_dict['input_ids'] = input_ids data_dict['labels'] = labels data_dict['dataset_type'] = 'region_coco' return data_dict def evaluation(): # 模型的加载 parser = transformers.HfArgumentParser(DataArguments) data_args = parser.parse_args_into_dataclasses()[0] # disable_torch_init() # model_path = os.path.expanduser(data_args.model_path) # model_name = get_model_name_from_path(model_path) # print(f'current model is {model_path}') # tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, model_args=data_args, mask_config=data_args.mask_config, device='cuda') # data_args.image_processor = image_processor # data_args.is_multimodal = True # conversation_lib.default_conversation = conversation_lib.conv_templates[data_args.version] # 数据集的加载 # data_collator = DataCollatorForCOCODatasetV2(tokenizer=tokenizer) # dataloader_params = { # "batch_size": data_args.eval_batch_size, # "num_workers": data_args.dataloader_num_workers, # } # def load_ref_dataset(): # return DAVIS_Dataset(json_path=data_args.json_path, tokenizer=tokenizer, data_args=data_args) #注册load_ref_dataset函数,方便快速获取数据集 # DatasetCatalog.register('refcoco_dataset', load_ref_dataset) # MetadataCatalog.get('refcoco_dataset').set(stuff_classes=['object'],) # 模型导入到device device = 'cuda' if torch.cuda.is_available() else 'cpu' # if data_args.using_autocast: # model.to(device=device).eval() # debug:不指定模型精度 # else: # model.to(device=device,dtype=torch.float).eval() # 定义处理takes的范围 save_path_json = "/scratch/yuqian_fu/competition_test_20250518_hardcode_v2_exoego_new.json" # debug: 实验前修改 data_path = "/home/yuqian_fu/Projects/PSALM/egoexo_test_framelevel.json" with open(data_path, "r") as fp: datas = json.load(fp) splits_path = "/home/yuqian_fu/Projects/ego-exo4d-relation/correspondence/split.json" with open(splits_path, "r") as fp: splits = json.load(fp) takes_all = splits["test"] # debug:修改推理takes的范围 NUM = len(takes_all) // 8 takes_all = takes_all[:NUM] # takes_all = takes_all[NUM:NUM*2] # takes_all = takes_all[NUM*2:NUM*3] # takes_all = takes_all[NUM*3:] #takes_all = ["40dc3bbc-c8c4-4e6d-a2a5-32a357f3c291"] # debug: 测试单个take # resume机制 if data_args.resume: with open(data_args.resume_path, "r") as fp: result = json.load(fp) # 删除result字典中最后处理的take_id # last_processed_take = "xxx" # del result[last_processed_take] processed_takes = set(result.keys()) takes_all = [take_id for take_id in takes_all if take_id not in processed_takes] else: result = {} # 混合精度推理 if data_args.using_autocast: scaler = torch.cuda.amp.autocast(enabled=True) # 记录target注释文件缺失的数量 anno_miss_num = 0 with torch.no_grad(): for take_id in tqdm(takes_all): print("current take_id:", take_id) # 获取针对每个take的标注文件 with open(f'{data_args.image_folder}/{take_id}/annotation.json', 'r') as fp: annotations = json.load(fp) # 获取每个take下的所有物体,并创建从fill_number到物体名称的映射 objs = natsorted(list(annotations["masks"].keys())) #debug: 是否有必要使用natsort coco_id_to_cont_id = {cont_id + 1: coco_id for cont_id, coco_id in enumerate(objs)} id_range = list(coco_id_to_cont_id.keys()) # 筛选出所有video-name为take_id的数据,构造数据集 datas_list = [] for data in datas: if data['video_name'] == take_id: datas_list.append(data) # print("len(datas_list):", len(datas_list)) # debug # eval_dataset = DAVIS_Dataset(datas_list, tokenizer=tokenizer, data_args=data_args) # eval_dataloader = DataLoader(eval_dataset, batch_size=dataloader_params['batch_size'], collate_fn=data_collator, # num_workers=dataloader_params['num_workers']) # 保存每个take下的结果 pred_json = {'masks': {}, 'subsample_idx': annotations['subsample_idx']} objs_after = [] # debug: 统计推理之后obj的个数 for idx, inputs in enumerate(datas_list): # 准备inputs #inputs = {k: v.to(device) if torch.is_tensor(v) else v for k, v in inputs.items()} data_idx = datas_list[idx] # 提取cam # target_cam = inputs['seg_info'][0]['file_name'].split('/')[-2] # query_cam = inputs['seg_info'][0]['vp_file_path'].split('/')[-2] # debug target_cam = data_idx['image'].split('/')[-2] query_cam = data_idx['first_frame_image'].split('/')[-2] # debug pair_key = f'{query_cam}_{target_cam}' #print("pair_key:", pair_key) # debug # 提取id #id = inputs['seg_info'][0]['file_name'].split('/')[-1].split('.')[0] id = data_idx['image'].split('/')[-1].split('.')[0] # debug # print("id:", id) # debug # 提取高宽 # h = inputs['seg_info'][0]['height'] # w = inputs['seg_info'][0]['width'] h = data_idx['image_info']['height'] w = data_idx['image_info']['width'] for i in range(len(data_idx['first_frame_anns'])): cur_fill_number = data_idx['first_frame_anns'][i]['category_id'] # debug:如果填充物体id不在所有物体的索引列表中,跳过 if cur_fill_number not in id_range: print(f"cur_fill_number {cur_fill_number} not in id_range, skipping...") raise ValueError(f"cur_fill_number {cur_fill_number} not in id_range, skipping...") # 根据cur_fill_number逆映射找到obj-name obj_name = coco_id_to_cont_id[cur_fill_number] objs_after.append(obj_name) # debug: 统计推理之后obj的个数 # 对获取到的obj_name进行合法性筛查 if target_cam not in annotations['masks'][obj_name].keys(): print(f"target_cam {target_cam} not in {obj_name}, skipping...") raise ValueError(f"target_cam {target_cam} not in {obj_name}, skipping...") if id not in annotations["masks"][obj_name][target_cam].keys(): anno_miss_num += 1 # print(f"id {id} not in {target_cam}, skipping...") # 编码cur_pred // 保存mask图片 cur_pred = np.random.randint(0, 2, (h, w), dtype=np.uint8) # debug: 随机生成掩码 if data_args.save_format == 'rle': cur_pred = mask_utils.encode(np.asfortranarray(cur_pred.astype(np.uint8))) cur_pred['counts'] = cur_pred['counts'].decode('ascii') elif data_args.save_format == 'png': save_path = f'/scratch/yuqian_fu/results_v2/{take_id}/{target_cam}/{obj_name}/{id}.png' os.makedirs(os.path.dirname(save_path), exist_ok=True) cv2.imwrite(save_path, cur_pred.astype(np.uint8)) else: raise ValueError(f"Unsupported save format: {data_args.save_format}") # 1) 保证第一层 obj_name 存在 if obj_name not in pred_json['masks']: pred_json['masks'][obj_name] = {} # 2) 保证第二层 pair_key 存在 if pair_key not in pred_json['masks'][obj_name]: pred_json['masks'][obj_name][pair_key] = {} if data_args.save_format == 'rle': pred_json['masks'][obj_name][f'{query_cam}_{target_cam}'][id] = {'pred_mask': cur_pred, 'confidence': 1.0} elif data_args.save_format == 'png': pred_json['masks'][obj_name][f'{query_cam}_{target_cam}'][id] = {'pred_mask': save_path, 'confidence': 1.0} #检查一下pred_json['masks']的内容是否为空 if len(pred_json['masks']) == 0: print(f"pred_json['masks'] is empty for take_id {take_id}, skipping...") # continue # bug # 将这个take下的所有结果存储 #存储之前,确保take下每个物体都存在,有的物体下没有targer_cam,需要置空写死 check_obj = set(objs) - set(objs_after) if len(check_obj) > 0: for obj in check_obj: # 对于缺失的物体,分两种情况。一种是obj下没有任何cam,另一种是obj下有cam但是没有ids # cams = annotations['masks'][obj].keys() # exo_cams = [x for x in cams if 'aria' not in x] # ego_cams = [x for x in cams if 'aria' in x] # # 如果exo_cams和ego_cams都非空,则需要增加一个cam的键 # if len(exo_cams) > 0 and len(ego_cams) > 0: # pred_json['masks'][obj] = {} # ego = ego_cams[0] # for exo in exo_cams: # pred_json['masks'][obj][f"{exo}_{ego}"] = {} # else: print(f"{take_id}缺失物体{obj}") pred_json['masks'][obj] = {} #pred_json['masks'][obj] = "xxx" result[take_id] = pred_json # 每个take定期保存,防止中断 # with open(save_path_json, 'w') as fp: # json.dump(result, fp) # 保存最后的结果 # with open(save_path_json, "w") as fp: # json.dump(result, fp) # 打印miss anno样本的数目 print(f"Total number of missing annotations: {anno_miss_num}") if __name__ == '__main__': evaluation() # path = "/scratch/yuqian_fu/competition_test_20250516_single_take_complete.json" # with open(path, "r") as fp: # result = json.load(fp) # pred = result["40dc3bbc-c8c4-4e6d-a2a5-32a357f3c291"] # print(len(list(pred['masks'].keys())))