|
|
import torch |
|
|
import os |
|
|
from enum import Enum |
|
|
from tqdm import tqdm |
|
|
import numpy as np |
|
|
from detectron2.structures import BitMasks |
|
|
from psalm.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, \ |
|
|
DEFAULT_IM_END_TOKEN, DEFAULT_SEG_TOKEN, SEG_TOKEN_INDEX |
|
|
from psalm.model.builder import load_pretrained_model |
|
|
from psalm.utils import disable_torch_init |
|
|
from psalm.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria |
|
|
import cv2 |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
from psalm import conversation as conversation_lib |
|
|
|
|
|
from psalm.train.train_datasets import COCO_interactive_dataset |
|
|
|
|
|
|
|
|
from detectron2.structures import BoxMode |
|
|
from detectron2.data import MetadataCatalog, DatasetCatalog |
|
|
from typing import Dict, Optional, Sequence, List |
|
|
from dataclasses import dataclass, field |
|
|
import torch.distributed as dist |
|
|
import transformers |
|
|
from pathlib import Path |
|
|
|
|
|
from psalm.eval.segmentation_evaluation import openseg_classes |
|
|
from natsort import natsorted |
|
|
COLOR_MAP = openseg_classes.ADE20K_150_CATEGORIES |
|
|
|
|
|
|
|
|
import re |
|
|
from psalm.constants import IGNORE_INDEX, IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_PATCH_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, SEG_TOKEN_INDEX, CLS_TOKEN_INDEX, REGION_TOKEN_INDEX, REFER_TOKEN_INDEX |
|
|
|
|
|
class Multicondition_Dataset(COCO_interactive_dataset): |
|
|
|
|
|
|
|
|
def preprocess_referring_instruction(self,instruction, REFER_token='[SEG]'): |
|
|
tokenized = self.tokenizer.encode(instruction, add_special_tokens=False) |
|
|
tokenized = tokenized + [self.tokenizer.encode(REFER_token, add_special_tokens=False)[0]] |
|
|
|
|
|
token_refer_id = torch.tensor(tokenized) |
|
|
|
|
|
return token_refer_id |
|
|
|
|
|
|
|
|
def tokenizer_special_tokens(self, prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, |
|
|
seg_token_index=SEG_TOKEN_INDEX, cls_token_index=CLS_TOKEN_INDEX, |
|
|
region_token_index=REGION_TOKEN_INDEX,refer_token_index=REFER_TOKEN_INDEX, return_tensors=None): |
|
|
input_ids = [] |
|
|
special_token_map = {'<image>': image_token_index, '<seg>': seg_token_index, '<cls>': cls_token_index, '<region>':region_token_index, '<refer>':refer_token_index} |
|
|
prompt_chunks = re.split('(<image>|<seg>|<cls>|<region>|<refer>)', prompt) |
|
|
|
|
|
for chunk in prompt_chunks: |
|
|
if chunk in special_token_map: |
|
|
input_ids.append(special_token_map[chunk]) |
|
|
else: |
|
|
input_ids.extend(tokenizer.encode(chunk, add_special_tokens=False)) |
|
|
if return_tensors is not None: |
|
|
if return_tensors == 'pt': |
|
|
return torch.tensor(input_ids, dtype=torch.long).squeeze() |
|
|
raise ValueError(f'Unsupported tensor type: {return_tensors}') |
|
|
else: |
|
|
return input_ids |
|
|
|
|
|
|
|
|
def __getitem__(self, idx): |
|
|
data = self.data[idx] |
|
|
|
|
|
|
|
|
image_file = data['image'] |
|
|
|
|
|
image_folder = self.data_args.image_folder |
|
|
|
|
|
|
|
|
data_dict = {} |
|
|
|
|
|
data_dict['file_name'] = os.path.join(image_folder, image_file) |
|
|
data_dict['height'] = data['image_info']['height'] |
|
|
data_dict['width'] = data['image_info']['width'] |
|
|
|
|
|
data_dict['image_id'] = data['new_img_id'] |
|
|
|
|
|
data_dict['annotations'] = data['anns'] |
|
|
|
|
|
data_dict['vp_annotations'] = data['first_frame_anns'] |
|
|
|
|
|
data_dict['vp_image'] = os.path.join(image_folder,data['first_frame_image']) |
|
|
|
|
|
for annotation in data_dict['annotations']: |
|
|
annotation['bbox_mode'] = BoxMode.XYXY_ABS |
|
|
|
|
|
annotation['bbox'] = [0,0,0,0] |
|
|
annotation['image_id'] = data['new_img_id'] |
|
|
|
|
|
|
|
|
for annotation in data_dict['vp_annotations']: |
|
|
annotation['bbox_mode'] = BoxMode.XYXY_ABS |
|
|
annotation['bbox'] = [0,0,0,0] |
|
|
annotation['image_id'] = data['new_img_id'] |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(self.data_args.image_processor,dict): |
|
|
|
|
|
processor = self.data_args.image_processor['instance'] |
|
|
|
|
|
else: |
|
|
processor = self.data_args.image_processor |
|
|
|
|
|
region_mask_type = getattr(self.data_args,'region_mask_type',None) |
|
|
if region_mask_type is not None: |
|
|
region_mask_type = region_mask_type.split('||') |
|
|
|
|
|
|
|
|
data_dict = processor.preprocess(data_dict,region_mask_type=region_mask_type,mask_format='bitmask') |
|
|
|
|
|
|
|
|
sentences = data['instruction'] |
|
|
|
|
|
|
|
|
|
|
|
num_target = len(data_dict['instances']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prefix_inst = 'This is an image <image>, Please segment by given regions and instruction' |
|
|
|
|
|
|
|
|
|
|
|
instruction = '' |
|
|
for sent in sentences: |
|
|
instruction += ' {}.'.format(sent['sent']) |
|
|
|
|
|
|
|
|
|
|
|
regions_inst = ' <region>,' * (num_target - 1) + ' <region>.' |
|
|
sources_value = f'\nThis is all regions: {regions_inst}\n' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
sources = [[{'from': 'human', 'value': prefix_inst + sources_value + "and this is the instruction: " + '<refer>\n'}, |
|
|
{'from': 'gpt', 'value': '\n[SEG]<seg>'}]] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
text_dict = self.preprocess_llama2(sources, self.tokenizer) |
|
|
|
|
|
input_ids = text_dict['input_ids'][0] |
|
|
|
|
|
labels = text_dict['labels'][0] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
token_refer_id = self.preprocess_referring_instruction(instruction) |
|
|
refer_embedding_indices = torch.zeros_like(input_ids) |
|
|
refer_embedding_indices[input_ids == REFER_TOKEN_INDEX] = 1 |
|
|
|
|
|
|
|
|
data_dict['input_ids'] = input_ids |
|
|
data_dict['labels'] = labels |
|
|
data_dict['dataset_type'] = 'referring_coco' |
|
|
|
|
|
|
|
|
|
|
|
data_dict['token_refer_id'] = token_refer_id |
|
|
data_dict['refer_embedding_indices'] = refer_embedding_indices |
|
|
return data_dict |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DataCollatorForCOCODatasetV2(object): |
|
|
"""Collate examples for supervised fine-tuning.""" |
|
|
|
|
|
tokenizer: transformers.PreTrainedTokenizer |
|
|
|
|
|
|
|
|
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
|
|
input_ids, labels = tuple([instance[key] for instance in instances] |
|
|
for key in ("input_ids", "labels")) |
|
|
input_ids = torch.nn.utils.rnn.pad_sequence( |
|
|
input_ids, |
|
|
batch_first=True, |
|
|
padding_value=self.tokenizer.pad_token_id) |
|
|
labels = torch.nn.utils.rnn.pad_sequence(labels, |
|
|
batch_first=True, |
|
|
padding_value=IGNORE_INDEX) |
|
|
input_ids = input_ids[:, :self.tokenizer.model_max_length] |
|
|
labels = labels[:, :self.tokenizer.model_max_length] |
|
|
batch = dict( |
|
|
input_ids=input_ids, |
|
|
labels=labels, |
|
|
attention_mask=input_ids.ne(self.tokenizer.pad_token_id), |
|
|
) |
|
|
if 'image' in instances[0]: |
|
|
images = [instance['image'] for instance in instances] |
|
|
if all(x is not None and x.shape == images[0].shape for x in images): |
|
|
batch['images'] = torch.stack(images) |
|
|
else: |
|
|
batch['images'] = images |
|
|
if 'vp_image' in instances[0]: |
|
|
vp_images = [instance['vp_image'] for instance in instances] |
|
|
if all(x is not None and x.shape == vp_images[0].shape for x in vp_images): |
|
|
batch['vp_images'] = torch.stack(vp_images) |
|
|
else: |
|
|
batch['vp_images'] = vp_images |
|
|
for instance in instances: |
|
|
for key in ['input_ids', 'labels', 'image']: |
|
|
del instance[key] |
|
|
batch['seg_info'] = [instance for instance in instances] |
|
|
|
|
|
if 'dataset_type' in instances[0]: |
|
|
batch['dataset_type'] = [instance['dataset_type'] for instance in instances] |
|
|
|
|
|
if 'class_name_ids' in instances[0]: |
|
|
class_name_ids = [instance['class_name_ids'] for instance in instances] |
|
|
if any(x.shape != class_name_ids[0].shape for x in class_name_ids): |
|
|
batch['class_name_ids'] = torch.nn.utils.rnn.pad_sequence( |
|
|
class_name_ids, |
|
|
batch_first=True, |
|
|
padding_value=-1, |
|
|
) |
|
|
else: |
|
|
batch['class_name_ids'] = torch.stack(class_name_ids, dim=0) |
|
|
if 'token_refer_id' in instances[0]: |
|
|
token_refer_id = [instance['token_refer_id'] for instance in instances] |
|
|
batch['token_refer_id'] = token_refer_id |
|
|
if 'cls_indices' in instances[0]: |
|
|
cls_indices = [instance['cls_indices'] for instance in instances] |
|
|
if any(x.shape != cls_indices[0].shape for x in cls_indices): |
|
|
batch['cls_indices'] = torch.nn.utils.rnn.pad_sequence( |
|
|
cls_indices, |
|
|
batch_first=True, |
|
|
padding_value=-1, |
|
|
) |
|
|
else: |
|
|
batch['cls_indices'] = torch.stack(cls_indices, dim=0) |
|
|
if 'random_idx' in instances[0]: |
|
|
random_idxs = [instance['random_idx'] for instance in instances] |
|
|
batch['random_idx'] = torch.stack(random_idxs, dim=0) |
|
|
if 'class_name_embedding_indices' in instances[0]: |
|
|
class_name_embedding_indices = [instance['class_name_embedding_indices'] for instance in instances] |
|
|
class_name_embedding_indices = torch.nn.utils.rnn.pad_sequence( |
|
|
class_name_embedding_indices, |
|
|
batch_first=True, |
|
|
padding_value=0) |
|
|
batch['class_name_embedding_indices'] = class_name_embedding_indices |
|
|
if 'refer_embedding_indices' in instances[0]: |
|
|
refer_embedding_indices = [instance['refer_embedding_indices'] for instance in instances] |
|
|
refer_embedding_indices = torch.nn.utils.rnn.pad_sequence( |
|
|
refer_embedding_indices, |
|
|
batch_first=True, |
|
|
padding_value=0) |
|
|
batch['refer_embedding_indices'] = refer_embedding_indices |
|
|
|
|
|
return batch |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class DataArguments: |
|
|
data_path: str = field(default=None, |
|
|
metadata={"help": "Path to the training data."}) |
|
|
lazy_preprocess: bool = False |
|
|
is_multimodal: bool = False |
|
|
image_folder: Optional[str] = field(default='/path/to/val2017') |
|
|
model_path: Optional[str] = field(default="/path/to/model") |
|
|
mask_config: Optional[str] = field(default="./psalm/mask_config/maskformer2_swin_base_384_bs16_50ep.yaml") |
|
|
image_aspect_ratio: str = 'square' |
|
|
image_grid_pinpoints: Optional[str] = field(default=None) |
|
|
json_path: str = '/path/to/coco' |
|
|
model_map_name: str = 'psalm_video' |
|
|
version: str = 'llava_phi' |
|
|
segmentation: bool = True |
|
|
eval_batch_size: int = 1 |
|
|
dataloader_num_workers: int = 8 |
|
|
seg_task: Optional[str] = field(default="region") |
|
|
region_mask_type: Optional[str] = field(default=None) |
|
|
with_memory: bool = False |
|
|
resume: bool = False |
|
|
resume_path: Optional[str] = field(default=None) |
|
|
|
|
|
def parse_outputs(outputs,gt_mask): |
|
|
res_list = [] |
|
|
for output in outputs: |
|
|
|
|
|
|
|
|
pred_mask = output['instances'].pred_masks |
|
|
pred_mask = pred_mask.cpu().numpy() |
|
|
scores = output['instances'].scores.transpose(1,0).cpu().numpy() |
|
|
gt_mask = output['gt'].cpu().numpy().astype(np.uint8) |
|
|
try: |
|
|
pred_cls = output['instances'].pred_classes.cpu().numpy() |
|
|
except: |
|
|
pred_cls = None |
|
|
assert scores.shape[0] == gt_mask.shape[0] |
|
|
for i in range(gt_mask.shape[0]): |
|
|
res = { |
|
|
'pred':pred_mask, |
|
|
'gt': gt_mask[i], |
|
|
'scores':scores[i], |
|
|
'pred_cls':pred_cls |
|
|
} |
|
|
res_list.append(res) |
|
|
return res_list |
|
|
|
|
|
|
|
|
class DAVIS_Dataset(COCO_interactive_dataset): |
|
|
|
|
|
|
|
|
def __getitem__(self, idx): |
|
|
data = self.data[idx] |
|
|
|
|
|
|
|
|
image_file = data['image'] |
|
|
|
|
|
image_folder = self.data_args.image_folder |
|
|
|
|
|
|
|
|
data_dict = {} |
|
|
|
|
|
data_dict['file_name'] = os.path.join(image_folder, image_file) |
|
|
data_dict['height'] = data['image_info']['height'] |
|
|
data_dict['width'] = data['image_info']['width'] |
|
|
|
|
|
data_dict['image_id'] = data['new_img_id'] |
|
|
|
|
|
data_dict['annotations'] = data['anns'] |
|
|
|
|
|
data_dict['vp_annotations'] = data['first_frame_anns'] |
|
|
|
|
|
data_dict['vp_image'] = os.path.join(image_folder,data['first_frame_image']) |
|
|
for annotation in data_dict['annotations']: |
|
|
annotation['bbox_mode'] = BoxMode.XYXY_ABS |
|
|
|
|
|
annotation['bbox'] = [0,0,0,0] |
|
|
annotation['image_id'] = data['new_img_id'] |
|
|
for annotation in data_dict['vp_annotations']: |
|
|
annotation['bbox_mode'] = BoxMode.XYXY_ABS |
|
|
annotation['bbox'] = [0,0,0,0] |
|
|
annotation['image_id'] = data['new_img_id'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
processor = self.data_args.image_processor['instance'] |
|
|
|
|
|
|
|
|
|
|
|
region_mask_type = getattr(self.data_args,'region_mask_type',None) |
|
|
if region_mask_type is not None: |
|
|
region_mask_type = region_mask_type.split('||') |
|
|
|
|
|
|
|
|
data_dict = processor.preprocess(data_dict,region_mask_type=region_mask_type,mask_format='bitmask') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
num_target = len(data_dict['instances']) |
|
|
|
|
|
prefix_inst = 'This is an image <image>, Please segment by given regions' |
|
|
|
|
|
regions_inst = ' <region>,' * (num_target - 1) + ' <region>.' |
|
|
sources_value = f'\nThis is all regions: {regions_inst}\n' |
|
|
|
|
|
|
|
|
sources = [ |
|
|
[{'from': 'human', 'value': prefix_inst + sources_value}, |
|
|
{'from': 'gpt', 'value': '\n[SEG]<seg>'}]] |
|
|
|
|
|
text_dict = self.preprocess_llama2(sources, self.tokenizer) |
|
|
|
|
|
input_ids = text_dict['input_ids'][0] |
|
|
|
|
|
labels = text_dict['labels'][0] |
|
|
data_dict['input_ids'] = input_ids |
|
|
data_dict['labels'] = labels |
|
|
data_dict['dataset_type'] = 'region_coco' |
|
|
|
|
|
return data_dict |
|
|
|
|
|
|
|
|
import zlib |
|
|
import base64 |
|
|
|
|
|
def compress_mask(mask): |
|
|
|
|
|
packed = np.packbits(mask.astype(bool), axis=None) |
|
|
|
|
|
compressed = zlib.compress(packed.tobytes()) |
|
|
|
|
|
return base64.b64encode(compressed).decode('ascii') |
|
|
|
|
|
def decompress_mask(encoded_str, shape): |
|
|
|
|
|
compressed = base64.b64decode(encoded_str) |
|
|
packed = zlib.decompress(compressed) |
|
|
arr = np.frombuffer(packed, dtype=np.uint8) |
|
|
return np.unpackbits(arr).reshape(shape).astype(bool) |
|
|
|
|
|
|
|
|
|
|
|
def evaluation(): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
parser = transformers.HfArgumentParser(DataArguments) |
|
|
data_args = parser.parse_args_into_dataclasses()[0] |
|
|
disable_torch_init() |
|
|
model_path = os.path.expanduser(data_args.model_path) |
|
|
model_name = get_model_name_from_path(model_path) |
|
|
print(f'current model is {model_path}') |
|
|
tokenizer, model, image_processor, context_len = load_pretrained_model(model_path, None, model_name, model_args=data_args, mask_config=data_args.mask_config, device='cuda') |
|
|
|
|
|
|
|
|
|
|
|
data_args.image_processor = image_processor |
|
|
data_args.is_multimodal = True |
|
|
conversation_lib.default_conversation = conversation_lib.conv_templates[data_args.version] |
|
|
|
|
|
eval_dataset = DAVIS_Dataset(json_path=data_args.json_path, tokenizer=tokenizer, data_args=data_args) |
|
|
data_collator = DataCollatorForCOCODatasetV2(tokenizer=tokenizer) |
|
|
|
|
|
dataloader_params = { |
|
|
"batch_size": data_args.eval_batch_size, |
|
|
"num_workers": data_args.dataloader_num_workers, |
|
|
} |
|
|
eval_dataloader = DataLoader(eval_dataset, batch_size=dataloader_params['batch_size'], collate_fn=data_collator, |
|
|
num_workers=dataloader_params['num_workers']) |
|
|
|
|
|
def load_ref_dataset(): |
|
|
return DAVIS_Dataset(json_path=data_args.json_path, tokenizer=tokenizer, data_args=data_args) |
|
|
|
|
|
|
|
|
DatasetCatalog.register('refcoco_dataset', load_ref_dataset) |
|
|
MetadataCatalog.get('refcoco_dataset').set(stuff_classes=['object'],) |
|
|
gt_json_path = data_args.json_path |
|
|
|
|
|
save_dir = os.path.dirname(gt_json_path) |
|
|
save_dir = os.path.join(save_dir,'predictions') |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
model.to(device=device,dtype=torch.float).eval() |
|
|
|
|
|
|
|
|
prev_image = None |
|
|
prev_mask_list = None |
|
|
prev_fill_number_list = None |
|
|
prev_video = None |
|
|
prev_transformer = None |
|
|
|
|
|
|
|
|
|
|
|
import json |
|
|
from pycocotools import mask as mask_utils |
|
|
|
|
|
splits_path = "/home/yuqian_fu/Projects/ego-exo4d-relation/correspondence/SegSwap/data/split.json" |
|
|
save_path = "/work/yuqian_fu/Ego/bisai_results_base64.json" |
|
|
with open(splits_path, "r") as fp: |
|
|
splits = json.load(fp) |
|
|
takes_all = splits["val"] |
|
|
|
|
|
if data_args.resume: |
|
|
with open(data_args.resume_path, "r") as fp: |
|
|
result = json.load(fp) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
processed_takes = set(result.keys()) |
|
|
takes_all = [take_id for take_id in takes_all if take_id not in processed_takes] |
|
|
|
|
|
else: |
|
|
result = {} |
|
|
|
|
|
|
|
|
scaler = torch.cuda.amp.autocast(enabled=True) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
for take_id in tqdm(takes_all): |
|
|
print("current take_id:", take_id) |
|
|
|
|
|
with open(f'{data_args.image_folder}/{take_id}/annotation.json', 'r') as fp: |
|
|
annotations = json.load(fp) |
|
|
|
|
|
|
|
|
objs = natsorted(list(annotations["masks"].keys())) |
|
|
coco_id_to_cont_id = {cont_id + 1: coco_id for cont_id, coco_id in enumerate(objs)} |
|
|
id_range = list(coco_id_to_cont_id.keys()) |
|
|
|
|
|
|
|
|
pred_json = {'masks': {}, 'subsample_idx': annotations['subsample_idx']} |
|
|
|
|
|
for idx, inputs in enumerate(eval_dataloader): |
|
|
inputs = {k: v.to(device) if torch.is_tensor(v) else v for k, v in inputs.items()} |
|
|
|
|
|
video_name = inputs['seg_info'][0]['file_name'].split('/')[-3] |
|
|
|
|
|
if video_name != take_id: |
|
|
continue |
|
|
|
|
|
|
|
|
target_cam = inputs['seg_info'][0]['file_name'].split('/')[-2] |
|
|
query_cam = inputs['seg_info'][0]['vp_file_path'].split('/')[-2] |
|
|
|
|
|
|
|
|
pair_key = f'{query_cam}_{target_cam}' |
|
|
|
|
|
|
|
|
|
|
|
id = inputs['seg_info'][0]['file_name'].split('/')[-1].split('.')[0] |
|
|
|
|
|
|
|
|
with torch.cuda.amp.autocast(): |
|
|
outputs = model.eval_video( |
|
|
input_ids=inputs['input_ids'], |
|
|
attention_mask=inputs['attention_mask'], |
|
|
images=inputs['images'].float(), |
|
|
vp_images=inputs['vp_images'].float(), |
|
|
seg_info=inputs['seg_info'], |
|
|
labels=inputs['labels'] |
|
|
) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.synchronize() |
|
|
|
|
|
|
|
|
output = outputs[0] |
|
|
pred_mask = output['instances'].pred_masks |
|
|
pred_mask = pred_mask.cpu().numpy() |
|
|
scores = output['instances'].scores.transpose(1, 0).cpu().numpy() |
|
|
gt_mask = output['gt'].cpu().numpy().astype(np.uint8) |
|
|
assert len(scores) == len(inputs['seg_info'][0]['instances'].vp_fill_number) |
|
|
prev_idx = [] |
|
|
for i in range(len(scores)): |
|
|
cur_scores = scores[i] |
|
|
cur_fill_number = inputs['seg_info'][0]['instances'].vp_fill_number[i] |
|
|
|
|
|
if cur_fill_number not in id_range: |
|
|
print(f"cur_fill_number {cur_fill_number} not in id_range, skipping...") |
|
|
continue |
|
|
max_score, idx = torch.topk(torch.tensor(cur_scores), 10, largest=True, sorted=True) |
|
|
idx = idx.cpu().numpy() |
|
|
for i in range(10): |
|
|
if idx[i] not in prev_idx: |
|
|
prev_idx.append(idx[i]) |
|
|
pick_idx = idx[i] |
|
|
pick_score = max_score[i] |
|
|
break |
|
|
|
|
|
cur_pred = pred_mask[pick_idx, :].astype(bool) |
|
|
compressed_str = compress_mask(cur_pred) |
|
|
|
|
|
|
|
|
obj_name = coco_id_to_cont_id[cur_fill_number.item()] |
|
|
|
|
|
if target_cam not in annotations['masks'][obj_name].keys(): |
|
|
print(f"target_cam {target_cam} not in {obj_name}, skipping...") |
|
|
continue |
|
|
if id not in annotations["masks"][obj_name][target_cam].keys(): |
|
|
print(f"id {id} not in {target_cam}, skipping...") |
|
|
continue |
|
|
|
|
|
|
|
|
|
|
|
if obj_name not in pred_json['masks']: |
|
|
pred_json['masks'][obj_name] = {} |
|
|
|
|
|
|
|
|
if pair_key not in pred_json['masks'][obj_name]: |
|
|
pred_json['masks'][obj_name][pair_key] = {} |
|
|
|
|
|
|
|
|
pred_json['masks'][obj_name][f'{query_cam}_{target_cam}'][id] = {'pred_mask': compressed_str, 'confidence': pick_score.item(), 'shape': cur_pred.shape} |
|
|
|
|
|
|
|
|
if len(pred_json['masks']) == 0: |
|
|
print(f"pred_json['masks'] is empty for take_id {take_id}, skipping...") |
|
|
continue |
|
|
|
|
|
result[take_id] = pred_json |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open(save_path, "w") as fp: |
|
|
json.dump(result, fp) |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
evaluation() |
|
|
|