File size: 10,914 Bytes
0f8411f fef87a3 ab6ae1b fef87a3 ab6ae1b fef87a3 ab6ae1b 0f8411f b94ff1b 0f8411f fef87a3 0f8411f e7d1dd0 0f8411f e09b1c8 9651d7a ab6ae1b 0f8411f e7d1dd0 b94ff1b e7d1dd0 b94ff1b 0f8411f e09b1c8 0f8411f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 | """
Input: image and text
Middle output: bbox (VG), Gen Image and similarity score (CXRGen), Shift_x&y (DETR)
Output: Localization Score, Reliability Score
python inference.py \
--image_path VG/38708899-5132e206-88cb58cf-d55a7065-6cbc983d.jpg \
--text_prompt "Cardiomegaly with mild pulmonary vascular congestion."
"""
import sys, os
# ---------------------------------------------------------------------
# Make CheXbert's `src` folder importable (so `import utils` works)
# ---------------------------------------------------------------------
# BASE_DIR = os.path.dirname(__file__)
# CHEXBERT_SRC = os.path.join(BASE_DIR, "CheXbert", "src")
# if CHEXBERT_SRC not in sys.path:
# sys.path.insert(0, CHEXBERT_SRC)
# from label import label # now imports /app/CheXbert/src/label.py
import pandas as pd
import numpy as np
import time
import cv2
import argparse
from ast import literal_eval
# from nltk import tokenize
# sys.path.append('/home/gholipos-admin/Desktop/Thesis/Training_Code/VICCA')
from pathlib import Path
import shutil
from huggingface_hub import hf_hub_download
from weights_utils import get_weight
# def ensure_vicca_weights():
# """
# Download all VICCA weights from the vicca-weights repo into the paths
# expected by the original code, with caching and safe subfolder handling.
# """
# repo_id = "sayehghp/vicca-weights"
# base = Path(__file__).parent
# weight_files = [
# # CheXbert
# "CheXbert/checkpoint/chexbert.pth",
# # Uniformer
# "CXRGen/annotator/ckpts/upernet_global_small.pth",
# # Diffusion
# "CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth",
# # Encoders
# "CXRGen/ldm/modules/encoders/BiomedVLP-CXR-BERT/pytorch_model.bin",
# "VG/weights/BiomedVLP-CXR-BERT/pytorch_model.bin",
# # Lung UNet
# "CXRGen/LungDetection/models/unet-2v.pt",
# "CXRGen/LungDetection/models/unet-6v.pt",
# # DETR
# "DETR/output/checkpoint.pth",
# # VG weights
# "VG/weights/checkpoint0399.pth",
# "VG/weights/checkpoint0399_log4.pth",
# "VG/weights/checkpoint_best_regular.pth",
# ]
# for rel_path in weight_files:
# local_path = base / rel_path
# local_path.parent.mkdir(parents=True, exist_ok=True)
# if local_path.exists():
# continue # skip if already mirrored into repo tree
# # Split repo path
# if "/" in rel_path:
# subfolder, filename = rel_path.rsplit("/", 1)
# else:
# subfolder, filename = None, rel_path
# cached_path = hf_hub_download(
# repo_id=repo_id,
# filename=filename,
# subfolder=subfolder if subfolder else None
# )
# # Copy from HF cache → repo tree
# shutil.copy2(cached_path, local_path)
# # Run once at import time so all weights are present before anything loads them
# ensure_vicca_weights()
# ---- SHIM FOR basicsr / torchvision ----
import types
from torchvision.transforms import functional as F
# Create a fake module torchvision.transforms.functional_tensor
# and expose rgb_to_grayscale from torchvision.transforms.functional
mod = types.ModuleType("torchvision.transforms.functional_tensor")
mod.rgb_to_grayscale = F.rgb_to_grayscale
sys.modules["torchvision.transforms.functional_tensor"] = mod
# ---- END SHIM ----
from CXRGen import sample_generation
from DETR import svc
from DETR.arguments import get_args_parser as get_detr_args_parser
from VG import localization
from ssim import ssim
import torch
from CheXbert.src.label import label
def get_args_parser():
parser = argparse.ArgumentParser('Set the Input', add_help=True)
parser.add_argument('--weight_path_gencxr', type=str, default="CXRGen/checkpoints/cn_d25ofd18_epoch-v18.pth",
help="Path to the CXR generation trained model")
parser.add_argument('--weight_path_vg', type=str, default="VG/weights/checkpoint0399_log4.pth",
help="Path to the Visual Grounding trained model")
parser.add_argument('--image_path', type=str, required=True,
help="Path to the input image file.")
parser.add_argument('--text_prompt', type=str, required=True,
help="Text prompt describing pathology.")
parser.add_argument('--box_threshold', default=0.2, type=float, help="Box threshold for VG")
parser.add_argument('--text_threshold', default=0.2, type=float, help="Text threshold for VG")
parser.add_argument('--num_samples', type=int, default=4, help="Number of generated image samples.")
parser.add_argument('--output_path', type=str, default="CXRGen/test/samples/output/",
help="Path to save generated files.")
return parser
import re
def simple_sentence_split(text: str):
"""
Very lightweight sentence splitter good enough for radiology reports.
Splits on '.', ';', and newlines, then strips whitespace.
"""
parts = re.split(r"[.\n;]+", text)
return [p.strip() for p in parts if p.strip()]
path_list = ['Enlarged Cardiomediastinum', 'Cardiomegaly', 'Lung Opacity',
'Lung Lesion', 'Edema', 'Consolidation', 'Pneumonia', 'Atelectasis',
'Pneumothorax', 'Pleural Effusion', 'Pleural Other', 'Fracture',
'Support Devices', 'No Finding']
# Cache CheXbert weights once at import time
CHEXBERT_WEIGHTS = get_weight("CheXbert/checkpoint/chexbert.pth")
# def chexbert_pathology(text):
# sentences = list(set(tokenize.sent_tokenize(text)))
# path_dict = []
# for sentence in sentences:
# sentence = sentence.replace('\n',' ')
# sentence = sentence.replace('\s+',' ')
# chexbert_weight_path = get_weight("CheXbert/checkpoint/chexbert.pth")
# # pathology = np.array(label("CheXbert/checkpoint/chexbert.pth", sentence)).T[0]
# pathology = np.array(label(chexbert_weight_path, sentence)).T[0]
# if pathology[-1]==1 or len(list(set(pathology)))==1 or not any(e==1 for e in pathology):
# pass
# else:
# indice = [i for i, e in enumerate(pathology) if e==1]
# for ind in indice:
# path_dict.append(path_list[ind])
# return path_dict
def chexbert_pathology(text: str):
"""
Run CheXbert on the text and return a list of *positive* pathology labels,
deduplicated.
"""
# If NLTK punkt ever becomes a problem on Spaces, replace this with a simple split.
# sentences = list(set(tokenize.sent_tokenize(text)))
# sentences = [s.strip() for s in text.split(".") if s.strip()]
sentences = list(set(simple_sentence_split(text)))
path_terms = set()
for sentence in sentences:
sentence = sentence.replace("\n", " ")
sentence = sentence.replace("\s+", " ")
# Run CheXbert
pathology = np.array(label(CHEXBERT_WEIGHTS, sentence)).T[0]
# Skip if: "No Finding" active, or all labels same, or no positives
if pathology[-1] == 1 or len(set(pathology)) == 1 or not any(e == 1 for e in pathology):
continue
# Collect positive indices
indices = [i for i, e in enumerate(pathology) if e == 1]
for ind in indices:
path_terms.add(path_list[ind])
return sorted(path_terms)
def extract_tensor(value):
cleaned_value = value.replace('tensor(', '').replace(')', '')
return literal_eval(cleaned_value)
def gen_cxr(weight_path, image_path, text_prompt, num_samples, output_path, device: str = "cpu"):
parser = sample_generation.get_args_parser()
args = parser.parse_args([])
# args.weight_path = weight_path
args.image_path = image_path
args.text_prompt = text_prompt
args.num_samples = num_samples
args.output_path = output_path
args.weight_path = get_weight(weight_path)
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
args.device = device
sample_generation.main(args)
def cal_shift(img_org_path, img_gen_path):
parser = get_detr_args_parser()
args = parser.parse_args([])
args.read_checkpoint = get_weight("DETR/output/checkpoint.pth")
args.img_org = img_org_path
args.img_gen = img_gen_path
shift_x, shift_y = svc.main(args)
return shift_x, shift_y
def get_local_bbox(weight_path, image_path, text_prompt, box_threshold, text_threshold):
parser = localization.get_args_parser()
args = parser.parse_args([])
# vg_ckpt_main = get_weight("VG/weights/checkpoint0399.pth")
# vg_ckpt_best = get_weight("VG/weights/checkpoint_best_regular.pth")
# vg_ckpt_log4 = get_weight("VG/weights/checkpoint0399_log4.pth")
# args.weight_path = weight_path
args.weight_path = get_weight(weight_path)
args.image_path = image_path
args.text_prompt = text_prompt
args.box_threshold = box_threshold
args.text_threshold = text_threshold
bbox, logits, phrases = localization.main(args)
return bbox, logits, phrases
if __name__ == "__main__":
args = get_args_parser().parse_args()
gen_cxr(args.weight_path_gencxr, args.image_path, args.text_prompt, args.num_samples, args.output_path)
time.sleep(4) # ensure outputs are written
df = pd.read_csv(args.output_path + "info_path_similarity.csv")
sim_ratios = [extract_tensor(val) for val in df["similarity_rate"]]
max_sim_index = sim_ratios.index(max(sim_ratios))
max_sim_gen_path = df["gen_sample_path"][max_sim_index]
sx, sy = cal_shift(args.image_path, max_sim_gen_path)
boxes, logits, phrases = get_local_bbox(
args.weight_path_vg,
args.image_path,
args.text_prompt,
args.box_threshold,
args.text_threshold
)
print("Boxes:", boxes)
print("Phrases:", phrases)
image_org_cv = cv2.imread(args.image_path, cv2.IMREAD_GRAYSCALE)
image_gen_cv = cv2.imread(max_sim_gen_path, cv2.IMREAD_GRAYSCALE)
ssim_scores = []
for bbox in boxes:
x1, y1, x2, y2 = bbox
bbox1 = [x1, y1, x2 - x1, y2 - y1]
bbox2 = [x1 + sx, y1 + sy, x2 - x1, y2 - y1]
bx1, by1, bw1, bh1 = [int(val) for val in bbox1]
bx2, by2, bw2, bh2 = [int(val) for val in bbox2]
roi_org = image_org_cv[by1:by1 + bh1, bx1:bx1 + bw1]
roi_gen = image_gen_cv[by2:by2 + bh2, bx2:bx2 + bw2]
if roi_org.shape == roi_gen.shape and roi_org.size > 0:
score = ssim(roi_org, roi_gen)
ssim_scores.append(score)
if ssim_scores:
print("SSIM scores per box:", ssim_scores)
print("Localization Detection Scores per bbox:", boxes, logits)
# print("Average SSIM (Localization Score):", sum(ssim_scores) / len(ssim_scores))
else:
print("No valid SSIM scores (e.g., mismatched shapes or empty ROIs).")
|