vicca / VG /test.py
sayehghp's picture
Add application file
0f8411f
'''
The model is trained on the architecture of the Grounding DINO with replacement of the image and text encoder.
The testing code is coming from GD model with minor changes in the configuration to be suited for CXRs.
# A code sample to run this script:
python test.py --weight_path="weights/checkpoint0399.pth" --image_path="38708899-5132e206-88cb58cf-d55a7065-6cbc983d.jpg"\\
--text_prompt="Cardiomegaly with mild pulmonary vascular congestion." --box_threshold=0.3 \\
--text_threshold=0.2 --plot_boxes
'''
from groundingdino.util.inference import load_model, load_image, predict, annotate
import cv2
import matplotlib.pyplot as plt
import supervision as sv
import torch
from torchvision.ops import box_convert
import numpy as np
import argparse
def get_args_parser():
parser = argparse.ArgumentParser('Set Visual Grounding', add_help=False)
parser.add_argument('--weight_path', type=str, default="weights/checkpoint_best_regular.pth",
help="The path to the trained model")
parser.add_argument('--image_path', type=str,
help="The path to the image file.")
parser.add_argument('--text_prompt', type=str,
help="The text prompt.")
parser.add_argument('--box_threshold', default=0.22, type=float)
parser.add_argument('--text_threshold', default=0.2, type=float)
parser.add_argument('--plot_boxes', action='store_true')
return parser
def convert_boxes_to_numpy(boxes, image_source):
h, w, _ = image_source.shape
bbox = boxes * torch.Tensor([w, h, w, h])
bbox = box_convert(boxes=bbox, in_fmt="cxcywh", out_fmt="xyxy").numpy()
return bbox
def main(args):
model = load_model("groundingdino/config/GroundingDINO_SwinT_OGC.py", args.weight_path)
IMAGE_PATH = args.image_path
TEXT_PROMPT = args.text_prompt
BOX_TRESHOLD = args.box_threshold
TEXT_TRESHOLD = args.text_threshold
image_source, image = load_image(IMAGE_PATH)
boxes, logits, phrases = predict(
model=model,
image=image,
caption=TEXT_PROMPT,
box_threshold=BOX_TRESHOLD,
text_threshold=TEXT_TRESHOLD
)
if args.plot_boxes:
annotate_dict = dict(color=sv.ColorPalette.DEFAULT, thickness=2, text_thickness=1)
annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases, bbox_annot=annotate_dict)
plt.imshow(annotated_frame, cmap="gray")
plt.axis('off')
bbox = convert_boxes_to_numpy(boxes, image_source)
print(bbox, logits, phrases)
return bbox, logits, phrases
if __name__ == '__main__':
parser = argparse.ArgumentParser('Visual Grounding of CXR Image Prompt', parents=[get_args_parser()])
args = parser.parse_args()
main(args)