| | """ |
| | Input: image and text |
| | Middle output: bbox (VG), Gen Image and similarity score (CXRGen), Shift_x&y (DETR) |
| | Output: Localization Score, Reliability Score |
| | |
| | python inference.py \ |
| | --image_path VG/38708899-5132e206-88cb58cf-d55a7065-6cbc983d.jpg \ |
| | --text_prompt "Cardiomegaly with mild pulmonary vascular congestion." |
| | |
| | """ |
| | import sys, os |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | import pandas as pd |
| | import numpy as np |
| | import time |
| | import cv2 |
| |
|
| | import argparse |
| | from ast import literal_eval |
| | |
| |
|
| | |
| | from pathlib import Path |
| | import shutil |
| | from huggingface_hub import hf_hub_download |
| |
|
| | from weights_utils import get_weight |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| |
|
| |
|
| | |
| | |
| |
|
| |
|
| | |
| | import types |
| | from torchvision.transforms import functional as F |
| |
|
| | |
| | |
| | mod = types.ModuleType("torchvision.transforms.functional_tensor") |
| | mod.rgb_to_grayscale = F.rgb_to_grayscale |
| | sys.modules["torchvision.transforms.functional_tensor"] = mod |
| | |
| | from CXRGen import sample_generation |
| | from DETR import svc |
| | from DETR.arguments import get_args_parser as get_detr_args_parser |
| | from VG import localization |
| | from ssim import ssim |
| | import torch |
| |
|
| | from CheXbert.src.label import label |
| |
|
| | def get_args_parser(): |
| | parser = argparse.ArgumentParser('Set the Input', add_help=True) |
| | parser.add_argument('--weight_path_gencxr', type=str, default="CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth", |
| | help="Path to the CXR generation trained model") |
| | parser.add_argument('--weight_path_vg', type=str, default="VG/weights/checkpoint0399_log4.pth", |
| | help="Path to the Visual Grounding trained model") |
| | parser.add_argument('--image_path', type=str, required=True, |
| | help="Path to the input image file.") |
| | parser.add_argument('--text_prompt', type=str, required=True, |
| | help="Text prompt describing pathology.") |
| | parser.add_argument('--box_threshold', default=0.2, type=float, help="Box threshold for VG") |
| | parser.add_argument('--text_threshold', default=0.2, type=float, help="Text threshold for VG") |
| | parser.add_argument('--num_samples', type=int, default=4, help="Number of generated image samples.") |
| | parser.add_argument('--output_path', type=str, default="CXRGen/test/samples/output/", |
| | help="Path to save generated files.") |
| | return parser |
| |
|
| | import re |
| |
|
| | def simple_sentence_split(text: str): |
| | """ |
| | Very lightweight sentence splitter good enough for radiology reports. |
| | Splits on '.', ';', and newlines, then strips whitespace. |
| | """ |
| | parts = re.split(r"[.\n;]+", text) |
| | return [p.strip() for p in parts if p.strip()] |
| |
|
| |
|
| | path_list = ['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity', |
| | 'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis', |
| | 'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture', |
| | 'Support Devices', 'No Finding'] |
| |
|
| | |
| | CHEXBERT_WEIGHTS = get_weight("CheXbert/checkpoint/chexbert.pth") |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | def chexbert_pathology(text: str): |
| | """ |
| | Run CheXbert on the text and return a list of *positive* pathology labels, |
| | deduplicated. |
| | """ |
| | |
| | |
| | |
| | sentences = list(set(simple_sentence_split(text))) |
| |
|
| | path_terms = set() |
| |
|
| | for sentence in sentences: |
| | sentence = sentence.replace("\n", " ") |
| | sentence = sentence.replace("\s+", " ") |
| |
|
| | |
| | pathology = np.array(label(CHEXBERT_WEIGHTS, sentence)).T[0] |
| |
|
| | |
| | if pathology[-1] == 1 or len(set(pathology)) == 1 or not any(e == 1 for e in pathology): |
| | continue |
| |
|
| | |
| | indices = [i for i, e in enumerate(pathology) if e == 1] |
| | for ind in indices: |
| | path_terms.add(path_list[ind]) |
| |
|
| | return sorted(path_terms) |
| |
|
| | def extract_tensor(value): |
| | cleaned_value = value.replace('tensor(', '').replace(')', '') |
| | return literal_eval(cleaned_value) |
| |
|
| |
|
| | def gen_cxr(weight_path, image_path, text_prompt, num_samples, output_path, device: str = "cpu"): |
| | parser = sample_generation.get_args_parser() |
| | args = parser.parse_args([]) |
| | |
| | args.image_path = image_path |
| | args.text_prompt = text_prompt |
| | args.num_samples = num_samples |
| | args.output_path = output_path |
| | args.weight_path = get_weight(weight_path) |
| | if torch.cuda.is_available(): |
| | device = torch.device("cuda") |
| | else: |
| | device = torch.device("cpu") |
| | args.device = device |
| | sample_generation.main(args) |
| |
|
| |
|
| | def cal_shift(img_org_path, img_gen_path): |
| | parser = get_detr_args_parser() |
| | args = parser.parse_args([]) |
| | args.read_checkpoint = get_weight("DETR/output/checkpoint.pth") |
| | args.img_org = img_org_path |
| | args.img_gen = img_gen_path |
| | shift_x, shift_y = svc.main(args) |
| | return shift_x, shift_y |
| |
|
| |
|
| | def get_local_bbox(weight_path, image_path, text_prompt, box_threshold, text_threshold): |
| | parser = localization.get_args_parser() |
| | args = parser.parse_args([]) |
| | |
| | |
| | |
| | |
| | args.weight_path = get_weight(weight_path) |
| | args.image_path = image_path |
| | args.text_prompt = text_prompt |
| | args.box_threshold = box_threshold |
| | args.text_threshold = text_threshold |
| | bbox, logits, phrases = localization.main(args) |
| | return bbox, logits, phrases |
| |
|
| |
|
| | if __name__ == "__main__": |
| | args = get_args_parser().parse_args() |
| |
|
| | gen_cxr(args.weight_path_gencxr, args.image_path, args.text_prompt, args.num_samples, args.output_path) |
| | time.sleep(4) |
| |
|
| | df = pd.read_csv(args.output_path + "info_path_similarity.csv") |
| | sim_ratios = [extract_tensor(val) for val in df["similarity_rate"]] |
| | max_sim_index = sim_ratios.index(max(sim_ratios)) |
| | max_sim_gen_path = df["gen_sample_path"][max_sim_index] |
| |
|
| | sx, sy = cal_shift(args.image_path, max_sim_gen_path) |
| |
|
| | boxes, logits, phrases = get_local_bbox( |
| | args.weight_path_vg, |
| | args.image_path, |
| | args.text_prompt, |
| | args.box_threshold, |
| | args.text_threshold |
| | ) |
| | print("Boxes:", boxes) |
| | print("Phrases:", phrases) |
| |
|
| | image_org_cv = cv2.imread(args.image_path, cv2.IMREAD_GRAYSCALE) |
| | image_gen_cv = cv2.imread(max_sim_gen_path, cv2.IMREAD_GRAYSCALE) |
| |
|
| | ssim_scores = [] |
| | for bbox in boxes: |
| | x1, y1, x2, y2 = bbox |
| | bbox1 = [x1, y1, x2 - x1, y2 - y1] |
| | bbox2 = [x1 + sx, y1 + sy, x2 - x1, y2 - y1] |
| |
|
| | bx1, by1, bw1, bh1 = [int(val) for val in bbox1] |
| | bx2, by2, bw2, bh2 = [int(val) for val in bbox2] |
| |
|
| | roi_org = image_org_cv[by1:by1 + bh1, bx1:bx1 + bw1] |
| | roi_gen = image_gen_cv[by2:by2 + bh2, bx2:bx2 + bw2] |
| |
|
| | if roi_org.shape == roi_gen.shape and roi_org.size > 0: |
| | score = ssim(roi_org, roi_gen) |
| | ssim_scores.append(score) |
| |
|
| | if ssim_scores: |
| | print("SSIM scores per box:", ssim_scores) |
| | print("Localization Detection Scores per bbox:", boxes, logits) |
| | |
| | else: |
| | print("No valid SSIM scores (e.g., mismatched shapes or empty ROIs).") |
| |
|