| import os |
| import gc |
| import torch |
|
|
| from PIL import Image |
| from huggingface_hub import hf_hub_download |
| import argparse |
|
|
| |
| from ape.model_zoo import get_config_file |
| from third_parts.APE.demo.demo_lazy import get_parser, setup_cfg |
| from third_parts.APE.demo.predictor_lazy import VisualizationDemo |
| |
| from detectron2.config import CfgNode |
| from detectron2.data.detection_utils import read_image, _apply_exif_orientation, convert_PIL_to_numpy |
| import detectron2.data.transforms as T |
| from detectron2.evaluation.coco_evaluation import instances_to_coco_json |
| from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES, CITYSCAPES_CATEGORIES |
| from detectron2.data.datasets.lvis_v1_categories import LVIS_CATEGORIES |
| import ape |
|
|
| class APEPredictor(object): |
| def __init__(self, |
| repo_id="shenyunhang/APE", |
| checkpoint_file="configs/LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO_GQA_PhraseCut_Flickr30k/ape_deta/ape_deta_vitl_eva02_clip_vlf_lsj1024_cp_16x4_1080k_mdl_20230829_162438/model_final.pth", |
| config_file="LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO_GQA_PhraseCut_Flickr30k/ape_deta/ape_deta_vitl_eva02_clip_vlf_lsj1024_cp_16x4_1080k.py", |
| target_categories=None, |
| ): |
| super().__init__() |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| self.running_device = "cuda" |
| |
| if os.path.exists(checkpoint_file): |
| init_checkpoint = checkpoint_file |
| else: |
| init_checkpoint = hf_hub_download(repo_id=repo_id, filename=checkpoint_file) |
| |
| |
| args_dict = { |
| 'config_file': "configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml", |
| 'webcam': False, |
| 'video_input': '', |
| 'input': '', |
| 'output': '', |
| 'confidence_threshold': 0.5, |
| 'opts': [], |
| 'text_prompt': None, |
| 'with_box': False, |
| 'with_mask': False, |
| 'with_sseg': False, |
| } |
| args = argparse.Namespace(**args_dict) |
|
|
| args.config_file = get_config_file( |
| config_file |
| ) |
| args.confidence_threshold = 0.01 |
| args.opts = [ |
| "train.init_checkpoint='{}'".format(init_checkpoint), |
| "model.model_language.cache_dir=''", |
| "model.model_vision.select_box_nums_for_evaluation=500", |
| "model.model_vision.text_feature_bank_reset=True", |
| "model.model_vision.backbone.net.xattn=False", |
| "model.model_vision.transformer.encoder.pytorch_attn=True", |
| "model.model_vision.transformer.decoder.pytorch_attn=True", |
| ] |
|
|
| if self.running_device == "cpu": |
| args.opts += [ |
| "model.model_language.dtype='float32'", |
| ] |
| |
| cfg = setup_cfg(args) |
| cfg.model.model_vision.criterion[0].use_fed_loss = False |
| cfg.model.model_vision.criterion[2].use_fed_loss = False |
| cfg.train.device = self.running_device |
| |
| |
| |
|
|
| ape.modeling.text.eva02_clip.factory._MODEL_CONFIGS[cfg.model.model_language.clip_model][ |
| "vision_cfg" |
| ]["layers"] = 1 |
|
|
| self.cfg = cfg |
| self.args = args |
| self.aug = T.ResizeShortestEdge([1024, 1024], 1024) |
|
|
| self.demo = VisualizationDemo(cfg, args=args) |
| self.demo.predictor.model.to(self.running_device) |
|
|
| self.target_categories = target_categories |
| |
| def run_on_image( |
| self, |
| input_image_path, |
| input_text=None, |
| score_threhold=0.1, |
| output_type=["object detection", "instance segmentation"], |
| visualize=False, |
| ): |
| """ |
| input_text: word1,word2,sentence1,sentence2,... |
| output_type: ["object detection", "instance segmentation", "semantic segmentation"] |
| """ |
| self.demo.predictor.model.model_vision.test_score_thresh = score_threhold |
|
|
| if input_text is None or input_text == "": |
| assert self.target_categories is not None and self.target_categories != "", \ |
| "The `input_text` and `self.target_categories` can't be None at the same time." |
| input_text = self.target_categories |
|
|
| with_box = False |
| with_mask = False |
| with_sseg = False |
| if "object detection" in output_type: |
| with_box = True |
| if "instance segmentation" in output_type: |
| with_mask = True |
| if "semantic segmentation" in output_type: |
| with_sseg = True |
| |
| if not isinstance(input_image_path, str): |
| input_image = _apply_exif_orientation(input_image_path) |
| input_image = convert_PIL_to_numpy(input_image, format="BGR") |
| else: |
| input_image = read_image(input_image_path, format="BGR") |
| |
| input_mask = None |
|
|
| if not with_box and not with_mask and not with_sseg: |
| return None |
| |
| if input_image.shape[0] > 1024 or input_image.shape[1] > 1024: |
| transform = self.aug.get_transform(input_image) |
| input_image = transform.apply_image(input_image) |
| else: |
| transform = None |
| |
| if not visualize: |
| predictions = self.demo.run_on_image( |
| input_image, |
| text_prompt=input_text, |
| mask_prompt=input_mask, |
| with_box=with_box, |
| with_mask=with_mask, |
| with_sseg=with_sseg, |
| visualize=visualize, |
| ) |
| output_image = None |
| else: |
| predictions, visualized_output, _, metadata = self.demo.run_on_image( |
| input_image, |
| text_prompt=input_text, |
| mask_prompt=input_mask, |
| with_box=with_box, |
| with_mask=with_mask, |
| with_sseg=with_sseg, |
| visualize=visualize, |
| ) |
| output_image = visualized_output.get_image() |
| if transform: |
| output_image = transform.inverse().apply_image(output_image) |
| output_image = Image.fromarray(output_image) |
| |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| json_results = instances_to_coco_json(predictions["instances"].to(self.demo.cpu_device), 0) |
| for json_result in json_results: |
| json_result["category_name"] = metadata.thing_classes[json_result["category_id"]] |
| del json_result["image_id"] |
| |
| return output_image, json_results |
| |
|
|
| def build_ape_predictor(model_type: str = 'D', which_categories=None, override_ckpt_file=""): |
| repo_id = "shenyunhang/APE" |
| checkpoint_file = None |
| config_file = None |
|
|
| |
| if model_type == 'A': |
| checkpoint_file = "configs/LVISCOCOCOCOSTUFF_O365_OID_VG/ape_deta/ape_deta_vitl_eva02_lsj_cp_720k_20230504_002019/model_final.pth" |
| config_file = "LVISCOCOCOCOSTUFF_O365_OID_VG/ape_deta/ape_deta_vitl_eva02_lsj1024_cp_720k.py" |
| raise NotImplementedError |
| elif model_type == 'B': |
| checkpoint_file = "configs/LVISCOCOCOCOSTUFF_O365_OID_VGR_REFCOCO/ape_deta/ape_deta_vitl_eva02_vlf_lsj_cp_1080k_20230702_225418/model_final.pth" |
| config_file = "LVISCOCOCOCOSTUFF_O365_OID_VGR_REFCOCO/ape_deta/ape_deta_vitl_eva02_vlf_lsj1024_cp_1080k.py" |
| raise NotImplementedError |
| elif model_type == 'C': |
| checkpoint_file = "configs/LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO/ape_deta/ape_deta_vitl_eva02_vlf_lsj_cp_1080k_20230702_210950/model_final.pth" |
| config_file = "LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO/ape_deta/ape_deta_vitl_eva02_vlf_lsj1024_cp_1080k.py" |
| raise NotImplementedError |
| else: |
| checkpoint_file = "configs/LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO_GQA_PhraseCut_Flickr30k/ape_deta/ape_deta_vitl_eva02_clip_vlf_lsj1024_cp_16x4_1080k_mdl_20230829_162438/model_final.pth" |
| config_file = "LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO_GQA_PhraseCut_Flickr30k/ape_deta/ape_deta_vitl_eva02_clip_vlf_lsj1024_cp_16x4_1080k.py" |
| |
| checkpoint_file = override_ckpt_file |
|
|
| category_dict = {'COCO': COCO_CATEGORIES, 'CITYSCAPES': CITYSCAPES_CATEGORIES, 'LVIS': LVIS_CATEGORIES} |
| |
| target_categories = "" |
| if which_categories in ['COCO', 'CITYSCAPES', 'LVIS']: |
| category_names = [] |
| for item in category_dict[which_categories]: |
| category_names.append(item['name'].split('-')[0].strip()) |
| category_names = list(set(category_names)) |
| for category_name in category_names: |
| target_categories = target_categories + f",{category_name}" |
| if target_categories.startswith(','): |
| target_categories = target_categories[1:] |
| if target_categories.endswith(','): |
| target_categories = target_categories[:-1] |
| |
| predictor = APEPredictor(repo_id=repo_id, |
| checkpoint_file=checkpoint_file, |
| config_file=config_file, |
| target_categories=target_categories) |
| |
| return predictor |
| |
|
|
|
|
| if __name__ == "__main__": |
| import json |
| ape_predictor = build_ape_predictor(which_categories='COCO', |
| override_ckpt_file="shenyunhang/APE/configs/LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO_GQA_PhraseCut_Flickr30k/ape_deta/ape_deta_vitl_eva02_clip_vlf_lsj1024_cp_16x4_1080k_mdl_20230829_162438/model_final.pth") |
|
|
| output_image, json_results = ape_predictor.run_on_image( |
| "./zhouyik/zt_any_visual_prompt/sa_10293442.jpg", |
| |
| visualize=True, |
| score_threhold=0.1, |
| ) |
|
|
| with open('./sa_10293442_out.json', 'w') as savef: |
| json.dump(json_results, savef) |
| |
| output_image.save("./sa_10293442_out.jpg") |
|
|
|
|
|
|
|
|