DenseLabelDev / third_parts /APE /build_ape.py
zhouyik's picture
Upload folder using huggingface_hub
032e687 verified
import os
import gc
import torch
from PIL import Image
from huggingface_hub import hf_hub_download
import argparse
# from demo_lazy import get_parser, setup_cfg
from ape.model_zoo import get_config_file
from third_parts.APE.demo.demo_lazy import get_parser, setup_cfg
from third_parts.APE.demo.predictor_lazy import VisualizationDemo
# from predictor_lazy import VisualizationDemo
from detectron2.config import CfgNode
from detectron2.data.detection_utils import read_image, _apply_exif_orientation, convert_PIL_to_numpy
import detectron2.data.transforms as T
from detectron2.evaluation.coco_evaluation import instances_to_coco_json
from detectron2.data.datasets.builtin_meta import COCO_CATEGORIES, CITYSCAPES_CATEGORIES
from detectron2.data.datasets.lvis_v1_categories import LVIS_CATEGORIES
import ape
class APEPredictor(object):
def __init__(self,
repo_id="shenyunhang/APE",
checkpoint_file="configs/LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO_GQA_PhraseCut_Flickr30k/ape_deta/ape_deta_vitl_eva02_clip_vlf_lsj1024_cp_16x4_1080k_mdl_20230829_162438/model_final.pth",
config_file="LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO_GQA_PhraseCut_Flickr30k/ape_deta/ape_deta_vitl_eva02_clip_vlf_lsj1024_cp_16x4_1080k.py",
target_categories=None,
):
super().__init__()
# available_memory = [
# torch.cuda.mem_get_info(i)[0] / 1024**3 for i in range(torch.cuda.device_count())
# ]
# if len(available_memory) > 0:
# max_available_memory = max(available_memory)
# device_id = available_memory.index(max_available_memory)
# self.running_device = "cuda:" + str(device_id)
# else:
# max_available_memory = 0
# self.running_device = "cpu"
self.running_device = "cuda"
if os.path.exists(checkpoint_file):
init_checkpoint = checkpoint_file
else:
init_checkpoint = hf_hub_download(repo_id=repo_id, filename=checkpoint_file)
# args = get_parser().parse_args()
args_dict = {
'config_file': "configs/quick_schedules/mask_rcnn_R_50_FPN_inference_acc_test.yaml",
'webcam': False,
'video_input': '',
'input': '',
'output': '',
'confidence_threshold': 0.5,
'opts': [],
'text_prompt': None,
'with_box': False,
'with_mask': False,
'with_sseg': False,
}
args = argparse.Namespace(**args_dict)
args.config_file = get_config_file(
config_file
)
args.confidence_threshold = 0.01
args.opts = [
"train.init_checkpoint='{}'".format(init_checkpoint),
"model.model_language.cache_dir=''",
"model.model_vision.select_box_nums_for_evaluation=500",
"model.model_vision.text_feature_bank_reset=True",
"model.model_vision.backbone.net.xattn=False",
"model.model_vision.transformer.encoder.pytorch_attn=True",
"model.model_vision.transformer.decoder.pytorch_attn=True",
]
if self.running_device == "cpu":
args.opts += [
"model.model_language.dtype='float32'",
]
cfg = setup_cfg(args)
cfg.model.model_vision.criterion[0].use_fed_loss = False
cfg.model.model_vision.criterion[2].use_fed_loss = False
cfg.train.device = self.running_device
# print(cfg.model.model_language.clip_model)
# exit(0)
ape.modeling.text.eva02_clip.factory._MODEL_CONFIGS[cfg.model.model_language.clip_model][
"vision_cfg"
]["layers"] = 1
self.cfg = cfg
self.args = args
self.aug = T.ResizeShortestEdge([1024, 1024], 1024)
self.demo = VisualizationDemo(cfg, args=args)
self.demo.predictor.model.to(self.running_device)
self.target_categories = target_categories
def run_on_image(
self,
input_image_path,
input_text=None,
score_threhold=0.1,
output_type=["object detection", "instance segmentation"],
visualize=False,
):
"""
input_text: word1,word2,sentence1,sentence2,...
output_type: ["object detection", "instance segmentation", "semantic segmentation"]
"""
self.demo.predictor.model.model_vision.test_score_thresh = score_threhold
if input_text is None or input_text == "":
assert self.target_categories is not None and self.target_categories != "", \
"The `input_text` and `self.target_categories` can't be None at the same time."
input_text = self.target_categories
with_box = False
with_mask = False
with_sseg = False
if "object detection" in output_type:
with_box = True
if "instance segmentation" in output_type:
with_mask = True
if "semantic segmentation" in output_type:
with_sseg = True
if not isinstance(input_image_path, str):
input_image = _apply_exif_orientation(input_image_path)
input_image = convert_PIL_to_numpy(input_image, format="BGR")
else:
input_image = read_image(input_image_path, format="BGR")
input_mask = None
if not with_box and not with_mask and not with_sseg:
return None
if input_image.shape[0] > 1024 or input_image.shape[1] > 1024:
transform = self.aug.get_transform(input_image)
input_image = transform.apply_image(input_image)
else:
transform = None
if not visualize:
predictions = self.demo.run_on_image(
input_image,
text_prompt=input_text,
mask_prompt=input_mask,
with_box=with_box,
with_mask=with_mask,
with_sseg=with_sseg,
visualize=visualize,
)
output_image = None
else:
predictions, visualized_output, _, metadata = self.demo.run_on_image(
input_image,
text_prompt=input_text,
mask_prompt=input_mask,
with_box=with_box,
with_mask=with_mask,
with_sseg=with_sseg,
visualize=visualize,
)
output_image = visualized_output.get_image()
if transform:
output_image = transform.inverse().apply_image(output_image)
output_image = Image.fromarray(output_image)
gc.collect()
torch.cuda.empty_cache()
json_results = instances_to_coco_json(predictions["instances"].to(self.demo.cpu_device), 0)
for json_result in json_results:
json_result["category_name"] = metadata.thing_classes[json_result["category_id"]]
del json_result["image_id"]
return output_image, json_results
def build_ape_predictor(model_type: str = 'D', which_categories=None, override_ckpt_file=""):
repo_id = "shenyunhang/APE"
checkpoint_file = None
config_file = None
if model_type == 'A':
checkpoint_file = "configs/LVISCOCOCOCOSTUFF_O365_OID_VG/ape_deta/ape_deta_vitl_eva02_lsj_cp_720k_20230504_002019/model_final.pth"
config_file = "LVISCOCOCOCOSTUFF_O365_OID_VG/ape_deta/ape_deta_vitl_eva02_lsj1024_cp_720k.py"
raise NotImplementedError
elif model_type == 'B':
checkpoint_file = "configs/LVISCOCOCOCOSTUFF_O365_OID_VGR_REFCOCO/ape_deta/ape_deta_vitl_eva02_vlf_lsj_cp_1080k_20230702_225418/model_final.pth"
config_file = "LVISCOCOCOCOSTUFF_O365_OID_VGR_REFCOCO/ape_deta/ape_deta_vitl_eva02_vlf_lsj1024_cp_1080k.py"
raise NotImplementedError
elif model_type == 'C':
checkpoint_file = "configs/LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO/ape_deta/ape_deta_vitl_eva02_vlf_lsj_cp_1080k_20230702_210950/model_final.pth"
config_file = "LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO/ape_deta/ape_deta_vitl_eva02_vlf_lsj1024_cp_1080k.py"
raise NotImplementedError
else:
checkpoint_file = "configs/LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO_GQA_PhraseCut_Flickr30k/ape_deta/ape_deta_vitl_eva02_clip_vlf_lsj1024_cp_16x4_1080k_mdl_20230829_162438/model_final.pth"
config_file = "LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO_GQA_PhraseCut_Flickr30k/ape_deta/ape_deta_vitl_eva02_clip_vlf_lsj1024_cp_16x4_1080k.py"
checkpoint_file = override_ckpt_file
category_dict = {'COCO': COCO_CATEGORIES, 'CITYSCAPES': CITYSCAPES_CATEGORIES, 'LVIS': LVIS_CATEGORIES}
target_categories = ""
if which_categories in ['COCO', 'CITYSCAPES', 'LVIS']:
category_names = []
for item in category_dict[which_categories]:
category_names.append(item['name'].split('-')[0].strip())
category_names = list(set(category_names))
for category_name in category_names:
target_categories = target_categories + f",{category_name}"
if target_categories.startswith(','):
target_categories = target_categories[1:]
if target_categories.endswith(','):
target_categories = target_categories[:-1]
predictor = APEPredictor(repo_id=repo_id,
checkpoint_file=checkpoint_file,
config_file=config_file,
target_categories=target_categories)
return predictor
if __name__ == "__main__":
import json
ape_predictor = build_ape_predictor(which_categories='COCO',
override_ckpt_file="shenyunhang/APE/configs/LVISCOCOCOCOSTUFF_O365_OID_VGR_SA1B_REFCOCO_GQA_PhraseCut_Flickr30k/ape_deta/ape_deta_vitl_eva02_clip_vlf_lsj1024_cp_16x4_1080k_mdl_20230829_162438/model_final.pth")
output_image, json_results = ape_predictor.run_on_image(
"./zhouyik/zt_any_visual_prompt/sa_10293442.jpg",
# input_text="person,bicycle,buildings,windows",
visualize=True,
score_threhold=0.1,
)
with open('./sa_10293442_out.json', 'w') as savef:
json.dump(json_results, savef)
output_image.save("./sa_10293442_out.jpg")