| import os
|
| import json
|
| import argparse
|
| import os.path as osp
|
|
|
| import cv2
|
| import tqdm
|
| import torch
|
| import numpy as np
|
| import tensorflow as tf
|
| import supervision as sv
|
| from torchvision.ops import nms
|
|
|
| BOUNDING_BOX_ANNOTATOR = sv.BoundingBoxAnnotator(thickness=1)
|
| MASK_ANNOTATOR = sv.MaskAnnotator()
|
|
|
|
|
| class LabelAnnotator(sv.LabelAnnotator):
|
|
|
| @staticmethod
|
| def resolve_text_background_xyxy(
|
| center_coordinates,
|
| text_wh,
|
| position,
|
| ):
|
| center_x, center_y = center_coordinates
|
| text_w, text_h = text_wh
|
| return center_x, center_y, center_x + text_w, center_y + text_h
|
|
|
|
|
| LABEL_ANNOTATOR = LabelAnnotator(text_padding=4,
|
| text_scale=0.5,
|
| text_thickness=1)
|
|
|
|
|
| def parse_args():
|
| parser = argparse.ArgumentParser('YOLO-World TFLite (INT8) Demo')
|
| parser.add_argument('path', help='TFLite Model `.tflite`')
|
| parser.add_argument('image', help='image path, include image file or dir.')
|
| parser.add_argument(
|
| 'text',
|
| help=
|
| 'detecting texts (str, txt, or json), should be consistent with the ONNX model'
|
| )
|
| parser.add_argument('--output-dir',
|
| default='./output',
|
| help='directory to save output files')
|
| args = parser.parse_args()
|
| return args
|
|
|
|
|
| def preprocess(image, size=(640, 640)):
|
| h, w = image.shape[:2]
|
| max_size = max(h, w)
|
| scale_factor = size[0] / max_size
|
| pad_h = (max_size - h) // 2
|
| pad_w = (max_size - w) // 2
|
| pad_image = np.zeros((max_size, max_size, 3), dtype=image.dtype)
|
| pad_image[pad_h:h + pad_h, pad_w:w + pad_w] = image
|
| image = cv2.resize(pad_image, size,
|
| interpolation=cv2.INTER_LINEAR).astype('float32')
|
| image /= 255.0
|
| image = image[None]
|
| return image, scale_factor, (pad_h, pad_w)
|
|
|
|
|
| def generate_anchors_per_level(feat_size, stride, offset=0.5):
|
| h, w = feat_size
|
| shift_x = (torch.arange(0, w) + offset) * stride
|
| shift_y = (torch.arange(0, h) + offset) * stride
|
| yy, xx = torch.meshgrid(shift_y, shift_x)
|
| anchors = torch.stack([xx, yy]).reshape(2, -1).transpose(0, 1)
|
| return anchors
|
|
|
|
|
| def generate_anchors(feat_sizes=[(80, 80), (40, 40), (20, 20)],
|
| strides=[8, 16, 32],
|
| offset=0.5):
|
| anchors = [
|
| generate_anchors_per_level(fs, s, offset)
|
| for fs, s in zip(feat_sizes, strides)
|
| ]
|
| anchors = torch.cat(anchors)
|
| return anchors
|
|
|
|
|
| def simple_bbox_decode(points, pred_bboxes, stride):
|
|
|
| pred_bboxes = pred_bboxes * stride[None, :, None]
|
| x1 = points[..., 0] - pred_bboxes[..., 0]
|
| y1 = points[..., 1] - pred_bboxes[..., 1]
|
| x2 = points[..., 0] + pred_bboxes[..., 2]
|
| y2 = points[..., 1] + pred_bboxes[..., 3]
|
| bboxes = torch.stack([x1, y1, x2, y2], -1)
|
|
|
| return bboxes
|
|
|
|
|
| def visualize(image, bboxes, labels, scores, texts):
|
| detections = sv.Detections(xyxy=bboxes, class_id=labels, confidence=scores)
|
| labels = [
|
| f"{texts[class_id][0]} {confidence:0.2f}" for class_id, confidence in
|
| zip(detections.class_id, detections.confidence)
|
| ]
|
|
|
| image = BOUNDING_BOX_ANNOTATOR.annotate(image, detections)
|
| image = LABEL_ANNOTATOR.annotate(image, detections, labels=labels)
|
| return image
|
|
|
|
|
| def inference_per_sample(interp,
|
| image_path,
|
| texts,
|
| priors,
|
| strides,
|
| output_dir,
|
| size=(640, 640),
|
| vis=False,
|
| score_thr=0.05,
|
| nms_thr=0.3,
|
| max_dets=300):
|
|
|
|
|
| input_details = interp.get_input_details()
|
| output_details = interp.get_output_details()
|
|
|
|
|
| ori_image = cv2.imread(image_path)
|
| h, w = ori_image.shape[:2]
|
| image, scale_factor, pad_param = preprocess(ori_image[:, :, [2, 1, 0]],
|
| size)
|
|
|
|
|
| interp.set_tensor(input_details[0]['index'], image)
|
| interp.invoke()
|
|
|
| scores = interp.get_tensor(output_details[1]['index'])
|
| bboxes = interp.get_tensor(output_details[0]['index'])
|
|
|
|
|
|
|
| ori_scores = torch.from_numpy(scores[0])
|
| ori_bboxes = torch.from_numpy(bboxes)
|
|
|
|
|
| decoded_bboxes = simple_bbox_decode(priors, ori_bboxes, strides)[0]
|
| scores_list = []
|
| labels_list = []
|
| bboxes_list = []
|
| for cls_id in range(len(texts)):
|
| cls_scores = ori_scores[:, cls_id]
|
| labels = torch.ones(cls_scores.shape[0], dtype=torch.long) * cls_id
|
| keep_idxs = nms(decoded_bboxes, cls_scores, iou_threshold=0.5)
|
| cur_bboxes = decoded_bboxes[keep_idxs]
|
| cls_scores = cls_scores[keep_idxs]
|
| labels = labels[keep_idxs]
|
| scores_list.append(cls_scores)
|
| labels_list.append(labels)
|
| bboxes_list.append(cur_bboxes)
|
|
|
| scores = torch.cat(scores_list, dim=0)
|
| labels = torch.cat(labels_list, dim=0)
|
| bboxes = torch.cat(bboxes_list, dim=0)
|
|
|
| keep_idxs = scores > score_thr
|
| scores = scores[keep_idxs]
|
| labels = labels[keep_idxs]
|
| bboxes = bboxes[keep_idxs]
|
|
|
| keep_idxs = nms(bboxes, scores, iou_threshold=nms_thr)
|
| num_dets = min(len(keep_idxs), max_dets)
|
| bboxes = bboxes[keep_idxs].unsqueeze(0)
|
| scores = scores[keep_idxs].unsqueeze(0)
|
| labels = labels[keep_idxs].unsqueeze(0)
|
|
|
| scores = scores[0, :num_dets].numpy()
|
| bboxes = bboxes[0, :num_dets].numpy()
|
| labels = labels[0, :num_dets].numpy()
|
|
|
| bboxes -= np.array(
|
| [pad_param[1], pad_param[0], pad_param[1], pad_param[0]])
|
| bboxes /= scale_factor
|
| bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, w)
|
| bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, h)
|
|
|
| if vis:
|
| image_out = visualize(ori_image, bboxes, labels, scores, texts)
|
| cv2.imwrite(osp.join(output_dir, osp.basename(image_path)), image_out)
|
| print(f"detecting {num_dets} objects.")
|
| return image_out, ori_scores, ori_bboxes[0]
|
| else:
|
| return bboxes, labels, scores
|
|
|
|
|
| def main():
|
|
|
| args = parse_args()
|
| tflite_file = args.tflite
|
|
|
| interpreter = tf.lite.Interpreter(model_path=tflite_file,
|
| experimental_preserve_all_tensors=True)
|
| interpreter.allocate_tensors()
|
| print("Init TFLite Interpter")
|
| output_dir = "onnx_outputs"
|
| if not osp.exists(output_dir):
|
| os.mkdir(output_dir)
|
|
|
|
|
| if not osp.isfile(args.image):
|
| images = [
|
| osp.join(args.image, img) for img in os.listdir(args.image)
|
| if img.endswith('.png') or img.endswith('.jpg')
|
| ]
|
| else:
|
| images = [args.image]
|
|
|
| if args.text.endswith('.txt'):
|
| with open(args.text) as f:
|
| lines = f.readlines()
|
| texts = [[t.rstrip('\r\n')] for t in lines]
|
| elif args.text.endswith('.json'):
|
| texts = json.load(open(args.text))
|
| else:
|
| texts = [[t.strip()] for t in args.text.split(',')]
|
|
|
| size = (640, 640)
|
| strides = [8, 16, 32]
|
|
|
|
|
| featmap_sizes = [(size[0] // s, size[1] // s) for s in strides]
|
| flatten_priors = generate_anchors(featmap_sizes, strides=strides)
|
| mlvl_strides = [
|
| flatten_priors.new_full((featmap_size[0] * featmap_size[1] * 1, ),
|
| stride)
|
| for featmap_size, stride in zip(featmap_sizes, strides)
|
| ]
|
| flatten_strides = torch.cat(mlvl_strides)
|
|
|
| print("Start to inference.")
|
| for img in tqdm.tqdm(images):
|
| inference_per_sample(interpreter,
|
| img,
|
| texts,
|
| flatten_priors[None],
|
| flatten_strides,
|
| output_dir=output_dir,
|
| vis=True,
|
| score_thr=0.3,
|
| nms_thr=0.5)
|
| print("Finish inference")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|