''' The model is trained on the architecture of the Grounding DINO with replacement of the image and text encoder. The testing code is coming from GD model with minor changes in the configuration to be suited for CXRs. # A code sample to run this script: python test.py --weight_path="weights/checkpoint0399.pth" --image_path="38708899-5132e206-88cb58cf-d55a7065-6cbc983d.jpg"\\ --text_prompt="Cardiomegaly with mild pulmonary vascular congestion." --box_threshold=0.3 \\ --text_threshold=0.2 --plot_boxes ''' from groundingdino.util.inference import load_model, load_image, predict, annotate import cv2 import matplotlib.pyplot as plt import supervision as sv import torch from torchvision.ops import box_convert import numpy as np import argparse def get_args_parser(): parser = argparse.ArgumentParser('Set Visual Grounding', add_help=False) parser.add_argument('--weight_path', type=str, default="weights/checkpoint_best_regular.pth", help="The path to the trained model") parser.add_argument('--image_path', type=str, help="The path to the image file.") parser.add_argument('--text_prompt', type=str, help="The text prompt.") parser.add_argument('--box_threshold', default=0.22, type=float) parser.add_argument('--text_threshold', default=0.2, type=float) parser.add_argument('--plot_boxes', action='store_true') return parser def convert_boxes_to_numpy(boxes, image_source): h, w, _ = image_source.shape bbox = boxes * torch.Tensor([w, h, w, h]) bbox = box_convert(boxes=bbox, in_fmt="cxcywh", out_fmt="xyxy").numpy() return bbox def main(args): model = load_model("groundingdino/config/GroundingDINO_SwinT_OGC.py", args.weight_path) IMAGE_PATH = args.image_path TEXT_PROMPT = args.text_prompt BOX_TRESHOLD = args.box_threshold TEXT_TRESHOLD = args.text_threshold image_source, image = load_image(IMAGE_PATH) boxes, logits, phrases = predict( model=model, image=image, caption=TEXT_PROMPT, box_threshold=BOX_TRESHOLD, text_threshold=TEXT_TRESHOLD ) if args.plot_boxes: annotate_dict = dict(color=sv.ColorPalette.DEFAULT, thickness=2, text_thickness=1) annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases, bbox_annot=annotate_dict) plt.imshow(annotated_frame, cmap="gray") plt.axis('off') bbox = convert_boxes_to_numpy(boxes, image_source) print(bbox, logits, phrases) return bbox, logits, phrases if __name__ == '__main__': parser = argparse.ArgumentParser('Visual Grounding of CXR Image Prompt', parents=[get_args_parser()]) args = parser.parse_args() main(args)