| |
| import argparse |
| import math |
| import os |
| import os.path as osp |
| import numpy as np |
| import torch |
| import tqdm |
| from mmengine.dist import (collect_results, get_dist_info, get_rank, init_dist, |
| master_only) |
| from mmengine.utils.dl_utils import set_multi_processing |
| from torch.utils.data import Dataset |
| from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, |
| BitsAndBytesConfig, CLIPImageProcessor, |
| CLIPVisionModel, GenerationConfig) |
|
|
| from xtuner.model.utils import prepare_inputs_labels_for_multimodal |
| from xtuner.tools.utils import get_stop_criteria |
| from xtuner.utils import (DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX, |
| PROMPT_TEMPLATE) |
| from xtuner.registry import BUILDER |
| from xtuner.configs import cfgs_name_path |
| from xtuner.model.utils import guess_load_checkpoint |
| from mmengine.config import Config |
| from mmengine.fileio import PetrelBackend, get_file_backend |
| from mmengine.config import ConfigDict |
|
|
| import logging |
| from mmengine import print_log |
| from PIL import Image |
| from pycocotools import mask |
| import torch.nn.functional as F |
| from projects.omg_llava.dataset.utils import expand2square |
| from projects.omg_llava.dataset.utils.refcoco_refer import REFER |
| from projects.omg_llava.tools.utils_refcoco import AverageMeter, Summary, intersectionAndUnionGPU |
|
|
|
|
| def convert_dict2config_dict(input): |
| input = ConfigDict(**input) |
| for key in input.keys(): |
| if isinstance(input[key], dict): |
| input[key] = convert_dict2config_dict(input[key]) |
| return input |
|
|
| TORCH_DTYPE_MAP = dict( |
| fp16=torch.float16, bf16=torch.bfloat16, fp32=torch.float32, auto='auto') |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description='RefCocoSeg') |
| parser.add_argument('config', help='config file name or path.') |
| parser.add_argument('--pth_model', help='pth model file') |
| parser.add_argument( |
| '--dataset', |
| choices=DATASETS_ATTRIBUTES.keys(), |
| default='refcoco', |
| help='Specify a ref dataset') |
| parser.add_argument( |
| '--split', |
| default='val', |
| help='Specify a split') |
| parser.add_argument( |
| '--prompt-template', |
| choices=PROMPT_TEMPLATE.keys(), |
| default='internlm2_chat', |
| help='Specify a prompt template') |
| parser.add_argument( |
| '--stop-words', nargs='+', type=str, default=[], help='Stop words') |
| parser.add_argument( |
| '--torch-dtype', |
| default='fp16', |
| choices=TORCH_DTYPE_MAP.keys(), |
| help='Override the default `torch.dtype` and load the model under ' |
| 'a specific `dtype`.') |
| parser.add_argument( |
| '--bits', |
| type=int, |
| choices=[4, 8, None], |
| default=None, |
| help='LLM bits') |
| parser.add_argument( |
| '--bot-name', type=str, default='BOT', help='Name for Bot') |
| parser.add_argument( |
| '--offload-folder', |
| default=None, |
| help='The folder in which to offload the model weights (or where the ' |
| 'model weights are already offloaded).') |
| parser.add_argument( |
| '--max-new-tokens', |
| type=int, |
| default=100, |
| help='Maximum number of new tokens allowed in generated text') |
| parser.add_argument( |
| '--seed', |
| type=int, |
| default=0, |
| help='Random seed for reproducible text generation') |
| parser.add_argument( |
| '--launcher', |
| choices=['none', 'pytorch', 'slurm', 'mpi'], |
| default='none', |
| help='job launcher') |
| args = parser.parse_args() |
| return args |
|
|
| DATASETS_ATTRIBUTES = { |
| 'refcoco': {'splitBy': "unc", 'dataset_name': 'refcoco'}, |
| 'refcoco_plus': {'splitBy': "unc", 'dataset_name': 'refcoco+'}, |
| 'refcocog': {'splitBy': "umd", 'dataset_name': 'refcocog'}, |
| } |
|
|
| @master_only |
| def master_print(msg): |
| print(msg) |
|
|
| class RefcocoReferringSegDataset(Dataset): |
| def __init__(self, |
| image_folder, |
| image_processor, |
| dataset_name, |
| data_path=None, |
| tokenizer=None, |
| offline_processed_text_folder=None, |
| pad_image_to_square=False, |
| debug=False, |
| repeats=1, |
| split='val', |
| ): |
| self.split = split |
| self._set_attribute(dataset_name) |
| self.debug = debug |
| if offline_processed_text_folder and data_path: |
| print_log( |
| 'Both `offline_processed_text_folder` and ' |
| '`data_path` are set, and we load dataset from' |
| '`offline_processed_text_folder` ' |
| f'({offline_processed_text_folder})', |
| logger='current', |
| level=logging.WARNING) |
|
|
| if offline_processed_text_folder is not None: |
| raise NotImplementedError |
| else: |
| json_datas = self.json_file_preprocess(data_path) |
| self.json_datas = json_datas |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| self.image_folder = image_folder |
| size = image_processor.crop_size |
| if isinstance(size, int): |
| self.image_h, self.image_w = size, size |
| else: |
| self.image_w, self.image_h = size |
|
|
| if isinstance(image_processor, dict) or isinstance( |
| image_processor, Config) or isinstance(image_processor, |
| ConfigDict): |
| self.image_processor = BUILDER.build(image_processor) |
| else: |
| self.image_processor = image_processor |
| self.pad_image_to_square = pad_image_to_square |
| self.down_ratio = 1 |
| self.repeats = repeats |
|
|
| def _set_attribute(self, dataset_name): |
| attr_dict = DATASETS_ATTRIBUTES[dataset_name] |
|
|
| self.splitBy = attr_dict['splitBy'] |
| self.dataset_name = attr_dict['dataset_name'] |
|
|
| def __len__(self): |
| return len(self.json_datas) * self.repeats |
|
|
| def real_len(self): |
| return len(self.json_datas) |
|
|
| def json_file_preprocess(self, data_path): |
| splitBy = self.splitBy |
| dataset_name = self.dataset_name |
| refer_api = REFER(data_path, dataset_name, splitBy) |
| ref_ids_train = refer_api.getRefIds(split=self.split) |
| images_ids_train = refer_api.getImgIds(ref_ids=ref_ids_train) |
| refs_train = refer_api.loadRefs(ref_ids=ref_ids_train) |
| self.img2refs = self.create_img_to_refs_mapping(refs_train) |
|
|
| image_infos = [] |
| loaded_images = refer_api.loadImgs(image_ids=images_ids_train) |
| for item in loaded_images: |
| item = item.copy() |
| image_infos.append(item) |
|
|
| self.annotations = refer_api.Anns |
| |
|
|
| refs = [self.img2refs[image_info['id']] for image_info in image_infos] |
|
|
| ret = [] |
| for image_info, ref in zip(image_infos, refs): |
| if len(ref) == 0: |
| continue |
|
|
| sents = [] |
| ann_ids = [] |
| for _ref in ref: |
| for sent in _ref["sentences"]: |
| text = sent["sent"] |
| sents.append(text) |
| ann_ids.append(_ref["ann_id"]) |
|
|
| |
| |
| |
| |
| |
| |
|
|
| sampled_inds = list(range(len(sents))) |
| sampled_sents = np.vectorize(sents.__getitem__)(sampled_inds).tolist() |
| |
| sampled_ann_ids = [ann_ids[ind] for ind in sampled_inds] |
| selected_labels = sampled_sents |
| ret.append( |
| {'image_info': image_info, |
| 'sampled_ann_id': sampled_ann_ids, |
| 'selected_labels': selected_labels, |
| 'image': image_info['file_name'] |
| } |
| ) |
| if self.debug: |
| return ret[:10] |
| return ret |
|
|
| def load_images(self, refer_api, images_ids_train, dataset_dir, dataset_name, inference=False): |
| images = [] |
| loaded_images = refer_api.loadImgs(image_ids=images_ids_train) |
| for item in loaded_images: |
| item = item.copy() |
| images.append(item) |
| return images |
|
|
| def create_img_to_refs_mapping(self, refs_train): |
| img2refs = {} |
| for ref in refs_train: |
| img2refs[ref["image_id"]] = img2refs.get(ref["image_id"], []) + [ref, ] |
| return img2refs |
|
|
| def decode_mask(self, annotations_ids, image_info): |
| flag = False |
| masks = [] |
|
|
| for ann_id in annotations_ids: |
| if isinstance(ann_id, list): |
| flag = True |
| if -1 in ann_id: |
| assert len(ann_id) == 1 |
| m = np.zeros((image_info["height"], image_info["width"])).astype( |
| np.uint8 |
| ) |
| else: |
| m_final = np.zeros( |
| (image_info["height"], image_info["width"]) |
| ).astype(np.uint8) |
| for ann_id_i in ann_id: |
| ann = self.annotations[ann_id_i] |
|
|
| if len(ann["segmentation"]) == 0: |
| m = np.zeros( |
| (image_info["height"], image_info["width"]) |
| ).astype(np.uint8) |
| else: |
| if type(ann["segmentation"][0]) == list: |
| rle = mask.frPyObjects( |
| ann["segmentation"], image_info["height"], image_info["width"], ) |
| else: |
| rle = ann["segmentation"] |
| for i in range(len(rle)): |
| if not isinstance(rle[i]["counts"], bytes): |
| rle[i]["counts"] = rle[i]["counts"].encode() |
| m = mask.decode(rle) |
| m = np.sum( |
| m, axis=2 |
| ) |
| m = m.astype(np.uint8) |
| m_final = m_final | m |
| m = m_final |
| masks.append(m) |
| continue |
|
|
| ann = self.annotations[ann_id] |
|
|
| if len(ann["segmentation"]) == 0: |
| m = np.zeros((image_info["height"], image_info["width"])).astype( |
| np.uint8 |
| ) |
| masks.append(m) |
| continue |
|
|
| if type(ann["segmentation"][0]) == list: |
| rle = mask.frPyObjects( |
| ann["segmentation"], image_info["height"], image_info["width"] |
| ) |
| else: |
| rle = ann["segmentation"] |
| for i in range(len(rle)): |
| if not isinstance(rle[i]["counts"], bytes): |
| rle[i]["counts"] = rle[i]["counts"].encode() |
| m = mask.decode(rle) |
| m = np.sum(m, axis=2) |
| m = m.astype(np.uint8) |
| masks.append(m) |
| masks = np.stack(masks, axis=0) |
|
|
| |
| |
| masks = torch.from_numpy(masks) |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| return masks |
|
|
| def only_get_text_infos(self, json_data): |
| return {'sampled_sents': json_data['selected_labels']} |
|
|
| def get_questions(self, text_require_infos): |
| sampled_sents = text_require_infos['sampled_sents'] |
| ret = [] |
| for sent in sampled_sents: |
| ret.append("Please segment {} in this image.".format(sent)) |
| return ret |
|
|
| def filter_data_dict(self, data_dict): |
| names = ['pixel_values', 'masks', 'ori_size', 'questions'] |
| ret = {name: data_dict[name] for name in names} |
| return ret |
|
|
| def __getitem__(self, index): |
| index = index % self.real_len() |
| data_dict = self.json_datas[index] |
| text_require_infos = self.only_get_text_infos(data_dict) |
| questions = self.get_questions(text_require_infos) |
|
|
| assert data_dict.get('image', None) is not None |
| if data_dict.get('image', None) is not None: |
| image_file = data_dict['image'] |
| image_file = os.path.join(self.image_folder, image_file) |
| |
| image = Image.open(image_file).convert('RGB') |
| ori_width, ori_height = image.size |
| if self.pad_image_to_square: |
| image = expand2square( |
| image, |
| tuple( |
| int(x * 255) for x in self.image_processor.image_mean)) |
| image = self.image_processor.preprocess( |
| image, return_tensors='pt')['pixel_values'][0] |
| data_dict['pixel_values'] = image |
|
|
| |
| masks = self.decode_mask(data_dict['sampled_ann_id'], data_dict['image_info']) |
| data_dict['masks'] = masks |
| data_dict['ori_size'] = (ori_width, ori_height) |
| data_dict['questions'] = questions |
| else: |
| if hasattr(self.image_processor, 'crop_size'): |
| crop_size = self.image_processor.crop_size |
| else: |
| crop_size = self.image_processor.size |
| data_dict['pixel_values'] = torch.zeros(3, crop_size['height'], |
| crop_size['width']) |
| data_dict['masks'] = None |
| |
| return self.filter_data_dict(data_dict) |
|
|
| def main(): |
| args = parse_args() |
|
|
| torch.manual_seed(args.seed) |
|
|
| if args.launcher != 'none': |
| set_multi_processing(distributed=True) |
| init_dist(args.launcher) |
|
|
| rank, world_size = get_dist_info() |
| torch.cuda.set_device(rank) |
| else: |
| rank = 0 |
| world_size = 1 |
|
|
| |
| if not osp.isfile(args.config): |
| try: |
| args.config = cfgs_name_path[args.config] |
| except KeyError: |
| raise FileNotFoundError(f'Cannot find {args.config}') |
|
|
| |
| cfg = Config.fromfile(args.config) |
| |
| |
|
|
| model_name = cfg.model.type if isinstance(cfg.model.type, |
| str) else cfg.model.type.__name__ |
| if 'LLaVAModel' or 'OMG' in model_name: |
| cfg.model.pretrained_pth = None |
|
|
| model = BUILDER.build(cfg.model) |
| backend = get_file_backend(args.pth_model) |
|
|
| if os.path.exists(cfg.pretrained_pth): |
| if isinstance(backend, PetrelBackend): |
| from xtuner.utils.fileio import patch_fileio |
| with patch_fileio(): |
| state_dict = guess_load_checkpoint(cfg.pretrained_pth) |
| else: |
| state_dict = guess_load_checkpoint(cfg.pretrained_pth) |
|
|
| |
| model.load_state_dict(state_dict, strict=False) |
| print(f'Load pre PTH model from {cfg.pretrained_pth}') |
|
|
| if isinstance(backend, PetrelBackend): |
| from xtuner.utils.fileio import patch_fileio |
| with patch_fileio(): |
| state_dict = guess_load_checkpoint(args.pth_model) |
| else: |
| state_dict = guess_load_checkpoint(args.pth_model) |
|
|
| |
| |
| model.load_state_dict(state_dict, strict=False) |
| print(f'Load PTH model from {args.pth_model}') |
| |
| image_processor = cfg.image_processor |
| image_processor_type = image_processor['type'] |
| del image_processor['type'] |
| image_processor = image_processor_type(**image_processor) |
|
|
| llm = model.llm |
| tokenizer = model.tokenizer |
|
|
| model.cuda() |
| model.eval() |
| llm.eval() |
| visual_encoder = model.visual_encoder |
| projector = model.projector |
| projector_text2vision = model.projector_text2vision |
|
|
| projector.cuda() |
| projector.eval() |
|
|
| visual_encoder.cuda() |
| visual_encoder.eval() |
|
|
| stop_words = args.stop_words |
| if args.prompt_template: |
| template = PROMPT_TEMPLATE[args.prompt_template] |
| stop_words += template.get('STOP_WORDS', []) |
| stop_criteria = get_stop_criteria( |
| tokenizer=tokenizer, stop_words=stop_words) |
|
|
| gen_config = GenerationConfig( |
| max_new_tokens=args.max_new_tokens, |
| do_sample=False, |
| eos_token_id=tokenizer.eos_token_id, |
| pad_token_id=tokenizer.pad_token_id |
| if tokenizer.pad_token_id is not None else tokenizer.eos_token_id, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| dataset = RefcocoReferringSegDataset( |
| dataset_name=args.dataset, |
| image_folder='./data/glamm_data/' + 'images/coco2014/train2014/', |
| image_processor=image_processor, |
| data_path="./data/ref_seg/", |
| tokenizer=tokenizer, |
| pad_image_to_square=True, |
| debug=False, |
| split=args.split, |
| |
| ) |
|
|
| results = [] |
| n_samples = len(dataset) |
| per_rank_samples = math.ceil(n_samples / world_size) |
|
|
| per_rank_ids = range(per_rank_samples * rank, |
| min(n_samples, per_rank_samples * (rank + 1))) |
|
|
| trackers = { |
| "intersection": AverageMeter("Intersec", ":6.3f", Summary.SUM), |
| "union": AverageMeter("Union", ":6.3f", Summary.SUM), |
| "gIoU": AverageMeter("gIoU", ":6.3f", Summary.SUM) |
| } |
|
|
| for i in tqdm.tqdm(per_rank_ids, desc=f'Rank {rank}'): |
| data_sample = dataset[i] |
| questions = data_sample['questions'] |
| texts = [] |
| for question in questions: |
| texts.append(DEFAULT_IMAGE_TOKEN + '\n' + question) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| prompt_texts = [] |
|
|
| if args.prompt_template: |
| for text in texts: |
| prompt_text = '' |
| template = PROMPT_TEMPLATE[args.prompt_template] |
| prompt_text += template['INSTRUCTION'].format( |
| input=text, round=1, bot_name=args.bot_name) |
| prompt_texts.append(prompt_text) |
| else: |
| prompt_texts = texts |
|
|
| batch_inputs = prompt_texts |
|
|
| image = data_sample['pixel_values'] |
| image = image.cuda().unsqueeze(0).to(visual_encoder.dtype) |
| visual_outputs = visual_encoder(image, output_hidden_states=True) |
| if isinstance(visual_outputs, list) or isinstance(visual_outputs, tuple)\ |
| or isinstance(visual_outputs, torch.Tensor): |
| pixel_values = projector(visual_outputs) |
| else: |
| pixel_values = projector( |
| visual_outputs.hidden_states[args.visual_select_layer][:, 1:]) |
| |
| |
|
|
| ori_size = data_sample['ori_size'] |
| target_masks = data_sample['masks'].cuda().to(torch.uint8) |
|
|
| intersection, union, accuracy_iou = 0.0, 0.0, 0.0 |
|
|
| for idx_inp, inputs in enumerate(batch_inputs): |
| |
| chunk_encode = [] |
| for idx, chunk in enumerate(inputs.split(DEFAULT_IMAGE_TOKEN)): |
| if idx == 0: |
| cur_encode = tokenizer.encode(chunk) |
| else: |
| cur_encode = tokenizer.encode(chunk, add_special_tokens=False) |
| chunk_encode.append(cur_encode) |
| assert len(chunk_encode) == 2 |
| ids = [] |
| for idx, cur_chunk_encode in enumerate(chunk_encode): |
| ids.extend(cur_chunk_encode) |
| if idx != len(chunk_encode) - 1: |
| ids.append(IMAGE_TOKEN_INDEX) |
| ids = torch.tensor(ids).cuda().unsqueeze(0) |
| mm_inputs = prepare_inputs_labels_for_multimodal( |
| llm=llm, input_ids=ids, pixel_values=pixel_values) |
|
|
| |
|
|
| generate_output = llm.generate( |
| **mm_inputs, |
| generation_config=gen_config, |
| streamer=None, |
| bos_token_id=tokenizer.bos_token_id, |
| stopping_criteria=stop_criteria, |
| output_hidden_states=True, |
| return_dict_in_generate=True |
| ) |
| predict = tokenizer.decode( |
| |
| generate_output.sequences[0]).strip() |
| |
|
|
| hidden_states = generate_output.hidden_states |
| last_hidden_states = [item[-1][-1] for item in hidden_states] |
| last_hidden_states = torch.cat(last_hidden_states, dim=0) |
| seg_hidden_states = get_seg_hidden_states( |
| |
| last_hidden_states, generate_output.sequences[0][:-1], |
| seg_id=model.seg_token_idx |
| ) |
| |
| |
| if len(seg_hidden_states) == 0: |
| print("Warning, no [SEG] tokens !!!") |
| continue |
| elif len(seg_hidden_states) > 1: |
| print("Warning, {} [SEG] tokens !!!".format(len(seg_hidden_states))) |
| seg_hidden_states = seg_hidden_states[:1] |
|
|
| seg_hidden_states = projector_text2vision(seg_hidden_states) |
| batch_idxs = torch.zeros((seg_hidden_states.shape[0],), |
| dtype=torch.int64).to(seg_hidden_states.device) |
| pred_masks_list = model.visual_encoder.forward_llm_seg(seg_hidden_states, batch_idxs) |
| pred_masks = pred_masks_list[-1] |
| w, h = ori_size |
| masks = F.interpolate(pred_masks, size=(max(w, h), max(w, h)), |
| mode='bilinear', align_corners=False) |
| masks = masks[:, 0] |
| |
| if w == h: |
| pass |
| elif w > h: |
| n_pad = w - h |
| n_pad_1 = n_pad // 2 |
| n_pad_2 = n_pad - n_pad_1 |
| masks = masks[:, n_pad_1: w - n_pad_2] |
| else: |
| n_pad = h - w |
| n_pad_1 = n_pad // 2 |
| n_pad_2 = n_pad - n_pad_1 |
| masks = masks[:, :, n_pad_1: h - n_pad_2] |
| |
| masks = masks.sigmoid() > 0.5 |
| masks = masks.int() |
| _target = target_masks[idx_inp:idx_inp+1].int() |
|
|
| |
| for target, prediction in zip(masks, _target): |
| intersect, union_, _ = intersectionAndUnionGPU( |
| prediction.contiguous().clone(), target.contiguous(), 2, ignore_index=255 |
| ) |
| intersection += intersect |
| union += union_ |
| accuracy_iou += intersect / (union_ + 1e-5) |
| |
| |
| accuracy_iou[union_ == 0] += 1.0 |
|
|
| intersection, union = intersection.cpu().numpy(), union.cpu().numpy() |
| accuracy_iou = accuracy_iou.cpu().numpy() / target_masks.shape[0] |
| trackers["intersection"].update(intersection) |
| trackers["union"].update(union) |
| trackers["gIoU"].update(accuracy_iou, n=target_masks.shape[0]) |
|
|
| |
| |
| |
| |
| cur_results = {'pixel_intersection': trackers["intersection"].sum[1], |
| 'pixel_union': trackers["union"].sum[1], |
| 'gIoU': trackers["gIoU"].avg[1], |
| 'mask_counts': trackers["gIoU"].count, |
| } |
| results.append(cur_results) |
| |
| |
| |
| |
| |
| |
|
|
| results = collect_results(results, n_samples) |
|
|
| if get_rank() == 0: |
| pixel_intersection = [cur_result['pixel_intersection'] for cur_result in results] |
| pixel_union = [cur_result['pixel_union'] for cur_result in results] |
| gIoUs = [cur_result['gIoU'] for cur_result in results] |
| mask_counts = [cur_result['mask_counts'] for cur_result in results] |
|
|
| class_iou = sum(pixel_intersection) / (sum(pixel_union) + 1e-10) |
| global_iou = sum([giou * n_masks for giou, n_masks in zip(gIoUs, mask_counts)]) / sum(mask_counts) |
| print("ciou: ", class_iou) |
| print("giou: ", global_iou) |
|
|
| def get_seg_hidden_states(hidden_states, output_ids, seg_id): |
| seg_mask = output_ids == seg_id |
| n_out = len(seg_mask) |
| return hidden_states[-n_out:][seg_mask] |
|
|
|
|
| if __name__ == '__main__': |
|
|
| main() |
|
|